diff --git a/main_test.go b/main_test.go index c2fb721..4698ff3 100644 --- a/main_test.go +++ b/main_test.go @@ -27,7 +27,7 @@ func serverBin(t *testing.T) { bytes := make([]byte, 1024*32) n, err := conn.Read(bytes) if err != nil { - log.Error("Test Server read %", err) + log.Error("Test Server read %", err.Error()) break } log.Info("Test Server read %d bytes from %s", n, conn.RemoteAddr().String()) @@ -46,7 +46,6 @@ func serverBin(t *testing.T) { log.Error("Test Server read bin %s", err.Error()) break } - log.Info("Test Server read file %d, it should be %d", nBinByte, n-1) for { if nBinByte != n-1 { log.Info("Test Server read file %d, it should be %d", nBinByte, n-1) @@ -60,7 +59,7 @@ func serverBin(t *testing.T) { } break } - + log.Info("Test Server read file %d, it should be %d", nBinByte, n-1) for i := 0; i < nBinByte; i++ { if binBytes[i] != bytes[i+1] { log.Error("Test Server read not consistent at %d. read:%c receive:%c", i, binBytes[i], bytes[i+1]) diff --git a/tcp.go b/tcp.go index a272518..5788d23 100644 --- a/tcp.go +++ b/tcp.go @@ -90,8 +90,8 @@ func (this *TransTCP) tunnel(src, dest net.Conn, id string, encrypDirection stri sendCarrier.Cipher = nil log.Debug("Write not crypted. Tunnel: %s", id) } - _, err = tscipher.SendData(sendCarrier, nByte) - log.Info("Write %d bytes from %s to %s. Tunnel: %s", nByte, dest.LocalAddr(), dest.RemoteAddr().String(), id) + n, err := tscipher.SendData(sendCarrier, nByte) + log.Info("Write %d bytes from %s to %s. Tunnel: %s", n, dest.LocalAddr(), dest.RemoteAddr().String(), id) log.Debug("Write %s %s", id, cache[:nByte]) if err != nil { log.Panic("Write panic. ID: %s, Err: %s, Remote Add: %s", id, err, dest.RemoteAddr().String()) diff --git a/tscipher/cipher.go b/tscipher/cipher.go index 8f6dea7..cff0de8 100644 --- a/tscipher/cipher.go +++ b/tscipher/cipher.go @@ -1,9 +1,16 @@ package tscipher import ( + "bytes" + "errors" + "github.com/TransX/log" + "github.com/TransX/utils" "net" ) +var StartMark = []byte("#2v!") //should be constant +var EndMark = []byte("_=1z") //should be constant + type Cipher interface { Decrypt(data []byte) (decrypted []byte, err error) Encrypt(data []byte) (encryped []byte, err error) @@ -28,6 +35,43 @@ func NewCipher(cipherName string) (cipher Cipher) { return nil //TODO:临时这样处理 } +func WrapPackage(data []byte) []byte { //把要加密传输的数据打包成一定的格式,避免发送了100自己,只收到90字节的问题。 + sizeOfData := len(data) + binSize := utils.Int2binary(sizeOfData, 10) + header := append(append(StartMark, binSize...), EndMark...) + //加密 + key := []byte("hahahehe~-1!") + cipheredHeader := make([]byte, len(header)) + for i, v := range header { + cipheredHeader[i] = v ^ key[i%len(key)] + } + return append(cipheredHeader, data...) +} + +func UnwrapPackage(pacakge []byte) (packageSize int, data []byte, err error) { + //前14个字节是header + cipheredHeader := pacakge[:18] + header := make([]byte, len(cipheredHeader)) + key := []byte("hahahehe~-1!") + for i, v := range cipheredHeader { + header[i] = v ^ key[i%len(key)] + } + start := header[:4] + end := header[14:] + binSize := header[4:14] + if bytes.Compare(start, StartMark) == 0 && bytes.Compare(end, EndMark) == 0 { + packageSize = utils.Binary2Int(binSize) + data = pacakge[18:] + err = nil + } else { + packageSize = 0 + data = pacakge + err = errors.New("not a package") + } + return + +} + func SendData(carrier *Carrier, nByte int) (n int, err error) { if carrier.Cipher == nil { n, err = carrier.Conn.Write(carrier.Cache[:nByte]) @@ -38,7 +82,9 @@ func SendData(carrier *Carrier, nByte int) (n int, err error) { n = 0 return } - n, err = carrier.Conn.Write(encrypedByte[:nByte]) + //打包 + wraped := WrapPackage(encrypedByte[:nByte]) + n, err = carrier.Conn.Write(wraped) copy(carrier.Cache, encrypedByte[:nByte]) // in case of debugging return } @@ -60,12 +106,46 @@ func ReceiveData(carrier *Carrier) (n int, err error) { if carrier.Cipher == nil { return } - decrypted, err := carrier.Cipher.Decrypt(carrier.Cache[:n]) - copy(carrier.Cache, decrypted[:n]) + //解包 + wrapedPackage := carrier.Cache[:n] + packageSize, data, err := UnwrapPackage(wrapedPackage) + realData := make([]byte, 0, packageSize) + // log.Info("packageSize %d data size %d", packageSize, len(data)) + if err == nil && packageSize == len(data) { //读到的是一个完整的包 + realData = data + log.Info("read a complete package") + } else { + gotSize := len(data) + for { + n, err = carrier.Conn.Read(carrier.Cache) + if err != nil { + n = 0 + return + } + wrapedPackage = carrier.Cache[:n] + log.Info("got partial package size %d", n) + _, data, err = UnwrapPackage(wrapedPackage) + if err == nil { + n = 0 + err = errors.New("partial package lost") + return + } + gotSize += len(data) + realData = append(realData, data...) + if gotSize == packageSize { + log.Info("got enough:. packageSize %d, real size %d", packageSize, gotSize) + break + } + } + + } + decrypted, err := carrier.Cipher.Decrypt(realData) if err != nil { n = 0 return } + n = len(decrypted) + copy(carrier.Cache, decrypted) return } diff --git a/utils/int2binary.go b/utils/int2binary.go new file mode 100644 index 0000000..d810f90 --- /dev/null +++ b/utils/int2binary.go @@ -0,0 +1,34 @@ +package utils + +import ( + "encoding/binary" + "unsafe" +) + +func Int2binary(num int, b int) []byte { + var a int64 //Go的破毛病 + buf := make([]byte, unsafe.Sizeof(a)) + if len(buf) > b { + panic("int2binary: buff must be greater than size of int64") + } + binary.PutVarint(buf, int64(num)) + ret := make([]byte, b) + for i := 0; i < len(ret); i++ { + ret[i] = 0 + } + copy(ret, buf) + return ret +} + +func Binary2Int(bin []byte) int { + var a int64 //Go的破毛病 + b := unsafe.Sizeof(a) //只取前面几位 + num, n := binary.Varint(bin[:b]) + if n == 0 { + panic("Binary2Int:buf too small") + } + if n < 0 { + panic("value larger than 64 bits (overflow) sand -n is the number of bytes read") + } + return int(num) +}