diff --git a/main_test.go b/main_test.go index 00d321e..f28bdf1 100644 --- a/main_test.go +++ b/main_test.go @@ -3,6 +3,7 @@ package main import ( "bufio" // "fmt" + "fmt" "github.com/TransX/log" "github.com/TransX/protocol" "io" @@ -10,7 +11,6 @@ import ( "os" "testing" "time" - "fmt" ) func serverBin(t *testing.T) { @@ -18,7 +18,7 @@ func serverBin(t *testing.T) { reader := bufio.NewReader(file) listener, err := net.Listen("tcp4", "127.0.0.1:1244") if err != nil { - log.Error("Test Server %s",err.Error()) + log.Error("Test Server %s", err.Error()) } var nCount byte nCount = 0 @@ -99,9 +99,9 @@ func serverBin(t *testing.T) { log.Info("Test Server. All Matches.") // log.Info("Test Server Receive %s", string(bytes[:n])) _, err = conn.Write([]byte("OK")) - log.Info("Test Server write OK from %s to %s",conn.LocalAddr().String(),conn.RemoteAddr().String(),) + log.Info("Test Server write OK from %s to %s", conn.LocalAddr().String(), conn.RemoteAddr().String()) if err != nil { - log.Error("Test Server %s ",err.Error()) + log.Error("Test Server %s ", err.Error()) } conn.Close() log.Info("Test Server closed") @@ -155,12 +155,12 @@ func clientBin(t *testing.T) { for { conn, err := net.Dial("tcp4", "127.0.0.1:1200") if err != nil { - log.Error("Test Client %s",err.Error()) + log.Error("Test Client %s", err.Error()) } binBytes := make([]byte, 1024*4) nBinBytes, err := reader.Read(binBytes) if err != nil { - if err==io.EOF{ + if err == io.EOF { fmt.Println("Test Finished.") break } @@ -178,9 +178,9 @@ func clientBin(t *testing.T) { nCount++ bytes := make([]byte, 1024*32) n, err = conn.Read(bytes) - log.Info("Test Client read % bytes (local add %s) msg:%s from %s",n,conn.LocalAddr().String(),bytes[:n],conn.RemoteAddr().String()) + log.Info("Test Client read %d bytes (local add %s) msg:%s from %s", n, conn.LocalAddr().String(), bytes[:n], conn.RemoteAddr().String()) if err != nil { - log.Error("Test Client %s",err.Error()) + log.Error("Test Client %s", err.Error()) } time.Sleep(time.Second * 0) conn.Close() diff --git a/model/tunnel.go b/model/tunnel.go index e25dd23..4982e2c 100644 --- a/model/tunnel.go +++ b/model/tunnel.go @@ -18,6 +18,7 @@ type Tunnel struct { cipherDirection constant.Direction regChan chan interface{} unregChan chan interface{} + unregistered bool } func NewTunnel(id string, src, dest net.Conn, cipherDirection constant.Direction) *Tunnel { @@ -26,6 +27,7 @@ func NewTunnel(id string, src, dest net.Conn, cipherDirection constant.Direction src: src, dest: dest, cipherDirection: cipherDirection, + unregistered: false, } } @@ -73,9 +75,11 @@ func (this *Tunnel) receive(revCarrier *tscipher.Carrier) { cipherDirection := this.cipherDirection id := this.id defer func() { - // log.Debug("tunnel id %s ends", id) //注销 - // this.unregChan <- this + if !this.unregistered { // 应该不存在异步问题 + this.unregChan <- this + this.unregistered = true + } if r := recover(); r != nil { if src != nil { src.Close() @@ -110,7 +114,10 @@ func (this *Tunnel) send(sendCarrier *tscipher.Carrier) { id := this.id defer func() { //注销 - this.unregChan <- this + if !this.unregistered { + this.unregChan <- this + this.unregistered = true + } if r := recover(); r != nil { if src != nil { src.Close() diff --git a/model/tunnelpair.go b/model/tunnelpair.go new file mode 100644 index 0000000..e6eb60d --- /dev/null +++ b/model/tunnelpair.go @@ -0,0 +1,35 @@ +package model + +type TunnelPair struct { + pair []*Tunnel +} + +func NewTunnelPair(a *Tunnel, b *Tunnel) *TunnelPair { + r := new(TunnelPair) + r.pair = make([]*Tunnel, 2) + p := r.pair + p[0] = a + p[1] = b + return r +} + +func (this *TunnelPair) Run() { + p := this.pair + for _, v := range p { + go v.Run() + } +} + +func (this *TunnelPair) SetRegChan(reg chan interface{}) { + p := this.pair + for _, v := range p { + v.SetRegChan(reg) + } +} + +func (this *TunnelPair) SetUnRegChan(unreg chan interface{}) { + p := this.pair + for _, v := range p { + v.SetUnRegChan(unreg) + } +} diff --git a/protocol/tcp.go b/protocol/tcp.go index 109f392..a24ff61 100644 --- a/protocol/tcp.go +++ b/protocol/tcp.go @@ -77,31 +77,39 @@ func (this *TransTCP) Start(listenPort, destIP, destPort string, clientOrServer } log.Info("Dial %s", destConn.RemoteAddr().String()) //tunnel model : [ -->>server ---- client -->> ](this is a tunnel) - if clientOrServer == "client" { + if clientOrServer == "client" { //加密方向 sendID := utils.TunnelID() ntSend := model.NewTunnel(sendID, listenerConn, destConn, constant.SEND) - ntSend.SetRegChan(tunMng.GetRegChan()) - ntSend.SetUnRegChan(tunMng.GetUnregChan()) + // ntSend.SetRegChan(tunMng.GetRegChan()) + // ntSend.SetUnRegChan(tunMng.GetUnregChan()) receiveID := utils.TunnelID() ntRev := model.NewTunnel(receiveID, destConn, listenerConn, constant.RECEIVE) - ntRev.SetRegChan(tunMng.GetRegChan()) - ntRev.SetUnRegChan(tunMng.GetUnregChan()) + // ntRev.SetRegChan(tunMng.GetRegChan()) + // ntRev.SetUnRegChan(tunMng.GetUnregChan()) printModelDetail(sendID, receiveID) - go ntSend.Run() - go ntRev.Run() + // go ntSend.Run() + // go ntRev.Run() + tunnelPair := model.NewTunnelPair(ntSend, ntRev) + tunnelPair.SetRegChan(tunMng.GetRegChan()) + tunnelPair.SetUnRegChan(tunMng.GetUnregChan()) + tunnelPair.Run() } if clientOrServer == "server" { receiveID := utils.TunnelID() ntRev := model.NewTunnel(receiveID, listenerConn, destConn, constant.RECEIVE) - ntRev.SetRegChan(tunMng.GetRegChan()) - ntRev.SetUnRegChan(tunMng.GetUnregChan()) + // ntRev.SetRegChan(tunMng.GetRegChan()) + // ntRev.SetUnRegChan(tunMng.GetUnregChan()) sendID := utils.TunnelID() ntSend := model.NewTunnel(sendID, destConn, listenerConn, constant.SEND) - ntSend.SetRegChan(tunMng.GetRegChan()) - ntSend.SetUnRegChan(tunMng.GetUnregChan()) + // ntSend.SetRegChan(tunMng.GetRegChan()) + // ntSend.SetUnRegChan(tunMng.GetUnregChan()) printModelDetail(sendID, receiveID) - go ntRev.Run() - go ntSend.Run() + // go ntRev.Run() + // go ntSend.Run() + tunnelPair := model.NewTunnelPair(ntSend, ntRev) + tunnelPair.SetRegChan(tunMng.GetRegChan()) + tunnelPair.SetUnRegChan(tunMng.GetUnregChan()) + tunnelPair.Run() } }()