transx/tscipher/cipher.go

222 lines
5.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package tscipher
import (
"bytes"
"fmt"
"net"
"strconv"
"strings"
"github.com/TransX/cache"
"github.com/TransX/utils"
"github.com/spf13/viper"
)
// var StartMark = []byte("#2v!") //should be constant
// var EndMark = []byte("_=1z") //should be constant
type LackDataError struct {
e string
}
type NotPackageError struct {
e string
}
func (this *LackDataError) Error() string {
return this.e
}
func (this *NotPackageError) Error() string {
return this.e
}
type Cipher interface {
Decrypt(data []byte) (decrypted []byte, err error)
Encrypt(data []byte) (encryped []byte, err error)
}
type Carrier struct {
Conn net.Conn
Cipher Cipher
Cache *cache.BlockingQueueCache
Msg *cache.BlockingQueueCache
AttachedTunnelID string
receiveBuff []byte
}
func NewCarrier(conn net.Conn, cipher Cipher, queCache *cache.BlockingQueueCache, msg *cache.BlockingQueueCache, id string) *Carrier {
t := new(Carrier)
t.Conn = conn
t.Cipher = cipher
t.Cache = queCache
t.Msg = msg
t.AttachedTunnelID = id
t.receiveBuff = make([]byte, 0, 1024*4)
return t
}
func (this *Carrier) GetReceiveBuff() []byte {
buff := this.receiveBuff
_b := make([]byte, len(buff), cap(buff)) //必须这样写,没错。
copy(_b, buff)
return _b
}
func (this *Carrier) SetReceiveBuff(buff []byte) {
this.receiveBuff = buff
}
func NewCipher(cipherName string) (cipher Cipher) {
if cipherName == "default" {
return NewChaCha()
}
if cipherName == "AES" {
return NewAES()
}
if cipherName == "XOR" {
return NewXOR([]byte(viper.GetString("auth.XORKey")))
}
return nil //TODO:临时这样处理
}
func WrapPackage(data []byte) []byte { //把要加密传输的数据打包成一定的格式避免发送了100自己只收到90字节的问题。
sizeOfData := len(data)
binSize := utils.Int2binary(sizeOfData, 10)
startMark := []byte(viper.GetString("auth.startMark"))
endMark := []byte(viper.GetString("auth.endMark"))
header := append(append(startMark, binSize...), endMark...)
//加密
key := []byte(viper.GetString("auth.headerKey"))
cipheredHeader := make([]byte, len(header))
for i, v := range header {
cipheredHeader[i] = v ^ key[i%len(key)]
}
return append(cipheredHeader, data...)
}
func UnwrapPackage(pacakge []byte) (data []byte, rest []byte, err error) {
//前14个字节是header
cipheredHeader := pacakge[:18]
header := make([]byte, len(cipheredHeader))
key := []byte(viper.GetString("auth.headerKey"))
key = key
for i, v := range cipheredHeader {
header[i] = v ^ key[i%len(key)]
}
start := header[:4]
end := header[14:]
binSize := header[4:14]
packageSize := 0
startMarkStr := viper.GetString("auth.startMark")
endMarkStr := viper.GetString("auth.endMark")
startMark := []byte(startMarkStr)
endMark := []byte(endMarkStr)
if bytes.Compare(start, startMark) == 0 && bytes.Compare(end, endMark) == 0 {
packageSize = utils.Binary2Int(binSize)
if len(pacakge[18:]) < packageSize {
packageSize = 0
data = nil
rest = nil
err = &LackDataError{"LackDataError"}
return
}
data = pacakge[18 : 18+packageSize]
rest = pacakge[18+len(data):]
err = nil
} else {
packageSize = 0
data = nil
rest = nil
if strings.Contains(string(pacakge), startMarkStr) && strings.Contains(string(pacakge), endMarkStr) {
a := strings.Index(string(pacakge), startMarkStr)
b := strings.Index(string(pacakge), endMarkStr)
err = &NotPackageError{"NotPackageError(contains)" + "start:" + string(start) + " end:" + string(end) + "pacakge " + strconv.Itoa(len(pacakge)) + "start" + strconv.Itoa(a) + "end" + strconv.Itoa(b)}
} else {
err = &NotPackageError{fmt.Sprintf("NotPackageError start: %s end: %s whole %x", string(start), string(end), header)}
}
}
return
}
// func SendData(carrier *Carrier) (n int, err error) {
// msg, nByte := carrier.Msg.Get()
// if len(msg) < nByte {
// log.Panic("Cache of send is too small")
// }
// if carrier.Cipher == nil {
// n, err = carrier.Conn.Write(msg[:nByte])
// carrier.Cache.Put(make([]byte, 1024*4), 1024*4)
// return
// }
// encrypedByte, err := carrier.Cipher.Encrypt(msg[:nByte])
// if err != nil {
// n = 0
// return
// }
// //打包
// wraped := WrapPackage(encrypedByte[:nByte])
// n, err = carrier.Conn.Write(wraped)
// carrier.Cache.Put(make([]byte, 1024*4), 1024*4)
// return
// }
//
// func RowReceiveData(carrier *Carrier) (n int, err error) {
// cache, _ := carrier.Cache.Get()
// n, err = carrier.Conn.Read(cache)
// if err != nil {
// n = 0
// }
// carrier.Msg.Put(cache, n)
// return
// }
//
// func ReceiveData(carrier *Carrier) (n int, err error) {
// wrapedPackage := carrier.GetReceiveBuff() //make([]byte, 0, cap(carrier.Cache))
// var packageData []byte
// var _rest []byte
// cache, _ := carrier.Cache.Get()
// for {
// //首先检查这个是不是完整的包,是就返回好了,免得被阻塞
// data, rest, err := UnwrapPackage(wrapedPackage)
// packageData = data
// _rest = rest
// if err, ok := err.(*NotPackageError); len(wrapedPackage) >= 18 && ok {
// log.Debug("return NotPackageError %s", carrier.AttachedTunnelID)
// return 0, err
// }
// if err == nil {
// //够一个完整的包
// capBuff := cap(carrier.GetReceiveBuff())
// _buff := make([]byte, 0, capBuff) //释放
// _buff = append(_buff, _rest...)
// carrier.SetReceiveBuff(_buff)
// break
// }
// //如果读到的数据不够一个完整的包
// if len(wrapedPackage) > 0 {
// n, err = carrier.Conn.Read(cache)
// if err != nil {
// log.Error("ERROR %s", err)
// }
// } else {
// n, err = io.ReadAtLeast(carrier.Conn, cache, 18)
// }
//
// if err != nil {
// n = 0
// return n, err
// }
// wrapedPackage = append(wrapedPackage, cache[:n]...)
// }
// decrypted, err := carrier.Cipher.Decrypt(packageData)
// if err != nil {
// n = 0
// return
// }
// n = len(decrypted)
// carrier.Msg.Put(decrypted, n)
// return
// }