diff --git a/cache/queue.go b/cache/queue.go index 4055cce..97eea15 100644 --- a/cache/queue.go +++ b/cache/queue.go @@ -29,11 +29,11 @@ func (this *UnblockingQueueCache) Put(bits []byte, l int) { func (this *UnblockingQueueCache) Get() ([]byte, int) { var t *Cache - select { - case t = <-this.queue: - default: - return make([]byte, 1024*4), 1024 * 4 - } + // select { + t = <-this.queue + // default: + // return make([]byte, 1024*4), 1024 * 4 + // } return t.cache, t.l } diff --git a/cache/queue_test.go b/cache/queue_test.go new file mode 100644 index 0000000..454edc3 --- /dev/null +++ b/cache/queue_test.go @@ -0,0 +1,25 @@ +package cache + +import( + "testing" + "fmt" +) + + +func GetAndPut(name string,get *BlockingQueueCache,put *BlockingQueueCache){ + for{ + r,_:=get.Get() + fmt.Printf("%s Get\n",name) + put.Put(r,1) + fmt.Printf("%s Put\n",name) +} + +} + +func TestAsynchronousCache(t *testing.T){ + a:=NewBlockingQueueCache(1) + b:=NewBlockingQueueCache(1) + a.Put(make([]byte,1),1) + go GetAndPut("A",a,b) + GetAndPut("B",b,a) +} \ No newline at end of file diff --git a/protocol/tcp.go b/protocol/tcp.go index 3af5ea8..8fc3100 100644 --- a/protocol/tcp.go +++ b/protocol/tcp.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/TransX/log" "github.com/TransX/stat" + "github.com/TransX/utils" "net" "os" ) @@ -75,22 +76,28 @@ func (this *TransTCP) Start(listenPort, destIP, destPort string, clientOrServer log.Info("Dial %s", destConn.RemoteAddr().String()) //tunnel model : [ -->>server ---- client -->> ](this is a tunnel) if clientOrServer == "client" { - ntSend := NewTunnel(listenerConn, destConn, SEND) + sendID := utils.TunnelID() + ntSend := NewTunnel(sendID, listenerConn, destConn, SEND) ntSend.SetRegChan(tunMng.GetRegChan()) ntSend.SetUnRegChan(tunMng.GetUnregChan()) - ntRev := NewTunnel(destConn, listenerConn, RECEIVE) + receiveID := utils.TunnelID() + ntRev := NewTunnel(receiveID, destConn, listenerConn, RECEIVE) ntRev.SetRegChan(tunMng.GetRegChan()) ntRev.SetUnRegChan(tunMng.GetUnregChan()) + printModelDetail(sendID, receiveID) go ntSend.Run() go ntRev.Run() } if clientOrServer == "server" { - ntRev := NewTunnel(listenerConn, destConn, RECEIVE) + receiveID := utils.TunnelID() + ntRev := NewTunnel(receiveID, listenerConn, destConn, RECEIVE) ntRev.SetRegChan(tunMng.GetRegChan()) ntRev.SetUnRegChan(tunMng.GetUnregChan()) - ntSend := NewTunnel(destConn, listenerConn, SEND) + sendID := utils.TunnelID() + ntSend := NewTunnel(sendID, destConn, listenerConn, SEND) ntSend.SetRegChan(tunMng.GetRegChan()) ntSend.SetUnRegChan(tunMng.GetUnregChan()) + printModelDetail(sendID, receiveID) go ntRev.Run() go ntSend.Run() } @@ -101,3 +108,7 @@ func (this *TransTCP) Start(listenPort, destIP, destPort string, clientOrServer } } } + +func printModelDetail(tunnelIDA, tunnelIDB string) { + log.Info("id %s and id %s belong to a model", tunnelIDA, tunnelIDB) +} diff --git a/protocol/tunnel.go b/protocol/tunnel.go index db6c49a..6ca234e 100644 --- a/protocol/tunnel.go +++ b/protocol/tunnel.go @@ -1,25 +1,15 @@ package protocol import ( - "bytes" - "crypto/md5" - "encoding/hex" + // "fmt" "github.com/TransX/cache" "github.com/TransX/log" "github.com/TransX/tscipher" "net" - "strconv" - "sync/atomic" "time" ) -var seed int32 - -func init() { - seed = 0 -} - type Tunnel struct { id string src net.Conn @@ -29,9 +19,9 @@ type Tunnel struct { unregChan chan interface{} } -func NewTunnel(src, dest net.Conn, cipherDirection Direction) *Tunnel { +func NewTunnel(id string, src, dest net.Conn, cipherDirection Direction) *Tunnel { return &Tunnel{ - id: tunnelID(), + id: id, src: src, dest: dest, cipherDirection: cipherDirection, @@ -65,17 +55,14 @@ func (this *Tunnel) Run() { //单向的,从src发送到dest // cache := make([]byte, 1024*4) //4kB //构建Carrier queCache := cache.NewUnblockingQueueCache(1) + for i := 0; i < 1; i++ { + queCache.Put(make([]byte, 1024*4), 0) + } msg := cache.NewBlockingQueueCache(1) revCarrier := tscipher.NewCarrier(src, tscipher.NewCipher("XOR"), queCache, msg, id) sendCarrier := tscipher.NewCarrier(dest, tscipher.NewCipher("XOR"), queCache, msg, id) - //timer - - // for { - // nCh := make(chan int) go this.receive(revCarrier) go this.send(sendCarrier) - // log.Info("id %s send %d /receive %d duration %d ms", n, nByte, id, time.Since(srTimer)/1000) - // } } @@ -97,7 +84,6 @@ func (this *Tunnel) receive(revCarrier *tscipher.Carrier) { } } }() - // srTimer := time.Now() //send receive timer var n int var err error for { @@ -108,13 +94,11 @@ func (this *Tunnel) receive(revCarrier *tscipher.Carrier) { } else { n, err = tscipher.ReceiveData(revCarrier) } - // fmt.Println("receive") log.Info("id %s time to receive %d", id, time.Since(rTimer)/1000) if err != nil { log.Panic("Read panic. Tunnel id: %s. Remote Add: %s Local: %s. Err:%s", id, src.RemoteAddr().String(), src.LocalAddr().String(), err.Error()) } - log.Debug("Reived %d bytes (local add %s) from %s. Tunnel: id %s", n, src.LocalAddr().String(),src.RemoteAddr().String(), id) - // nCh <- n + log.Debug("Reived %d bytes (local add %s) from %s. Tunnel: id %s", n, src.LocalAddr().String(), src.RemoteAddr().String(), id) } } @@ -124,7 +108,6 @@ func (this *Tunnel) send(sendCarrier *tscipher.Carrier) { cipherDirection := this.cipherDirection id := this.id defer func() { - // log.Debug("tunnel id %s ends", id) //注销 this.unregChan <- this if r := recover(); r != nil { @@ -140,22 +123,12 @@ func (this *Tunnel) send(sendCarrier *tscipher.Carrier) { sendCarrier.Cipher = nil } for { - // nByte := <-nCh - // fmt.Println("in send loop\n") sTimer := time.Now() //send timer - _, err := tscipher.SendData(sendCarrier) - // fmt.Println("send") + n, err := tscipher.SendData(sendCarrier) log.Info("id %s time to send %d", id, time.Since(sTimer)/1000) if err != nil { log.Panic("Write panic. ID: %s, Err: %s, Remote Add: %s", id, err, dest.RemoteAddr().String()) } - // log.Debug("Write %d bytes from %s to %s. Tunnel: %s . 18 bytes %x", n, dest.LocalAddr(), dest.RemoteAddr().String(), id, sendCarrier.Cache[:18]) + log.Info("id %s Send %d bytes (local add %s),to %s", id, n, dest.LocalAddr().String(), dest.RemoteAddr().String()) } } - -func tunnelID() string { - nowString := time.Now().String() + strconv.Itoa(int(seed)) - atomic.AddInt32(&seed, 1) //避免多线程情况下获得的种子相同 - md5Byte := md5.Sum(bytes.NewBufferString(nowString).Bytes()) - return hex.EncodeToString(md5Byte[:]) -} diff --git a/tscipher/cipher.go b/tscipher/cipher.go index b73badb..d0af920 100644 --- a/tscipher/cipher.go +++ b/tscipher/cipher.go @@ -37,9 +37,8 @@ type Cipher interface { } type Carrier struct { - Conn net.Conn - Cipher Cipher - // Cache []byte + Conn net.Conn + Cipher Cipher Cache *cache.UnblockingQueueCache Msg *cache.BlockingQueueCache AttachedTunnelID string @@ -138,13 +137,17 @@ func UnwrapPackage(pacakge []byte) (data []byte, rest []byte, err error) { func SendData(carrier *Carrier) (n int, err error) { msg, nByte := carrier.Msg.Get() + id := carrier.AttachedTunnelID + log.Info("id %s Get Msg", id) if len(msg) < nByte { log.Panic("Cache of send is too small") } if carrier.Cipher == nil { n, err = carrier.Conn.Write(msg[:nByte]) + carrier.Cache.Put(make([]byte, 1024*4), 1024*4) return } + log.Info("id %s AAAAAAAaaa", id) encrypedByte, err := carrier.Cipher.Encrypt(msg[:nByte]) if err != nil { n = 0 @@ -154,16 +157,20 @@ func SendData(carrier *Carrier) (n int, err error) { wraped := WrapPackage(encrypedByte[:nByte]) n, err = carrier.Conn.Write(wraped) carrier.Cache.Put(make([]byte, 1024*4), 1024*4) + log.Info("id %s give back cache", id) return } func RowReceiveData(carrier *Carrier) (n int, err error) { cache, _ := carrier.Cache.Get() + log.Info("id %s get Cache", carrier.AttachedTunnelID) n, err = carrier.Conn.Read(cache) if err != nil { n = 0 } carrier.Msg.Put(cache, n) + id := carrier.AttachedTunnelID + log.Info("id %s put Msg", id) return } @@ -178,6 +185,7 @@ func ReceiveData(carrier *Carrier) (n int, err error) { var packageData []byte var _rest []byte cache, _ := carrier.Cache.Get() + log.Info("id %s get Cache", carrier.AttachedTunnelID) for { //首先检查这个是不是完整的包,是就返回好了,免得被阻塞 data, rest, err := UnwrapPackage(wrapedPackage) @@ -218,5 +226,7 @@ func ReceiveData(carrier *Carrier) (n int, err error) { } n = len(decrypted) carrier.Msg.Put(decrypted, n) + id := carrier.AttachedTunnelID + log.Info("id %s put Msg", id) return } diff --git a/utils/tunnelid.go b/utils/tunnelid.go new file mode 100644 index 0000000..e7e4a35 --- /dev/null +++ b/utils/tunnelid.go @@ -0,0 +1,23 @@ +package utils + +import ( + "bytes" + "crypto/md5" + "encoding/hex" + "strconv" + "sync/atomic" + "time" +) + +var seed int32 + +func init() { + seed = 0 +} + +func TunnelID() string { + nowString := time.Now().String() + strconv.Itoa(int(seed)) + atomic.AddInt32(&seed, 1) //避免多线程情况下获得的种子相同 + md5Byte := md5.Sum(bytes.NewBufferString(nowString).Bytes()) + return hex.EncodeToString(md5Byte[:]) +}