transx/tscipher/cipher.go

237 lines
6.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package tscipher
import (
"bytes"
"fmt"
"github.com/TransX/log"
"github.com/TransX/utils"
"io"
"net"
"strconv"
"strings"
)
var StartMark = []byte("#2v!") //should be constant
var EndMark = []byte("_=1z") //should be constant
type LackDataError struct {
e string
}
type NotPackageError struct {
e string
}
func (this *LackDataError) Error() string {
return this.e
}
func (this *NotPackageError) Error() string {
return this.e
}
type Cipher interface {
Decrypt(data []byte) (decrypted []byte, err error)
Encrypt(data []byte) (encryped []byte, err error)
}
type Carrier struct {
Conn net.Conn
Cipher Cipher
Cache []byte
AttachedTunnelID string
receiveBuff []byte
}
func NewCarrier(conn net.Conn, cipher Cipher, cache []byte, id string) *Carrier {
t := new(Carrier)
t.Conn = conn
t.Cipher = cipher
t.Cache = cache
t.AttachedTunnelID = id
t.receiveBuff = make([]byte, 0, len(cache)*8)
return t
}
func (this *Carrier) GetReceiveBuff() []byte {
log.Debug("id %d receivebuff Get, len %d", this.AttachedTunnelID, len(this.receiveBuff))
buff := this.receiveBuff
_b := make([]byte, len(buff), cap(buff))
copy(_b, buff)
return _b
}
func (this *Carrier) SetReceiveBuff(buff []byte) {
this.receiveBuff = buff
log.Debug("id %d receivebuff set, len %d", this.AttachedTunnelID, len(this.receiveBuff))
}
func NewCipher(cipherName string) (cipher Cipher) {
if cipherName == "default" {
return NewChaCha()
}
if cipherName == "AES" {
return NewAES()
}
if cipherName == "XOR" {
return NewXOR([]byte("fasdfasdf!3297!jfsl12*&!HHHFds"))
}
return nil //TODO:临时这样处理
}
func WrapPackage(data []byte) []byte { //把要加密传输的数据打包成一定的格式避免发送了100自己只收到90字节的问题。
sizeOfData := len(data)
binSize := utils.Int2binary(sizeOfData, 10)
header := append(append(StartMark, binSize...), EndMark...)
// log.Error("size of header %d %x", len(header), header)
//加密
key := []byte("#2GD+.>dt`Qdp")
key = key
cipheredHeader := make([]byte, len(header))
for i, v := range header {
cipheredHeader[i] = v ^ key[i%len(key)]
}
return append(cipheredHeader, data...)
}
func UnwrapPackage(pacakge []byte) (data []byte, rest []byte, err error) {
//前14个字节是header
cipheredHeader := pacakge[:18]
header := make([]byte, len(cipheredHeader))
key := []byte("#2GD+.>dt`Qdp")
key = key
for i, v := range cipheredHeader {
header[i] = v ^ key[i%len(key)]
}
// log.Error("receive header %d %x", len(header), header)
// log.Error("receive pacakge %d %x", len(pacakge[:180]), pacakge[:180])
start := header[:4]
end := header[14:]
binSize := header[4:14]
packageSize := 0
if bytes.Compare(start, StartMark) == 0 && bytes.Compare(end, EndMark) == 0 {
packageSize = utils.Binary2Int(binSize)
if len(pacakge[18:]) < packageSize {
packageSize = 0
data = nil
rest = nil
err = &LackDataError{"LackDataError"}
return
}
data = pacakge[18 : 18+packageSize]
rest = pacakge[18+len(data):]
err = nil
} else {
packageSize = 0
data = nil
rest = nil
if strings.Contains(string(pacakge), "#2v!") && strings.Contains(string(pacakge), "_=1z") {
a := strings.Index(string(pacakge), "#2v!")
b := strings.Index(string(pacakge), "_=1z")
err = &NotPackageError{"NotPackageError(contains)" + "start:" + string(start) + " end:" + string(end) + "pacakge " + strconv.Itoa(len(pacakge)) + "start" + strconv.Itoa(a) + "end" + strconv.Itoa(b)}
} else {
err = &NotPackageError{fmt.Sprintf("NotPackageError start: %s end: %s whole %x", string(start), string(end), header)}
}
}
return
}
func SendData(carrier *Carrier, nByte int) (n int, err error) {
if len(carrier.Cache) < nByte {
log.Panic("Cache of send is too small")
}
if carrier.Cipher == nil {
n, err = carrier.Conn.Write(carrier.Cache[:nByte])
return
}
encrypedByte, err := carrier.Cipher.Encrypt(carrier.Cache[:nByte])
if err != nil {
n = 0
return
}
//打包
wraped := WrapPackage(encrypedByte[:nByte])
n, err = carrier.Conn.Write(wraped)
log.Info("Ready to write id %s, 18 byte %s", carrier.AttachedTunnelID, string(wraped[:18]))
copy(carrier.Cache, encrypedByte[:nByte]) // in case of debugging
return
}
func RowReceiveData(carrier *Carrier) (n int, err error) {
n, err = carrier.Conn.Read(carrier.Cache)
if err != nil {
n = 0
}
return
}
func ReceiveData(carrier *Carrier) (n int, err error) {
// defer func() {
// if r := recover(); r != nil {
// log.Error("ReceiveData err %s", r)
// }
// }()
log.Info("id %s wrapedPackage := carrier.GetReceiveBuff()", carrier.AttachedTunnelID)
wrapedPackage := carrier.GetReceiveBuff() //make([]byte, 0, cap(carrier.Cache))
var packageData []byte
var _rest []byte
for {
//首先检查这个是不是完整的包,是就返回好了,免得被阻塞
data, rest, err := UnwrapPackage(wrapedPackage)
packageData = data
_rest = rest
if err, ok := err.(*NotPackageError); len(wrapedPackage) >= 18 && ok {
log.Debug("return NotPackageError %s", carrier.AttachedTunnelID)
return 0, err
}
log.Debug("id %s length of package %d", carrier.AttachedTunnelID, len(packageData))
if err == nil {
//够一个完整的包
log.Info("id %s capBuff := cap(carrier.GetReceiveBuff())", carrier.AttachedTunnelID)
capBuff := cap(carrier.GetReceiveBuff())
_buff := make([]byte, 0, capBuff) //释放
_buff = append(_buff, _rest...)
log.Info("id %s carrier.SetReceiveBuff(_buff)", carrier.AttachedTunnelID)
carrier.SetReceiveBuff(_buff)
break
}
//如果读到的数据不够一个完整的包
log.Debug("id %s to read wrapedPackage %d", carrier.AttachedTunnelID, len(wrapedPackage))
if len(wrapedPackage) > 0 {
n, err = carrier.Conn.Read(carrier.Cache)
if err != nil {
log.Error("ERROR %s", err)
}
log.Debug("id %s to Conn.Read %d", carrier.AttachedTunnelID, n)
} else {
n, err = io.ReadAtLeast(carrier.Conn, carrier.Cache, 18)
log.Debug("id %s to ReadAtLeast", carrier.AttachedTunnelID)
}
if err != nil {
n = 0
return n, err
}
wrapedPackage = append(wrapedPackage, carrier.Cache[:n]...)
log.Debug("id %s length of conn %d", carrier.AttachedTunnelID, n)
log.Debug("id %s first 18 %s from %s", carrier.AttachedTunnelID, string(wrapedPackage[:18]), carrier.Conn.RemoteAddr().String())
}
if len(carrier.GetReceiveBuff()) > 0 {
log.Debug("id %s trailing %d from %s", carrier.AttachedTunnelID, len(carrier.GetReceiveBuff()), carrier.Conn.RemoteAddr().String())
log.Debug("id %s 18 byte of trailing %s", carrier.AttachedTunnelID, string(carrier.GetReceiveBuff()[:18]))
}
decrypted, err := carrier.Cipher.Decrypt(packageData)
if err != nil {
n = 0
return
}
n = len(decrypted)
copy(carrier.Cache, decrypted)
return
}