diff --git a/cache/queue.go b/cache/queue.go new file mode 100644 index 0000000..4055cce --- /dev/null +++ b/cache/queue.go @@ -0,0 +1,57 @@ +package cache + +type Cache struct { + cache []byte + l int //length of content, not length of slice. +} + +type QueueCache struct { //FIFO + queue chan *Cache +} + +type BlockingQueueCache QueueCache + +type UnblockingQueueCache QueueCache + +func NewUnblockingQueueCache(quN int) *UnblockingQueueCache { + t := &UnblockingQueueCache{ + queue: make(chan *Cache, quN), + } + return t +} + +func (this *UnblockingQueueCache) Put(bits []byte, l int) { + this.queue <- &Cache{ + cache: bits, + l: l, + } +} + +func (this *UnblockingQueueCache) Get() ([]byte, int) { + var t *Cache + select { + case t = <-this.queue: + default: + return make([]byte, 1024*4), 1024 * 4 + } + return t.cache, t.l +} + +func NewBlockingQueueCache(quN int) *BlockingQueueCache { + t := &BlockingQueueCache{ + queue: make(chan *Cache, quN), + } + return t +} + +func (this *BlockingQueueCache) Put(bits []byte, l int) { + this.queue <- &Cache{ + cache: bits, + l: l, + } +} + +func (this *BlockingQueueCache) Get() ([]byte, int) { + t := <-this.queue + return t.cache, t.l +} \ No newline at end of file diff --git a/main_test.go b/main_test.go index bc6133b..00d321e 100644 --- a/main_test.go +++ b/main_test.go @@ -2,7 +2,7 @@ package main import ( "bufio" - "fmt" + // "fmt" "github.com/TransX/log" "github.com/TransX/protocol" "io" @@ -10,6 +10,7 @@ import ( "os" "testing" "time" + "fmt" ) func serverBin(t *testing.T) { @@ -17,7 +18,7 @@ func serverBin(t *testing.T) { reader := bufio.NewReader(file) listener, err := net.Listen("tcp4", "127.0.0.1:1244") if err != nil { - t.Fatal(err) + log.Error("Test Server %s",err.Error()) } var nCount byte nCount = 0 @@ -42,6 +43,10 @@ func serverBin(t *testing.T) { } log.Info("Test Server read per time %d", n) ////<- + if add+n >= len(bytes) { + log.Panic("serverBin reseive out of bound n:%d,add+n :%d", n, add+n) + break + } copy(bytes[add:add+n], _bytes[:n]) add += n log.Info("add %d from %s", add, conn.RemoteAddr().String()) @@ -94,9 +99,9 @@ func serverBin(t *testing.T) { log.Info("Test Server. All Matches.") // log.Info("Test Server Receive %s", string(bytes[:n])) _, err = conn.Write([]byte("OK")) - log.Info("Test Server write") + log.Info("Test Server write OK from %s to %s",conn.LocalAddr().String(),conn.RemoteAddr().String(),) if err != nil { - t.Fatal(err) + log.Error("Test Server %s ",err.Error()) } conn.Close() log.Info("Test Server closed") @@ -150,12 +155,16 @@ func clientBin(t *testing.T) { for { conn, err := net.Dial("tcp4", "127.0.0.1:1200") if err != nil { - t.Fatal(err) + log.Error("Test Client %s",err.Error()) } binBytes := make([]byte, 1024*4) nBinBytes, err := reader.Read(binBytes) if err != nil { - log.Error("client read %", err) + if err==io.EOF{ + fmt.Println("Test Finished.") + break + } + log.Error("client read %s", err.Error()) } toBinWrite := make([]byte, len(binBytes)+1) @@ -165,21 +174,17 @@ func clientBin(t *testing.T) { if n != nBinBytes+1 { log.Error("client not write enough bytes") } - // log.Info("Test Client write %s", string(binBytes[:n])) log.Info("Test Client write %d bytes with count %d", n, nCount) nCount++ bytes := make([]byte, 1024*32) n, err = conn.Read(bytes) - log.Info("Test Client read") + log.Info("Test Client read % bytes (local add %s) msg:%s from %s",n,conn.LocalAddr().String(),bytes[:n],conn.RemoteAddr().String()) if err != nil { - t.Fatal(err) + log.Error("Test Client %s",err.Error()) } - log.Info("Test Client Receive %s", bytes[:n]) - fmt.Println("Test Client Receive ", string(bytes[:n])) - time.Sleep(time.Second * 2) + time.Sleep(time.Second * 0) conn.Close() log.Info("Test Client closed") - } } @@ -206,7 +211,7 @@ func clientText(t *testing.T) { if err != nil { t.Fatal(err) } - log.Info("Test Client Receive %s", bytes[:n]) + // log.Info("Test Client Receive %s", bytes[:n]) time.Sleep(time.Second * 2) conn.Close() log.Info("Test Client closed") diff --git a/protocol/tunnel.go b/protocol/tunnel.go index e63db6f..db6c49a 100644 --- a/protocol/tunnel.go +++ b/protocol/tunnel.go @@ -5,6 +5,7 @@ import ( "crypto/md5" "encoding/hex" // "fmt" + "github.com/TransX/cache" "github.com/TransX/log" "github.com/TransX/tscipher" "net" @@ -63,9 +64,10 @@ func (this *Tunnel) Run() { //单向的,从src发送到dest id := this.id // cache := make([]byte, 1024*4) //4kB //构建Carrier - queCache := tscipher.NewQueueCache(1) - revCarrier := tscipher.NewCarrier(src, tscipher.NewCipher("XOR"), queCache, id) - sendCarrier := tscipher.NewCarrier(dest, tscipher.NewCipher("XOR"), queCache, id) + queCache := cache.NewUnblockingQueueCache(1) + msg := cache.NewBlockingQueueCache(1) + revCarrier := tscipher.NewCarrier(src, tscipher.NewCipher("XOR"), queCache, msg, id) + sendCarrier := tscipher.NewCarrier(dest, tscipher.NewCipher("XOR"), queCache, msg, id) //timer // for { @@ -111,7 +113,7 @@ func (this *Tunnel) receive(revCarrier *tscipher.Carrier) { if err != nil { log.Panic("Read panic. Tunnel id: %s. Remote Add: %s Local: %s. Err:%s", id, src.RemoteAddr().String(), src.LocalAddr().String(), err.Error()) } - log.Debug("Reived %d bytes from %s. Tunnel: id %s", n, src.RemoteAddr().String(), id) + log.Debug("Reived %d bytes (local add %s) from %s. Tunnel: id %s", n, src.LocalAddr().String(),src.RemoteAddr().String(), id) // nCh <- n } } @@ -143,7 +145,7 @@ func (this *Tunnel) send(sendCarrier *tscipher.Carrier) { sTimer := time.Now() //send timer _, err := tscipher.SendData(sendCarrier) // fmt.Println("send") - log.Info("id %s time to sned %d", id, time.Since(sTimer)/1000) + log.Info("id %s time to send %d", id, time.Since(sTimer)/1000) if err != nil { log.Panic("Write panic. ID: %s, Err: %s, Remote Add: %s", id, err, dest.RemoteAddr().String()) } diff --git a/tscipher/cipher.go b/tscipher/cipher.go index 51e99d7..8cd4b44 100644 --- a/tscipher/cipher.go +++ b/tscipher/cipher.go @@ -3,6 +3,7 @@ package tscipher import ( "bytes" "fmt" + "github.com/TransX/cache" "github.com/TransX/log" "github.com/TransX/utils" "io" @@ -11,38 +12,6 @@ import ( "strings" ) -type Cache struct { - cache []byte - l int //length, not capability -} - -type QueueCache struct { //FIFO - queue chan *Cache -} - -func NewQueueCache(quN int) *QueueCache { - t := &QueueCache{ - queue: make(chan *Cache, quN), - } - return t -} - -func (this *QueueCache) Put(bits []byte, l int) { - // fmt.Printf("len of chan Put %d\n", len(this.queue)) - this.queue <- &Cache{ - cache: bits, - l: l, - } - // fmt.Printf("QueueCache put %d\n", len(this.queue)) -} - -func (this *QueueCache) Get() ([]byte, int) { - // fmt.Printf("len of chan Get %d\n", len(this.queue)) - t := <-this.queue - // fmt.Printf("QueueCache got %d\n", len(this.queue)) - return t.cache, t.l -} - var StartMark = []byte("#2v!") //should be constant var EndMark = []byte("_=1z") //should be constant @@ -71,23 +40,25 @@ type Carrier struct { Conn net.Conn Cipher Cipher // Cache []byte - Cache *QueueCache + Cache *cache.UnblockingQueueCache + Msg *cache.BlockingQueueCache AttachedTunnelID string receiveBuff []byte } -func NewCarrier(conn net.Conn, cipher Cipher, queCache *QueueCache, id string) *Carrier { +func NewCarrier(conn net.Conn, cipher Cipher, queCache *cache.UnblockingQueueCache, msg *cache.BlockingQueueCache, id string) *Carrier { t := new(Carrier) t.Conn = conn t.Cipher = cipher t.Cache = queCache + t.Msg = msg t.AttachedTunnelID = id t.receiveBuff = make([]byte, 0, 1024*4) return t } func (this *Carrier) GetReceiveBuff() []byte { - log.Debug("id %d receivebuff Get, len %d", this.AttachedTunnelID, len(this.receiveBuff)) + // log.Debug("id %d receivebuff Get, len %d", this.AttachedTunnelID, len(this.receiveBuff)) buff := this.receiveBuff _b := make([]byte, len(buff), cap(buff)) //必须这样写,没错。 copy(_b, buff) @@ -96,7 +67,7 @@ func (this *Carrier) GetReceiveBuff() []byte { func (this *Carrier) SetReceiveBuff(buff []byte) { this.receiveBuff = buff - log.Debug("id %d receivebuff set, len %d", this.AttachedTunnelID, len(this.receiveBuff)) + // log.Debug("id %d receivebuff set, len %d", this.AttachedTunnelID, len(this.receiveBuff)) } func NewCipher(cipherName string) (cipher Cipher) { @@ -173,16 +144,18 @@ func UnwrapPackage(pacakge []byte) (data []byte, rest []byte, err error) { } func SendData(carrier *Carrier) (n int, err error) { - cache, nByte := carrier.Cache.Get() + // cache, nByte := carrier.Cache.Get() + msg, nByte := carrier.Msg.Get() + // fmt.Printf("SendData id %s get msg %d\n", carrier.AttachedTunnelID, nByte) // fmt.Printf("id %s get cache\n", carrier.AttachedTunnelID) - if len(cache) < nByte { + if len(msg) < nByte { log.Panic("Cache of send is too small") } if carrier.Cipher == nil { - n, err = carrier.Conn.Write(cache[:nByte]) + n, err = carrier.Conn.Write(msg[:nByte]) return } - encrypedByte, err := carrier.Cipher.Encrypt(cache[:nByte]) + encrypedByte, err := carrier.Cipher.Encrypt(msg[:nByte]) if err != nil { n = 0 return @@ -190,19 +163,24 @@ func SendData(carrier *Carrier) (n int, err error) { //打包 wraped := WrapPackage(encrypedByte[:nByte]) n, err = carrier.Conn.Write(wraped) - log.Debug("Ready to write id %s, 18 byte %s", carrier.AttachedTunnelID, string(wraped[:18])) + // log.Debug("Ready to write id %s, 18 byte %s", carrier.AttachedTunnelID, string(wraped[:18])) // copy(carrier.Cache, encrypedByte[:nByte]) // in case of debugging + // carrier.Cache.Put(msg, len(msg)) + carrier.Cache.Put(make([]byte, 1024*4), 1024*4) + // fmt.Printf("SendData %s id put cache\n", carrier.AttachedTunnelID) return } func RowReceiveData(carrier *Carrier) (n int, err error) { - cache := make([]byte, 1024*4) + cache, _ := carrier.Cache.Get() + // fmt.Printf("RowReceiveData id %s get cache\n", carrier.AttachedTunnelID) n, err = carrier.Conn.Read(cache) if err != nil { n = 0 } - carrier.Cache.Put(cache, n) - // fmt.Printf("id %s put cache\n", carrier.AttachedTunnelID) + // carrier.Cache.Put(cache, n) + // fmt.Printf("RowReceiveData id %s pet msg\n", carrier.AttachedTunnelID) + carrier.Msg.Put(cache, n) return } @@ -212,11 +190,11 @@ func ReceiveData(carrier *Carrier) (n int, err error) { // log.Error("ReceiveData err %s", r) // } // }() - log.Debug("id %s wrapedPackage := carrier.GetReceiveBuff()", carrier.AttachedTunnelID) + // log.Debug("id %s wrapedPackage := carrier.GetReceiveBuff()", carrier.AttachedTunnelID) wrapedPackage := carrier.GetReceiveBuff() //make([]byte, 0, cap(carrier.Cache)) var packageData []byte var _rest []byte - cache := make([]byte, 1024*4) + cache, _ := carrier.Cache.Get() for { //首先检查这个是不是完整的包,是就返回好了,免得被阻塞 data, rest, err := UnwrapPackage(wrapedPackage) @@ -226,28 +204,27 @@ func ReceiveData(carrier *Carrier) (n int, err error) { log.Debug("return NotPackageError %s", carrier.AttachedTunnelID) return 0, err } - log.Debug("id %s length of package %d", carrier.AttachedTunnelID, len(packageData)) + // log.Debug("id %s length of package %d", carrier.AttachedTunnelID, len(packageData)) if err == nil { //够一个完整的包 - log.Debug("id %s capBuff := cap(carrier.GetReceiveBuff())", carrier.AttachedTunnelID) capBuff := cap(carrier.GetReceiveBuff()) _buff := make([]byte, 0, capBuff) //释放 _buff = append(_buff, _rest...) - log.Debug("id %s carrier.SetReceiveBuff(_buff)", carrier.AttachedTunnelID) + // log.Debug("id %s carrier.SetReceiveBuff(_buff)", carrier.AttachedTunnelID) carrier.SetReceiveBuff(_buff) break } //如果读到的数据不够一个完整的包 - log.Debug("id %s to read wrapedPackage %d", carrier.AttachedTunnelID, len(wrapedPackage)) + // log.Debug("id %s to read wrapedPackage %d", carrier.AttachedTunnelID, len(wrapedPackage)) if len(wrapedPackage) > 0 { n, err = carrier.Conn.Read(cache) if err != nil { log.Error("ERROR %s", err) } - log.Debug("id %s to Conn.Read %d", carrier.AttachedTunnelID, n) + // log.Debug("id %s to Conn.Read %d", carrier.AttachedTunnelID, n) } else { n, err = io.ReadAtLeast(carrier.Conn, cache, 18) - log.Debug("id %s to ReadAtLeast", carrier.AttachedTunnelID) + // log.Debug("id %s to ReadAtLeast", carrier.AttachedTunnelID) } if err != nil { @@ -256,15 +233,15 @@ func ReceiveData(carrier *Carrier) (n int, err error) { } wrapedPackage = append(wrapedPackage, cache[:n]...) - log.Debug("id %s length of conn %d", carrier.AttachedTunnelID, n) - log.Debug("id %s first 18 %s from %s", carrier.AttachedTunnelID, string(wrapedPackage[:18]), carrier.Conn.RemoteAddr().String()) + // log.Debug("id %s length of conn %d", carrier.AttachedTunnelID, n) + // log.Debug("id %s first 18 %s from %s", carrier.AttachedTunnelID, string(wrapedPackage[:18]), carrier.Conn.RemoteAddr().String()) } - if len(carrier.GetReceiveBuff()) > 0 { - log.Debug("id %s trailing %d from %s", carrier.AttachedTunnelID, len(carrier.GetReceiveBuff()), carrier.Conn.RemoteAddr().String()) - log.Debug("id %s 18 byte of trailing %s", carrier.AttachedTunnelID, string(carrier.GetReceiveBuff()[:18])) - } + // if len(carrier.GetReceiveBuff()) > 0 { + // log.Debug("id %s trailing %d from %s", carrier.AttachedTunnelID, len(carrier.GetReceiveBuff()), carrier.Conn.RemoteAddr().String()) + // log.Debug("id %s 18 byte of trailing %s", carrier.AttachedTunnelID, string(carrier.GetReceiveBuff()[:18])) + // } decrypted, err := carrier.Cipher.Decrypt(packageData) if err != nil { n = 0 @@ -272,7 +249,9 @@ func ReceiveData(carrier *Carrier) (n int, err error) { } n = len(decrypted) // copy(cache, decrypted) - carrier.Cache.Put(decrypted, n) + // carrier.Cache.Put(decrypted, n) + // fmt.Printf("ReceiveData id %s get msg\n", carrier.AttachedTunnelID) + carrier.Msg.Put(decrypted, n) // fmt.Printf("id %s put cache\n", carrier.AttachedTunnelID) return }