package tscipher import ( "bytes" "errors" "github.com/TransX/log" "github.com/TransX/utils" "net" ) var StartMark = []byte("#2v!") //should be constant var EndMark = []byte("_=1z") //should be constant type Cipher interface { Decrypt(data []byte) (decrypted []byte, err error) Encrypt(data []byte) (encryped []byte, err error) } type Carrier struct { Conn net.Conn Cipher Cipher Cache []byte AttachedTunnelID string } func NewCipher(cipherName string) (cipher Cipher) { if cipherName == "default" { return NewChaCha() } if cipherName == "AES" { return NewAES() } if cipherName == "XOR" { return NewXOR([]byte("fasdfasdf!3297!jfsl12*&!HHHFds")) } return nil //TODO:临时这样处理 } func WrapPackage(data []byte) []byte { //把要加密传输的数据打包成一定的格式,避免发送了100自己,只收到90字节的问题。 sizeOfData := len(data) binSize := utils.Int2binary(sizeOfData, 10) header := append(append(StartMark, binSize...), EndMark...) //加密 key := []byte("hahahehe~-1!") cipheredHeader := make([]byte, len(header)) for i, v := range header { cipheredHeader[i] = v ^ key[i%len(key)] } return append(cipheredHeader, data...) } func UnwrapPackage(pacakge []byte) (packageSize int, data []byte, err error) { //前14个字节是header cipheredHeader := pacakge[:18] header := make([]byte, len(cipheredHeader)) key := []byte("hahahehe~-1!") for i, v := range cipheredHeader { header[i] = v ^ key[i%len(key)] } start := header[:4] end := header[14:] binSize := header[4:14] if bytes.Compare(start, StartMark) == 0 && bytes.Compare(end, EndMark) == 0 { packageSize = utils.Binary2Int(binSize) data = pacakge[18:] err = nil } else { packageSize = 0 data = pacakge err = errors.New("not a package") } return } func SendData(carrier *Carrier, nByte int) (n int, err error) { if len(carrier.Cache) < nByte { log.Panic("Cache of send is too small") } if carrier.Cipher == nil { n, err = carrier.Conn.Write(carrier.Cache[:nByte]) return } encrypedByte, err := carrier.Cipher.Encrypt(carrier.Cache[:nByte]) if err != nil { n = 0 return } //打包 wraped := WrapPackage(encrypedByte[:nByte]) n, err = carrier.Conn.Write(wraped) copy(carrier.Cache, encrypedByte[:nByte]) // in case of debugging return } func SendData2(carrier *Carrier, nByte int) (n int, err error) { n, err = carrier.Conn.Write(carrier.Cache[:nByte]) if err != nil { return } return } func ReceiveData(carrier *Carrier) (n int, err error) { n, err = carrier.Conn.Read(carrier.Cache) if err != nil { n = 0 return } if carrier.Cipher == nil { return } //解包 wrapedPackage := carrier.Cache[:n] packageSize, data, err := UnwrapPackage(wrapedPackage) realData := make([]byte, 0, packageSize) // log.Info("packageSize %d data size %d", packageSize, len(data)) if err == nil && packageSize == len(data) { //读到的是一个完整的包 realData = data n = len(realData) log.Debug("read a complete package") } else { gotSize := len(data) for { n, err = carrier.Conn.Read(carrier.Cache) if err != nil { n = 0 return } wrapedPackage = carrier.Cache[:n] log.Debug("got partial package size %d from %s ID: %s", n, carrier.Conn.RemoteAddr().String(), carrier.AttachedTunnelID) _, data, err = UnwrapPackage(wrapedPackage) if err == nil { n = 0 err = errors.New("partial package lost") return } gotSize += len(data) realData = append(realData, data...) if gotSize == packageSize { log.Debug("got enough:. packageSize %d, real size %d. not include header", packageSize, gotSize) n = gotSize break } } } decrypted, err := carrier.Cipher.Decrypt(realData) if err != nil { n = 0 return } // n = len(decrypted) copy(carrier.Cache, decrypted) return } func ReceiveData2(carrier *Carrier) (n int, err error) { n, err = carrier.Conn.Read(carrier.Cache) if err != nil { n = 0 return } return }