From ccb5027246934054803a536053c81bb1d8d64100 Mon Sep 17 00:00:00 2001 From: aarch64 <48624112+arm64v8a@users.noreply.github.com> Date: Sat, 1 May 2021 23:17:12 +0800 Subject: [PATCH] Fix XTCP timeout & Symmetric NAT client See #1585 #1795 --- client/visitor.go | 122 +++++++++++++++++++++++++++------------------- 1 file changed, 73 insertions(+), 49 deletions(-) diff --git a/client/visitor.go b/client/visitor.go index 36f9bada..d4c6b55e 100644 --- a/client/visitor.go +++ b/client/visitor.go @@ -218,11 +218,12 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) { return } - visitorConn, err := net.DialUDP("udp", nil, raddr) + visitorConn0, err := net.ListenUDP("udp", nil) if err != nil { xl.Warn("dial server udp addr error: %v", err) return } + visitorConn := &visitorConnWriter{UDPConn: visitorConn0, RemoteAddr: raddr} defer visitorConn.Close() now := time.Now().Unix() @@ -237,70 +238,84 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) { return } - // Wait for client address at most 10 seconds. var natHoleRespMsg msg.NatHoleResp - visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) - buf := pool.GetBuf(1024) - n, err := visitorConn.Read(buf) - if err != nil { - xl.Warn("get natHoleRespMsg error: %v", err) + var daddr_true *net.UDPAddr // Client's address may change! + var daddr *net.UDPAddr + var sid []byte + + // natHoleRespMsg may come later than client's sid, so we do this. + for { + // Wait for client address at most 10 seconds. + visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + buf := pool.GetBuf(1024) + n, tmpaddr, err := visitorConn.ReadFromUDP(buf) + if err != nil { + if daddr_true == nil { + xl.Warn("get natHoleRespMsg error: %v", err) + } else { + xl.Warn("read sid from client error: %v", err) + } + return + } + + // Received from server or client? + if tmpaddr.String() == raddr.String() { + err = msg.ReadMsgInto(bytes.NewReader(buf[:n]), &natHoleRespMsg) + if err != nil { + xl.Warn("get natHoleRespMsg error: %v", err) + return + } + visitorConn.SetReadDeadline(time.Time{}) + pool.PutBuf(buf) + + if natHoleRespMsg.Error != "" { + xl.Error("natHoleRespMsg get error info: %s", natHoleRespMsg.Error) + return + } + + xl.Trace("get natHoleRespMsg, sid [%s], client address [%s], visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr) + + daddr, err = net.ResolveUDPAddr("udp", natHoleRespMsg.ClientAddr) + if err != nil { + xl.Error("resolve client udp address error: %v", err) + return + } + + // send sid message to client + visitorConn.WriteToUDP([]byte(natHoleRespMsg.Sid), daddr) + } else { + daddr_true = tmpaddr + sid = buf[:n] + } + + if daddr_true != nil && natHoleRespMsg.ClientAddr != "" { + break + } + } + + if string(sid) != natHoleRespMsg.Sid { + xl.Warn("incorrect sid from client") return } - err = msg.ReadMsgInto(bytes.NewReader(buf[:n]), &natHoleRespMsg) - if err != nil { - xl.Warn("get natHoleRespMsg error: %v", err) - return - } - visitorConn.SetReadDeadline(time.Time{}) - pool.PutBuf(buf) + xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid) - if natHoleRespMsg.Error != "" { - xl.Error("natHoleRespMsg get error info: %s", natHoleRespMsg.Error) - return - } - - xl.Trace("get natHoleRespMsg, sid [%s], client address [%s], visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr) - - // Close visitorConn, so we can use it's local address. + // send sid message to client again + visitorConn.WriteToUDP([]byte(natHoleRespMsg.Sid), daddr_true) visitorConn.Close() - // send sid message to client + // make true connection laddr, _ := net.ResolveUDPAddr("udp", visitorConn.LocalAddr().String()) - daddr, err := net.ResolveUDPAddr("udp", natHoleRespMsg.ClientAddr) - if err != nil { - xl.Error("resolve client udp address error: %v", err) - return - } - lConn, err := net.DialUDP("udp", laddr, daddr) + lConn, err := net.DialUDP("udp", laddr, daddr_true) if err != nil { xl.Error("dial client udp address error: %v", err) return } defer lConn.Close() - lConn.Write([]byte(natHoleRespMsg.Sid)) - - // read ack sid from client - sidBuf := pool.GetBuf(1024) - lConn.SetReadDeadline(time.Now().Add(8 * time.Second)) - n, err = lConn.Read(sidBuf) - if err != nil { - xl.Warn("get sid from client error: %v", err) - return - } - lConn.SetReadDeadline(time.Time{}) - if string(sidBuf[:n]) != natHoleRespMsg.Sid { - xl.Warn("incorrect sid from client") - return - } - pool.PutBuf(sidBuf) - - xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid) - // wrap kcp connection var remote io.ReadWriteCloser - remote, err = frpNet.NewKCPConnFromUDP(lConn, true, natHoleRespMsg.ClientAddr) + remote, err = frpNet.NewKCPConnFromUDP(lConn, true, daddr_true.String()) if err != nil { xl.Error("create kcp connection from udp connection error: %v", err) return @@ -552,3 +567,12 @@ func (sv *SUDPVisitor) Close() { close(sv.readCh) close(sv.sendCh) } + +type visitorConnWriter struct { + *net.UDPConn + RemoteAddr *net.UDPAddr +} + +func (p *visitorConnWriter) Write(b []byte) (int, error) { + return p.WriteToUDP(b, p.RemoteAddr) +}