diff --git a/tcp.go b/tcp.go index b572a27..847d18b 100644 --- a/tcp.go +++ b/tcp.go @@ -1,6 +1,8 @@ package main import ( + "errors" + "fmt" "github.com/TransX/log" "net" "os" @@ -16,14 +18,24 @@ func NewTransTCP() *TransTCP { func (this *TransTCP) createTCPClient(ip, port string) (conn net.Conn, err error) { conn, err = net.Dial("tcp4", ip+":"+port) - if err == nil { - - } else { + if err != nil { conn = nil } return } +func (this *TransTCP) createTCPClientWithRetry(ip, port string, retry int) (conn net.Conn, err error) { + for i := 0; i < retry; i++ { + c, e := this.createTCPClient(ip, port) + if e == nil { + return c, e + } + log.Error("Create Client Error: %s", e.Error()) + } + //failed with retry + return nil, errors.New(fmt.Sprintln("failed to create client after %d retry", retry)) +} + func (this *TransTCP) createTCPListener(ip, port string) (listen net.Listener, err error) { listener, _err := net.Listen("tcp4", ip+":"+port) if _err == nil { @@ -48,7 +60,7 @@ func (this *TransTCP) Start(listenPort, destIP, destPort string, clientOrServer go func() { log.Info("Incoming %s", listenerConn.RemoteAddr().String()) //创建到目标的连接 - destConn, err := this.createTCPClient(destIP, destPort) + destConn, err := this.createTCPClientWithRetry(destIP, destPort, 3) if err != nil { log.Panic("Failed to connect to destination. %s", err) os.Exit(0)