Fix XTCP timeout & Symmetric NAT client

See #1585 #1795
pull/2376/head
aarch64 2021-05-01 23:17:12 +08:00 committed by arm64v8a
parent 2408f1df04
commit ccb5027246
1 changed files with 73 additions and 49 deletions

View File

@ -218,11 +218,12 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
return return
} }
visitorConn, err := net.DialUDP("udp", nil, raddr) visitorConn0, err := net.ListenUDP("udp", nil)
if err != nil { if err != nil {
xl.Warn("dial server udp addr error: %v", err) xl.Warn("dial server udp addr error: %v", err)
return return
} }
visitorConn := &visitorConnWriter{UDPConn: visitorConn0, RemoteAddr: raddr}
defer visitorConn.Close() defer visitorConn.Close()
now := time.Now().Unix() now := time.Now().Unix()
@ -237,70 +238,84 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
return return
} }
// Wait for client address at most 10 seconds.
var natHoleRespMsg msg.NatHoleResp var natHoleRespMsg msg.NatHoleResp
visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) var daddr_true *net.UDPAddr // Client's address may change!
buf := pool.GetBuf(1024) var daddr *net.UDPAddr
n, err := visitorConn.Read(buf) var sid []byte
if err != nil {
xl.Warn("get natHoleRespMsg error: %v", err) // natHoleRespMsg may come later than client's sid, so we do this.
for {
// Wait for client address at most 10 seconds.
visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
buf := pool.GetBuf(1024)
n, tmpaddr, err := visitorConn.ReadFromUDP(buf)
if err != nil {
if daddr_true == nil {
xl.Warn("get natHoleRespMsg error: %v", err)
} else {
xl.Warn("read sid from client error: %v", err)
}
return
}
// Received from server or client?
if tmpaddr.String() == raddr.String() {
err = msg.ReadMsgInto(bytes.NewReader(buf[:n]), &natHoleRespMsg)
if err != nil {
xl.Warn("get natHoleRespMsg error: %v", err)
return
}
visitorConn.SetReadDeadline(time.Time{})
pool.PutBuf(buf)
if natHoleRespMsg.Error != "" {
xl.Error("natHoleRespMsg get error info: %s", natHoleRespMsg.Error)
return
}
xl.Trace("get natHoleRespMsg, sid [%s], client address [%s], visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr)
daddr, err = net.ResolveUDPAddr("udp", natHoleRespMsg.ClientAddr)
if err != nil {
xl.Error("resolve client udp address error: %v", err)
return
}
// send sid message to client
visitorConn.WriteToUDP([]byte(natHoleRespMsg.Sid), daddr)
} else {
daddr_true = tmpaddr
sid = buf[:n]
}
if daddr_true != nil && natHoleRespMsg.ClientAddr != "" {
break
}
}
if string(sid) != natHoleRespMsg.Sid {
xl.Warn("incorrect sid from client")
return return
} }
err = msg.ReadMsgInto(bytes.NewReader(buf[:n]), &natHoleRespMsg) xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid)
if err != nil {
xl.Warn("get natHoleRespMsg error: %v", err)
return
}
visitorConn.SetReadDeadline(time.Time{})
pool.PutBuf(buf)
if natHoleRespMsg.Error != "" { // send sid message to client again
xl.Error("natHoleRespMsg get error info: %s", natHoleRespMsg.Error) visitorConn.WriteToUDP([]byte(natHoleRespMsg.Sid), daddr_true)
return
}
xl.Trace("get natHoleRespMsg, sid [%s], client address [%s], visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr)
// Close visitorConn, so we can use it's local address.
visitorConn.Close() visitorConn.Close()
// send sid message to client // make true connection
laddr, _ := net.ResolveUDPAddr("udp", visitorConn.LocalAddr().String()) laddr, _ := net.ResolveUDPAddr("udp", visitorConn.LocalAddr().String())
daddr, err := net.ResolveUDPAddr("udp", natHoleRespMsg.ClientAddr) lConn, err := net.DialUDP("udp", laddr, daddr_true)
if err != nil {
xl.Error("resolve client udp address error: %v", err)
return
}
lConn, err := net.DialUDP("udp", laddr, daddr)
if err != nil { if err != nil {
xl.Error("dial client udp address error: %v", err) xl.Error("dial client udp address error: %v", err)
return return
} }
defer lConn.Close() defer lConn.Close()
lConn.Write([]byte(natHoleRespMsg.Sid))
// read ack sid from client
sidBuf := pool.GetBuf(1024)
lConn.SetReadDeadline(time.Now().Add(8 * time.Second))
n, err = lConn.Read(sidBuf)
if err != nil {
xl.Warn("get sid from client error: %v", err)
return
}
lConn.SetReadDeadline(time.Time{})
if string(sidBuf[:n]) != natHoleRespMsg.Sid {
xl.Warn("incorrect sid from client")
return
}
pool.PutBuf(sidBuf)
xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid)
// wrap kcp connection // wrap kcp connection
var remote io.ReadWriteCloser var remote io.ReadWriteCloser
remote, err = frpNet.NewKCPConnFromUDP(lConn, true, natHoleRespMsg.ClientAddr) remote, err = frpNet.NewKCPConnFromUDP(lConn, true, daddr_true.String())
if err != nil { if err != nil {
xl.Error("create kcp connection from udp connection error: %v", err) xl.Error("create kcp connection from udp connection error: %v", err)
return return
@ -552,3 +567,12 @@ func (sv *SUDPVisitor) Close() {
close(sv.readCh) close(sv.readCh)
close(sv.sendCh) close(sv.sendCh)
} }
type visitorConnWriter struct {
*net.UDPConn
RemoteAddr *net.UDPAddr
}
func (p *visitorConnWriter) Write(b []byte) (int, error) {
return p.WriteToUDP(b, p.RemoteAddr)
}