diff --git a/protocol/tunnel.go b/protocol/tunnel.go index 9996757..e63db6f 100644 --- a/protocol/tunnel.go +++ b/protocol/tunnel.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/md5" "encoding/hex" + // "fmt" "github.com/TransX/log" "github.com/TransX/tscipher" "net" @@ -56,12 +57,72 @@ func (this *Tunnel) SetUnRegChan(c chan interface{}) { func (this *Tunnel) Run() { //单向的,从src发送到dest //进行注册 this.regChan <- this + src := this.src + dest := this.dest + // cipherDirection := this.cipherDirection + id := this.id + // cache := make([]byte, 1024*4) //4kB + //构建Carrier + queCache := tscipher.NewQueueCache(1) + revCarrier := tscipher.NewCarrier(src, tscipher.NewCipher("XOR"), queCache, id) + sendCarrier := tscipher.NewCarrier(dest, tscipher.NewCipher("XOR"), queCache, id) + //timer + + // for { + // nCh := make(chan int) + go this.receive(revCarrier) + go this.send(sendCarrier) + // log.Info("id %s send %d /receive %d duration %d ms", n, nByte, id, time.Since(srTimer)/1000) + // } + +} + +func (this *Tunnel) receive(revCarrier *tscipher.Carrier) { src := this.src dest := this.dest cipherDirection := this.cipherDirection id := this.id defer func() { - log.Debug("tunnel id %s ends", id) + // log.Debug("tunnel id %s ends", id) + //注销 + // this.unregChan <- this + if r := recover(); r != nil { + if src != nil { + src.Close() + } + if dest != nil { + dest.Close() + } + } + }() + // srTimer := time.Now() //send receive timer + var n int + var err error + for { + rTimer := time.Now() //receive timer + if cipherDirection != RECEIVE { + revCarrier.Cipher = nil + n, err = tscipher.RowReceiveData(revCarrier) + } else { + n, err = tscipher.ReceiveData(revCarrier) + } + // fmt.Println("receive") + log.Info("id %s time to receive %d", id, time.Since(rTimer)/1000) + if err != nil { + log.Panic("Read panic. Tunnel id: %s. Remote Add: %s Local: %s. Err:%s", id, src.RemoteAddr().String(), src.LocalAddr().String(), err.Error()) + } + log.Debug("Reived %d bytes from %s. Tunnel: id %s", n, src.RemoteAddr().String(), id) + // nCh <- n + } +} + +func (this *Tunnel) send(sendCarrier *tscipher.Carrier) { + src := this.src + dest := this.dest + cipherDirection := this.cipherDirection + id := this.id + defer func() { + // log.Debug("tunnel id %s ends", id) //注销 this.unregChan <- this if r := recover(); r != nil { @@ -73,41 +134,21 @@ func (this *Tunnel) Run() { //单向的,从src发送到dest } } }() - cache := make([]byte, 1024*4) //4kB - //构建Carrier - revCarrier := tscipher.NewCarrier(src, tscipher.NewCipher("XOR"), cache, id) - sendCarrier := tscipher.NewCarrier(dest, tscipher.NewCipher("XOR"), cache, id) - //timer - + if cipherDirection != SEND { + sendCarrier.Cipher = nil + } for { - srTimer := time.Now() //send receive timer - var nByte int - var err error - rTimer := time.Now() //receive timer - if cipherDirection != RECEIVE { - revCarrier.Cipher = nil - nByte, err = tscipher.RowReceiveData(revCarrier) - } else { - nByte, err = tscipher.ReceiveData(revCarrier) - } - log.Info("id %s time to receive %d", id, time.Since(rTimer)/1000) - if err != nil { - log.Panic("Read panic. Tunnel id: %s. Remote Add: %s Local: %s. Err:%s", id, src.RemoteAddr().String(), src.LocalAddr().String(), err.Error()) - } - log.Debug("Reived %d bytes from %s. Tunnel: id %s", nByte, src.RemoteAddr().String(), id) - if cipherDirection != SEND { - sendCarrier.Cipher = nil - } + // nByte := <-nCh + // fmt.Println("in send loop\n") sTimer := time.Now() //send timer - n, err := tscipher.SendData(sendCarrier, nByte) + _, err := tscipher.SendData(sendCarrier) + // fmt.Println("send") log.Info("id %s time to sned %d", id, time.Since(sTimer)/1000) if err != nil { log.Panic("Write panic. ID: %s, Err: %s, Remote Add: %s", id, err, dest.RemoteAddr().String()) } - log.Debug("Write %d bytes from %s to %s. Tunnel: %s . 18 bytes %x", n, dest.LocalAddr(), dest.RemoteAddr().String(), id, sendCarrier.Cache[:18]) - log.Info("id %s send %d /receive %d duration %d ms", n, nByte, id, time.Since(srTimer)/1000) + // log.Debug("Write %d bytes from %s to %s. Tunnel: %s . 18 bytes %x", n, dest.LocalAddr(), dest.RemoteAddr().String(), id, sendCarrier.Cache[:18]) } - } func tunnelID() string { diff --git a/tscipher/cipher.go b/tscipher/cipher.go index e61d29f..51e99d7 100644 --- a/tscipher/cipher.go +++ b/tscipher/cipher.go @@ -11,6 +11,38 @@ import ( "strings" ) +type Cache struct { + cache []byte + l int //length, not capability +} + +type QueueCache struct { //FIFO + queue chan *Cache +} + +func NewQueueCache(quN int) *QueueCache { + t := &QueueCache{ + queue: make(chan *Cache, quN), + } + return t +} + +func (this *QueueCache) Put(bits []byte, l int) { + // fmt.Printf("len of chan Put %d\n", len(this.queue)) + this.queue <- &Cache{ + cache: bits, + l: l, + } + // fmt.Printf("QueueCache put %d\n", len(this.queue)) +} + +func (this *QueueCache) Get() ([]byte, int) { + // fmt.Printf("len of chan Get %d\n", len(this.queue)) + t := <-this.queue + // fmt.Printf("QueueCache got %d\n", len(this.queue)) + return t.cache, t.l +} + var StartMark = []byte("#2v!") //should be constant var EndMark = []byte("_=1z") //should be constant @@ -36,20 +68,21 @@ type Cipher interface { } type Carrier struct { - Conn net.Conn - Cipher Cipher - Cache []byte + Conn net.Conn + Cipher Cipher + // Cache []byte + Cache *QueueCache AttachedTunnelID string receiveBuff []byte } -func NewCarrier(conn net.Conn, cipher Cipher, cache []byte, id string) *Carrier { +func NewCarrier(conn net.Conn, cipher Cipher, queCache *QueueCache, id string) *Carrier { t := new(Carrier) t.Conn = conn t.Cipher = cipher - t.Cache = cache + t.Cache = queCache t.AttachedTunnelID = id - t.receiveBuff = make([]byte, 0, len(cache)) + t.receiveBuff = make([]byte, 0, 1024*4) return t } @@ -139,15 +172,17 @@ func UnwrapPackage(pacakge []byte) (data []byte, rest []byte, err error) { } -func SendData(carrier *Carrier, nByte int) (n int, err error) { - if len(carrier.Cache) < nByte { +func SendData(carrier *Carrier) (n int, err error) { + cache, nByte := carrier.Cache.Get() + // fmt.Printf("id %s get cache\n", carrier.AttachedTunnelID) + if len(cache) < nByte { log.Panic("Cache of send is too small") } if carrier.Cipher == nil { - n, err = carrier.Conn.Write(carrier.Cache[:nByte]) + n, err = carrier.Conn.Write(cache[:nByte]) return } - encrypedByte, err := carrier.Cipher.Encrypt(carrier.Cache[:nByte]) + encrypedByte, err := carrier.Cipher.Encrypt(cache[:nByte]) if err != nil { n = 0 return @@ -156,16 +191,18 @@ func SendData(carrier *Carrier, nByte int) (n int, err error) { wraped := WrapPackage(encrypedByte[:nByte]) n, err = carrier.Conn.Write(wraped) log.Debug("Ready to write id %s, 18 byte %s", carrier.AttachedTunnelID, string(wraped[:18])) - copy(carrier.Cache, encrypedByte[:nByte]) // in case of debugging + // copy(carrier.Cache, encrypedByte[:nByte]) // in case of debugging return } func RowReceiveData(carrier *Carrier) (n int, err error) { - n, err = carrier.Conn.Read(carrier.Cache) + cache := make([]byte, 1024*4) + n, err = carrier.Conn.Read(cache) if err != nil { n = 0 - } + carrier.Cache.Put(cache, n) + // fmt.Printf("id %s put cache\n", carrier.AttachedTunnelID) return } @@ -179,6 +216,7 @@ func ReceiveData(carrier *Carrier) (n int, err error) { wrapedPackage := carrier.GetReceiveBuff() //make([]byte, 0, cap(carrier.Cache)) var packageData []byte var _rest []byte + cache := make([]byte, 1024*4) for { //首先检查这个是不是完整的包,是就返回好了,免得被阻塞 data, rest, err := UnwrapPackage(wrapedPackage) @@ -202,13 +240,13 @@ func ReceiveData(carrier *Carrier) (n int, err error) { //如果读到的数据不够一个完整的包 log.Debug("id %s to read wrapedPackage %d", carrier.AttachedTunnelID, len(wrapedPackage)) if len(wrapedPackage) > 0 { - n, err = carrier.Conn.Read(carrier.Cache) + n, err = carrier.Conn.Read(cache) if err != nil { log.Error("ERROR %s", err) } log.Debug("id %s to Conn.Read %d", carrier.AttachedTunnelID, n) } else { - n, err = io.ReadAtLeast(carrier.Conn, carrier.Cache, 18) + n, err = io.ReadAtLeast(carrier.Conn, cache, 18) log.Debug("id %s to ReadAtLeast", carrier.AttachedTunnelID) } @@ -216,7 +254,7 @@ func ReceiveData(carrier *Carrier) (n int, err error) { n = 0 return n, err } - wrapedPackage = append(wrapedPackage, carrier.Cache[:n]...) + wrapedPackage = append(wrapedPackage, cache[:n]...) log.Debug("id %s length of conn %d", carrier.AttachedTunnelID, n) log.Debug("id %s first 18 %s from %s", carrier.AttachedTunnelID, string(wrapedPackage[:18]), carrier.Conn.RemoteAddr().String()) @@ -233,6 +271,8 @@ func ReceiveData(carrier *Carrier) (n int, err error) { return } n = len(decrypted) - copy(carrier.Cache, decrypted) + // copy(cache, decrypted) + carrier.Cache.Put(decrypted, n) + // fmt.Printf("id %s put cache\n", carrier.AttachedTunnelID) return }