diff --git a/benckmark.go b/benckmark.go index f656f6e..ceb34b0 100644 --- a/benckmark.go +++ b/benckmark.go @@ -68,7 +68,7 @@ func attack() { log.Println("Finish") } -func main() { +func main1() { cfg := profile.Config{ MemProfile: true, ProfilePath: "./profile", // store profiles in current directory diff --git a/communicator/receiver.go b/communicator/receiver.go new file mode 100644 index 0000000..df0156a --- /dev/null +++ b/communicator/receiver.go @@ -0,0 +1,104 @@ +package communicator + +import ( + "io" + + "github.com/TransX/log" + "github.com/TransX/tscipher" +) + +type Receiver interface { + //SendData(carrier *Carrier) (n int, err error) + Receive() (n int, err error) + Close() +} + +func NewNormalReceiver(carrier *tscipher.Carrier) *NormalReceiver { + r := new(NormalReceiver) + r.carrier = carrier + return r +} + +type NormalReceiver struct { + Receiver + carrier *tscipher.Carrier +} + +func (this *NormalReceiver) Receive() (n int, err error) { + carrier := this.carrier + wrapedPackage := carrier.GetReceiveBuff() //make([]byte, 0, cap(carrier.Cache)) + var packageData []byte + var _rest []byte + cache, _ := carrier.Cache.Get() + for { + //首先检查这个是不是完整的包,是就返回好了,免得被阻塞 + data, rest, err := tscipher.UnwrapPackage(wrapedPackage) + packageData = data + _rest = rest + if err, ok := err.(*tscipher.NotPackageError); len(wrapedPackage) >= 18 && ok { + log.Debug("return NotPackageError %s", carrier.AttachedTunnelID) + return 0, err + } + if err == nil { + //够一个完整的包 + capBuff := cap(carrier.GetReceiveBuff()) + _buff := make([]byte, 0, capBuff) //释放 + _buff = append(_buff, _rest...) + carrier.SetReceiveBuff(_buff) + break + } + //如果读到的数据不够一个完整的包 + if len(wrapedPackage) > 0 { + n, err = carrier.Conn.Read(cache) + if err != nil { + log.Error("ERROR %s", err.Error()) + } + } else { + n, err = io.ReadAtLeast(carrier.Conn, cache, 18) + } + + if err != nil { + n = 0 + return n, err + } + wrapedPackage = append(wrapedPackage, cache[:n]...) + } + decrypted, err := carrier.Cipher.Decrypt(packageData) + if err != nil { + n = 0 + return + } + n = len(decrypted) + carrier.Msg.Put(decrypted, n) + return +} + +func (this *NormalReceiver) Close() { + this.carrier.Conn.Close() +} + +func NewDirectReceiver(carrier *tscipher.Carrier) *DirectReceiver { + r := new(DirectReceiver) + r.carrier = carrier + return r +} + +type DirectReceiver struct { + Receiver + carrier *tscipher.Carrier +} + +func (this *DirectReceiver) Receive() (n int, err error) { + carrier := this.carrier + cache, _ := carrier.Cache.Get() + n, err = carrier.Conn.Read(cache) + if err != nil { + n = 0 + } + carrier.Msg.Put(cache, n) + return +} + +func (this *DirectReceiver) Close() { + this.carrier.Conn.Close() +} diff --git a/communicator/sender.go b/communicator/sender.go new file mode 100644 index 0000000..9e167b9 --- /dev/null +++ b/communicator/sender.go @@ -0,0 +1,79 @@ +package communicator + +import ( + "log" + + "github.com/TransX/tscipher" +) + +type Sender interface { + //SendData(carrier *Carrier) (n int, err error) + Send() (n int, err error) + Close() +} + +func NewNormalSender(carrier *tscipher.Carrier) *NormalSender { + r := new(NormalSender) + r.carrier = carrier + return r +} + +type NormalSender struct { + Sender + carrier *tscipher.Carrier +} + +func (this *NormalSender) Send() (n int, err error) { + carrier := this.carrier + msg, nByte := carrier.Msg.Get() + if len(msg) < nByte { + log.Panic("Cache of send is too small") + } + // if carrier.Cipher == nil { + // n, err = carrier.Conn.Write(msg[:nByte]) + // carrier.Cache.Put(make([]byte, 1024*4), 1024*4) + // return + // } + encrypedByte, err := carrier.Cipher.Encrypt(msg[:nByte]) + if err != nil { + n = 0 + return + } + //打包 + wraped := tscipher.WrapPackage(encrypedByte[:nByte]) + n, err = carrier.Conn.Write(wraped) + carrier.Cache.Put(make([]byte, 1024*4), 1024*4) + return +} + +func (this *NormalSender) Close() { + this.carrier.Conn.Close() +} + +func NewDirectSender(carrier *tscipher.Carrier) *DirectSender { + r := new(DirectSender) + r.carrier = carrier + return r +} + +type DirectSender struct { + Sender + carrier *tscipher.Carrier +} + +func (this *DirectSender) Send() (n int, err error) { + carrier := this.carrier + msg, nByte := carrier.Msg.Get() + if len(msg) < nByte { + log.Panic("Cache of send is too small") + } + + n, err = carrier.Conn.Write(msg[:nByte]) + carrier.Cache.Put(make([]byte, 1024*4), 1024*4) + return + +} + +func (this *DirectSender) Close() { + this.carrier.Conn.Close() +} diff --git a/main.go b/main.go index 2341109..113c454 100644 --- a/main.go +++ b/main.go @@ -25,7 +25,7 @@ func tunnel() { } -func main1() { +func main() { // defer profile.Start(profile.CPUProfile).Stop() flag.Parse() fmt.Println("Hello World!") diff --git a/model/tunnel.go b/model/tunnel.go index d59f8ac..26fbf66 100644 --- a/model/tunnel.go +++ b/model/tunnel.go @@ -3,13 +3,15 @@ package model import ( // "fmt" + "net" + "time" + "github.com/TransX/cache" + "github.com/TransX/communicator" "github.com/TransX/constant" "github.com/TransX/log" "github.com/TransX/tscipher" "github.com/spf13/viper" - "net" - "time" ) type Tunnel struct { @@ -95,14 +97,18 @@ func (this *Tunnel) receive(revCarrier *tscipher.Carrier) { defer this.onError() var n int var err error + var receiver communicator.Receiver + if cipherDirection != constant.RECEIVE { + // revCarrier.Cipher = nil + // n, err = tscipher.RowReceiveData(revCarrier) + receiver = communicator.NewDirectReceiver(revCarrier) + } else { + receiver = communicator.NewNormalReceiver(revCarrier) + } + for { rTimer := time.Now() //receive timer - if cipherDirection != constant.RECEIVE { - revCarrier.Cipher = nil - n, err = tscipher.RowReceiveData(revCarrier) - } else { - n, err = tscipher.ReceiveData(revCarrier) - } + n, err = receiver.Receive() log.Info("id %s time to receive %d", id, time.Since(rTimer)/1000) if err != nil { log.Panic("Read panic. Tunnel id: %s. Remote Add: %s Local: %s. Err:%s", id, src.RemoteAddr().String(), src.LocalAddr().String(), err.Error()) @@ -115,13 +121,17 @@ func (this *Tunnel) send(sendCarrier *tscipher.Carrier) { dest := this.dest cipherDirection := this.cipherDirection id := this.id + var sender communicator.Sender defer this.onError() if cipherDirection != constant.SEND { - sendCarrier.Cipher = nil + // sendCarrier.Cipher = nil + sender = communicator.NewDirectSender(sendCarrier) + } else { + sender = communicator.NewNormalSender(sendCarrier) } for { sTimer := time.Now() //send timer - n, err := tscipher.SendData(sendCarrier) + n, err := sender.Send() log.Info("id %s time to send %d", id, time.Since(sTimer)/1000) if err != nil { log.Panic("Write panic. ID: %s, Err: %s, Remote Add: %s", id, err, dest.RemoteAddr().String()) diff --git a/protocol/tcp.go b/protocol/tcp.go index a24ff61..0653146 100644 --- a/protocol/tcp.go +++ b/protocol/tcp.go @@ -80,15 +80,9 @@ func (this *TransTCP) Start(listenPort, destIP, destPort string, clientOrServer if clientOrServer == "client" { //加密方向 sendID := utils.TunnelID() ntSend := model.NewTunnel(sendID, listenerConn, destConn, constant.SEND) - // ntSend.SetRegChan(tunMng.GetRegChan()) - // ntSend.SetUnRegChan(tunMng.GetUnregChan()) receiveID := utils.TunnelID() ntRev := model.NewTunnel(receiveID, destConn, listenerConn, constant.RECEIVE) - // ntRev.SetRegChan(tunMng.GetRegChan()) - // ntRev.SetUnRegChan(tunMng.GetUnregChan()) printModelDetail(sendID, receiveID) - // go ntSend.Run() - // go ntRev.Run() tunnelPair := model.NewTunnelPair(ntSend, ntRev) tunnelPair.SetRegChan(tunMng.GetRegChan()) tunnelPair.SetUnRegChan(tunMng.GetUnregChan()) @@ -97,15 +91,9 @@ func (this *TransTCP) Start(listenPort, destIP, destPort string, clientOrServer if clientOrServer == "server" { receiveID := utils.TunnelID() ntRev := model.NewTunnel(receiveID, listenerConn, destConn, constant.RECEIVE) - // ntRev.SetRegChan(tunMng.GetRegChan()) - // ntRev.SetUnRegChan(tunMng.GetUnregChan()) sendID := utils.TunnelID() ntSend := model.NewTunnel(sendID, destConn, listenerConn, constant.SEND) - // ntSend.SetRegChan(tunMng.GetRegChan()) - // ntSend.SetUnRegChan(tunMng.GetUnregChan()) printModelDetail(sendID, receiveID) - // go ntRev.Run() - // go ntSend.Run() tunnelPair := model.NewTunnelPair(ntSend, ntRev) tunnelPair.SetRegChan(tunMng.GetRegChan()) tunnelPair.SetUnRegChan(tunMng.GetUnregChan()) diff --git a/tscipher/cipher.go b/tscipher/cipher.go index 4fc97ba..8991afb 100644 --- a/tscipher/cipher.go +++ b/tscipher/cipher.go @@ -3,14 +3,13 @@ package tscipher import ( "bytes" "fmt" - "github.com/TransX/cache" - "github.com/TransX/log" - "github.com/TransX/utils" - "github.com/spf13/viper" - "io" "net" "strconv" "strings" + + "github.com/TransX/cache" + "github.com/TransX/utils" + "github.com/spf13/viper" ) // var StartMark = []byte("#2v!") //should be constant @@ -141,82 +140,82 @@ func UnwrapPackage(pacakge []byte) (data []byte, rest []byte, err error) { } -func SendData(carrier *Carrier) (n int, err error) { - msg, nByte := carrier.Msg.Get() - if len(msg) < nByte { - log.Panic("Cache of send is too small") - } - if carrier.Cipher == nil { - n, err = carrier.Conn.Write(msg[:nByte]) - carrier.Cache.Put(make([]byte, 1024*4), 1024*4) - return - } - encrypedByte, err := carrier.Cipher.Encrypt(msg[:nByte]) - if err != nil { - n = 0 - return - } - //打包 - wraped := WrapPackage(encrypedByte[:nByte]) - n, err = carrier.Conn.Write(wraped) - carrier.Cache.Put(make([]byte, 1024*4), 1024*4) - return -} - -func RowReceiveData(carrier *Carrier) (n int, err error) { - cache, _ := carrier.Cache.Get() - n, err = carrier.Conn.Read(cache) - if err != nil { - n = 0 - } - carrier.Msg.Put(cache, n) - return -} - -func ReceiveData(carrier *Carrier) (n int, err error) { - wrapedPackage := carrier.GetReceiveBuff() //make([]byte, 0, cap(carrier.Cache)) - var packageData []byte - var _rest []byte - cache, _ := carrier.Cache.Get() - for { - //首先检查这个是不是完整的包,是就返回好了,免得被阻塞 - data, rest, err := UnwrapPackage(wrapedPackage) - packageData = data - _rest = rest - if err, ok := err.(*NotPackageError); len(wrapedPackage) >= 18 && ok { - log.Debug("return NotPackageError %s", carrier.AttachedTunnelID) - return 0, err - } - if err == nil { - //够一个完整的包 - capBuff := cap(carrier.GetReceiveBuff()) - _buff := make([]byte, 0, capBuff) //释放 - _buff = append(_buff, _rest...) - carrier.SetReceiveBuff(_buff) - break - } - //如果读到的数据不够一个完整的包 - if len(wrapedPackage) > 0 { - n, err = carrier.Conn.Read(cache) - if err != nil { - log.Error("ERROR %s", err) - } - } else { - n, err = io.ReadAtLeast(carrier.Conn, cache, 18) - } - - if err != nil { - n = 0 - return n, err - } - wrapedPackage = append(wrapedPackage, cache[:n]...) - } - decrypted, err := carrier.Cipher.Decrypt(packageData) - if err != nil { - n = 0 - return - } - n = len(decrypted) - carrier.Msg.Put(decrypted, n) - return -} +// func SendData(carrier *Carrier) (n int, err error) { +// msg, nByte := carrier.Msg.Get() +// if len(msg) < nByte { +// log.Panic("Cache of send is too small") +// } +// if carrier.Cipher == nil { +// n, err = carrier.Conn.Write(msg[:nByte]) +// carrier.Cache.Put(make([]byte, 1024*4), 1024*4) +// return +// } +// encrypedByte, err := carrier.Cipher.Encrypt(msg[:nByte]) +// if err != nil { +// n = 0 +// return +// } +// //打包 +// wraped := WrapPackage(encrypedByte[:nByte]) +// n, err = carrier.Conn.Write(wraped) +// carrier.Cache.Put(make([]byte, 1024*4), 1024*4) +// return +// } +// +// func RowReceiveData(carrier *Carrier) (n int, err error) { +// cache, _ := carrier.Cache.Get() +// n, err = carrier.Conn.Read(cache) +// if err != nil { +// n = 0 +// } +// carrier.Msg.Put(cache, n) +// return +// } +// +// func ReceiveData(carrier *Carrier) (n int, err error) { +// wrapedPackage := carrier.GetReceiveBuff() //make([]byte, 0, cap(carrier.Cache)) +// var packageData []byte +// var _rest []byte +// cache, _ := carrier.Cache.Get() +// for { +// //首先检查这个是不是完整的包,是就返回好了,免得被阻塞 +// data, rest, err := UnwrapPackage(wrapedPackage) +// packageData = data +// _rest = rest +// if err, ok := err.(*NotPackageError); len(wrapedPackage) >= 18 && ok { +// log.Debug("return NotPackageError %s", carrier.AttachedTunnelID) +// return 0, err +// } +// if err == nil { +// //够一个完整的包 +// capBuff := cap(carrier.GetReceiveBuff()) +// _buff := make([]byte, 0, capBuff) //释放 +// _buff = append(_buff, _rest...) +// carrier.SetReceiveBuff(_buff) +// break +// } +// //如果读到的数据不够一个完整的包 +// if len(wrapedPackage) > 0 { +// n, err = carrier.Conn.Read(cache) +// if err != nil { +// log.Error("ERROR %s", err) +// } +// } else { +// n, err = io.ReadAtLeast(carrier.Conn, cache, 18) +// } +// +// if err != nil { +// n = 0 +// return n, err +// } +// wrapedPackage = append(wrapedPackage, cache[:n]...) +// } +// decrypted, err := carrier.Cipher.Decrypt(packageData) +// if err != nil { +// n = 0 +// return +// } +// n = len(decrypted) +// carrier.Msg.Put(decrypted, n) +// return +// }