transx/tscipher/cipher.go

258 lines
7.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package tscipher
import (
"bytes"
"fmt"
"github.com/TransX/cache"
"github.com/TransX/log"
"github.com/TransX/utils"
"io"
"net"
"strconv"
"strings"
)
var StartMark = []byte("#2v!") //should be constant
var EndMark = []byte("_=1z") //should be constant
type LackDataError struct {
e string
}
type NotPackageError struct {
e string
}
func (this *LackDataError) Error() string {
return this.e
}
func (this *NotPackageError) Error() string {
return this.e
}
type Cipher interface {
Decrypt(data []byte) (decrypted []byte, err error)
Encrypt(data []byte) (encryped []byte, err error)
}
type Carrier struct {
Conn net.Conn
Cipher Cipher
// Cache []byte
Cache *cache.UnblockingQueueCache
Msg *cache.BlockingQueueCache
AttachedTunnelID string
receiveBuff []byte
}
func NewCarrier(conn net.Conn, cipher Cipher, queCache *cache.UnblockingQueueCache, msg *cache.BlockingQueueCache, id string) *Carrier {
t := new(Carrier)
t.Conn = conn
t.Cipher = cipher
t.Cache = queCache
t.Msg = msg
t.AttachedTunnelID = id
t.receiveBuff = make([]byte, 0, 1024*4)
return t
}
func (this *Carrier) GetReceiveBuff() []byte {
// log.Debug("id %d receivebuff Get, len %d", this.AttachedTunnelID, len(this.receiveBuff))
buff := this.receiveBuff
_b := make([]byte, len(buff), cap(buff)) //必须这样写,没错。
copy(_b, buff)
return _b
}
func (this *Carrier) SetReceiveBuff(buff []byte) {
this.receiveBuff = buff
// log.Debug("id %d receivebuff set, len %d", this.AttachedTunnelID, len(this.receiveBuff))
}
func NewCipher(cipherName string) (cipher Cipher) {
if cipherName == "default" {
return NewChaCha()
}
if cipherName == "AES" {
return NewAES()
}
if cipherName == "XOR" {
return NewXOR([]byte("fasdfasdf!3297!jfsl12*&!HHHFds"))
}
return nil //TODO:临时这样处理
}
func WrapPackage(data []byte) []byte { //把要加密传输的数据打包成一定的格式避免发送了100自己只收到90字节的问题。
sizeOfData := len(data)
binSize := utils.Int2binary(sizeOfData, 10)
// header := make([]byte, 18)
// header = append(append(append(header, StartMark...), binSize...), EndMark...)
header := append(append(StartMark, binSize...), EndMark...)
// log.Error("size of header %d %x", len(header), header)
//加密
key := []byte("#2GD+.>dt`Qdp")
key = key
cipheredHeader := make([]byte, len(header))
for i, v := range header {
cipheredHeader[i] = v ^ key[i%len(key)]
}
return append(cipheredHeader, data...)
}
func UnwrapPackage(pacakge []byte) (data []byte, rest []byte, err error) {
//前14个字节是header
cipheredHeader := pacakge[:18]
header := make([]byte, len(cipheredHeader))
key := []byte("#2GD+.>dt`Qdp")
key = key
for i, v := range cipheredHeader {
header[i] = v ^ key[i%len(key)]
}
// log.Error("receive header %d %x", len(header), header)
// log.Error("receive pacakge %d %x", len(pacakge[:180]), pacakge[:180])
start := header[:4]
end := header[14:]
binSize := header[4:14]
packageSize := 0
if bytes.Compare(start, StartMark) == 0 && bytes.Compare(end, EndMark) == 0 {
packageSize = utils.Binary2Int(binSize)
if len(pacakge[18:]) < packageSize {
packageSize = 0
data = nil
rest = nil
err = &LackDataError{"LackDataError"}
return
}
data = pacakge[18 : 18+packageSize]
rest = pacakge[18+len(data):]
err = nil
} else {
packageSize = 0
data = nil
rest = nil
if strings.Contains(string(pacakge), "#2v!") && strings.Contains(string(pacakge), "_=1z") {
a := strings.Index(string(pacakge), "#2v!")
b := strings.Index(string(pacakge), "_=1z")
err = &NotPackageError{"NotPackageError(contains)" + "start:" + string(start) + " end:" + string(end) + "pacakge " + strconv.Itoa(len(pacakge)) + "start" + strconv.Itoa(a) + "end" + strconv.Itoa(b)}
} else {
err = &NotPackageError{fmt.Sprintf("NotPackageError start: %s end: %s whole %x", string(start), string(end), header)}
}
}
return
}
func SendData(carrier *Carrier) (n int, err error) {
// cache, nByte := carrier.Cache.Get()
msg, nByte := carrier.Msg.Get()
// fmt.Printf("SendData id %s get msg %d\n", carrier.AttachedTunnelID, nByte)
// fmt.Printf("id %s get cache\n", carrier.AttachedTunnelID)
if len(msg) < nByte {
log.Panic("Cache of send is too small")
}
if carrier.Cipher == nil {
n, err = carrier.Conn.Write(msg[:nByte])
return
}
encrypedByte, err := carrier.Cipher.Encrypt(msg[:nByte])
if err != nil {
n = 0
return
}
//打包
wraped := WrapPackage(encrypedByte[:nByte])
n, err = carrier.Conn.Write(wraped)
// log.Debug("Ready to write id %s, 18 byte %s", carrier.AttachedTunnelID, string(wraped[:18]))
// copy(carrier.Cache, encrypedByte[:nByte]) // in case of debugging
// carrier.Cache.Put(msg, len(msg))
carrier.Cache.Put(make([]byte, 1024*4), 1024*4)
// fmt.Printf("SendData %s id put cache\n", carrier.AttachedTunnelID)
return
}
func RowReceiveData(carrier *Carrier) (n int, err error) {
cache, _ := carrier.Cache.Get()
// fmt.Printf("RowReceiveData id %s get cache\n", carrier.AttachedTunnelID)
n, err = carrier.Conn.Read(cache)
if err != nil {
n = 0
}
// carrier.Cache.Put(cache, n)
// fmt.Printf("RowReceiveData id %s pet msg\n", carrier.AttachedTunnelID)
carrier.Msg.Put(cache, n)
return
}
func ReceiveData(carrier *Carrier) (n int, err error) {
// defer func() {
// if r := recover(); r != nil {
// log.Error("ReceiveData err %s", r)
// }
// }()
// log.Debug("id %s wrapedPackage := carrier.GetReceiveBuff()", carrier.AttachedTunnelID)
wrapedPackage := carrier.GetReceiveBuff() //make([]byte, 0, cap(carrier.Cache))
var packageData []byte
var _rest []byte
cache, _ := carrier.Cache.Get()
for {
//首先检查这个是不是完整的包,是就返回好了,免得被阻塞
data, rest, err := UnwrapPackage(wrapedPackage)
packageData = data
_rest = rest
if err, ok := err.(*NotPackageError); len(wrapedPackage) >= 18 && ok {
log.Debug("return NotPackageError %s", carrier.AttachedTunnelID)
return 0, err
}
// log.Debug("id %s length of package %d", carrier.AttachedTunnelID, len(packageData))
if err == nil {
//够一个完整的包
capBuff := cap(carrier.GetReceiveBuff())
_buff := make([]byte, 0, capBuff) //释放
_buff = append(_buff, _rest...)
// log.Debug("id %s carrier.SetReceiveBuff(_buff)", carrier.AttachedTunnelID)
carrier.SetReceiveBuff(_buff)
break
}
//如果读到的数据不够一个完整的包
// log.Debug("id %s to read wrapedPackage %d", carrier.AttachedTunnelID, len(wrapedPackage))
if len(wrapedPackage) > 0 {
n, err = carrier.Conn.Read(cache)
if err != nil {
log.Error("ERROR %s", err)
}
// log.Debug("id %s to Conn.Read %d", carrier.AttachedTunnelID, n)
} else {
n, err = io.ReadAtLeast(carrier.Conn, cache, 18)
// log.Debug("id %s to ReadAtLeast", carrier.AttachedTunnelID)
}
if err != nil {
n = 0
return n, err
}
wrapedPackage = append(wrapedPackage, cache[:n]...)
// log.Debug("id %s length of conn %d", carrier.AttachedTunnelID, n)
// log.Debug("id %s first 18 %s from %s", carrier.AttachedTunnelID, string(wrapedPackage[:18]), carrier.Conn.RemoteAddr().String())
}
// if len(carrier.GetReceiveBuff()) > 0 {
// log.Debug("id %s trailing %d from %s", carrier.AttachedTunnelID, len(carrier.GetReceiveBuff()), carrier.Conn.RemoteAddr().String())
// log.Debug("id %s 18 byte of trailing %s", carrier.AttachedTunnelID, string(carrier.GetReceiveBuff()[:18]))
// }
decrypted, err := carrier.Cipher.Decrypt(packageData)
if err != nil {
n = 0
return
}
n = len(decrypted)
// copy(cache, decrypted)
// carrier.Cache.Put(decrypted, n)
// fmt.Printf("ReceiveData id %s get msg\n", carrier.AttachedTunnelID)
carrier.Msg.Put(decrypted, n)
// fmt.Printf("id %s put cache\n", carrier.AttachedTunnelID)
return
}