diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index f08d3ead..55e87a7c 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -20,6 +20,8 @@ import ( "io" "io/ioutil" "net" + "strconv" + "strings" "sync" "time" @@ -280,25 +282,56 @@ func (pxy *XtcpProxy) InWorkConn(conn frpNet.Conn) { return } - pxy.Trace("get natHoleRespMsg, sid [%s], client address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr) + pxy.Trace("get natHoleRespMsg, sid [%s], client address [%s] visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr) - // Send sid to visitor udp address. - time.Sleep(time.Second) + // Send detect message + array := strings.Split(natHoleRespMsg.VisitorAddr, ":") + if len(array) <= 1 { + pxy.Error("get NatHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr) + } laddr, _ := net.ResolveUDPAddr("udp", clientConn.LocalAddr().String()) - daddr, err := net.ResolveUDPAddr("udp", natHoleRespMsg.VisitorAddr) + /* + for i := 1000; i < 65000; i++ { + pxy.sendDetectMsg(array[0], int64(i), laddr, "a") + } + */ + port, err := strconv.ParseInt(array[1], 10, 64) if err != nil { - pxy.Error("resolve visitor udp address error: %v", err) + pxy.Error("get natHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr) return } + pxy.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid)) + pxy.Trace("send all detect msg done") - lConn, err := net.DialUDP("udp", laddr, daddr) + msg.WriteMsg(conn, &msg.NatHoleClientDetectOK{}) + + // Listen for clientConn's address and wait for visitor connection + lConn, err := net.ListenUDP("udp", laddr) if err != nil { - pxy.Error("dial visitor udp address error: %v", err) + pxy.Error("listen on visitorConn's local adress error: %v", err) return } - lConn.Write([]byte(natHoleRespMsg.Sid)) + defer lConn.Close() - kcpConn, err := frpNet.NewKcpConnFromUdp(lConn, true, natHoleRespMsg.VisitorAddr) + lConn.SetReadDeadline(time.Now().Add(8 * time.Second)) + sidBuf := pool.GetBuf(1024) + var uAddr *net.UDPAddr + n, uAddr, err = lConn.ReadFromUDP(sidBuf) + if err != nil { + pxy.Warn("get sid from visitor error: %v", err) + return + } + lConn.SetReadDeadline(time.Time{}) + if string(sidBuf[:n]) != natHoleRespMsg.Sid { + pxy.Warn("incorrect sid from visitor") + return + } + pool.PutBuf(sidBuf) + pxy.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid) + + lConn.WriteToUDP(sidBuf[:n], uAddr) + + kcpConn, err := frpNet.NewKcpConnFromUdp(lConn, false, natHoleRespMsg.VisitorAddr) if err != nil { pxy.Error("create kcp connection from udp connection error: %v", err) return @@ -323,6 +356,25 @@ func (pxy *XtcpProxy) InWorkConn(conn frpNet.Conn) { frpNet.WrapConn(muxConn), []byte(pxy.cfg.Sk)) } +func (pxy *XtcpProxy) sendDetectMsg(addr string, port int, laddr *net.UDPAddr, content []byte) (err error) { + daddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + return err + } + + tConn, err := net.DialUDP("udp", laddr, daddr) + if err != nil { + return err + } + + //uConn := ipv4.NewConn(tConn) + //uConn.SetTTL(3) + + tConn.Write(content) + tConn.Close() + return nil +} + // UDP type UdpProxy struct { *BaseProxy diff --git a/client/visitor.go b/client/visitor.go index 35abcf5a..6eb3688c 100644 --- a/client/visitor.go +++ b/client/visitor.go @@ -20,13 +20,9 @@ import ( "io" "io/ioutil" "net" - "strconv" - "strings" "sync" "time" - "golang.org/x/net/ipv4" - "github.com/fatedier/frp/g" "github.com/fatedier/frp/models/config" "github.com/fatedier/frp/models/msg" @@ -251,42 +247,31 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) { return } - sv.Trace("get natHoleRespMsg, sid [%s], client address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr) + sv.Trace("get natHoleRespMsg, sid [%s], client address [%s], visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr) // Close visitorConn, so we can use it's local address. visitorConn.Close() - // Send detect message. - array := strings.Split(natHoleRespMsg.ClientAddr, ":") - if len(array) <= 1 { - sv.Error("get natHoleResp client address error: %s", natHoleRespMsg.ClientAddr) - return - } + // send sid message to client laddr, _ := net.ResolveUDPAddr("udp", visitorConn.LocalAddr().String()) - /* - for i := 1000; i < 65000; i++ { - sv.sendDetectMsg(array[0], int64(i), laddr, "a") - } - */ - port, err := strconv.ParseInt(array[1], 10, 64) + daddr, err := net.ResolveUDPAddr("udp", natHoleRespMsg.ClientAddr) if err != nil { - sv.Error("get natHoleResp client address error: %s", natHoleRespMsg.ClientAddr) + sv.Error("resolve client udp address error: %v", err) return } - sv.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid)) - sv.Trace("send all detect msg done") - - // Listen for visitorConn's address and wait for client connection. - lConn, err := net.ListenUDP("udp", laddr) + lConn, err := net.DialUDP("udp", laddr, daddr) if err != nil { - sv.Error("listen on visitorConn's local adress error: %v", err) + sv.Error("dial client udp address error: %v", err) return } defer lConn.Close() - lConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + lConn.Write([]byte(natHoleRespMsg.Sid)) + + // read ack sid from client sidBuf := pool.GetBuf(1024) - n, _, err = lConn.ReadFromUDP(sidBuf) + lConn.SetReadDeadline(time.Now().Add(8 * time.Second)) + n, err = lConn.Read(sidBuf) if err != nil { sv.Warn("get sid from client error: %v", err) return @@ -296,11 +281,13 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) { sv.Warn("incorrect sid from client") return } - sv.Info("nat hole connection make success, sid [%s]", string(sidBuf[:n])) pool.PutBuf(sidBuf) + sv.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid) + + // wrap kcp connection var remote io.ReadWriteCloser - remote, err = frpNet.NewKcpConnFromUdp(lConn, false, natHoleRespMsg.ClientAddr) + remote, err = frpNet.NewKcpConnFromUdp(lConn, true, natHoleRespMsg.ClientAddr) if err != nil { sv.Error("create kcp connection from udp connection error: %v", err) return @@ -336,22 +323,3 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) { frpIo.Join(userConn, muxConn) sv.Debug("join connections closed") } - -func (sv *XtcpVisitor) sendDetectMsg(addr string, port int, laddr *net.UDPAddr, content []byte) (err error) { - daddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - return err - } - - tConn, err := net.DialUDP("udp", laddr, daddr) - if err != nil { - return err - } - - uConn := ipv4.NewConn(tConn) - uConn.SetTTL(3) - - tConn.Write(content) - tConn.Close() - return nil -} diff --git a/models/msg/msg.go b/models/msg/msg.go index e06fa371..2d5985c4 100644 --- a/models/msg/msg.go +++ b/models/msg/msg.go @@ -17,44 +17,46 @@ package msg import "net" const ( - TypeLogin = 'o' - TypeLoginResp = '1' - TypeNewProxy = 'p' - TypeNewProxyResp = '2' - TypeCloseProxy = 'c' - TypeNewWorkConn = 'w' - TypeReqWorkConn = 'r' - TypeStartWorkConn = 's' - TypeNewVisitorConn = 'v' - TypeNewVisitorConnResp = '3' - TypePing = 'h' - TypePong = '4' - TypeUdpPacket = 'u' - TypeNatHoleVisitor = 'i' - TypeNatHoleClient = 'n' - TypeNatHoleResp = 'm' - TypeNatHoleSid = '5' + TypeLogin = 'o' + TypeLoginResp = '1' + TypeNewProxy = 'p' + TypeNewProxyResp = '2' + TypeCloseProxy = 'c' + TypeNewWorkConn = 'w' + TypeReqWorkConn = 'r' + TypeStartWorkConn = 's' + TypeNewVisitorConn = 'v' + TypeNewVisitorConnResp = '3' + TypePing = 'h' + TypePong = '4' + TypeUdpPacket = 'u' + TypeNatHoleVisitor = 'i' + TypeNatHoleClient = 'n' + TypeNatHoleResp = 'm' + TypeNatHoleClientDetectOK = 'd' + TypeNatHoleSid = '5' ) var ( msgTypeMap = map[byte]interface{}{ - TypeLogin: Login{}, - TypeLoginResp: LoginResp{}, - TypeNewProxy: NewProxy{}, - TypeNewProxyResp: NewProxyResp{}, - TypeCloseProxy: CloseProxy{}, - TypeNewWorkConn: NewWorkConn{}, - TypeReqWorkConn: ReqWorkConn{}, - TypeStartWorkConn: StartWorkConn{}, - TypeNewVisitorConn: NewVisitorConn{}, - TypeNewVisitorConnResp: NewVisitorConnResp{}, - TypePing: Ping{}, - TypePong: Pong{}, - TypeUdpPacket: UdpPacket{}, - TypeNatHoleVisitor: NatHoleVisitor{}, - TypeNatHoleClient: NatHoleClient{}, - TypeNatHoleResp: NatHoleResp{}, - TypeNatHoleSid: NatHoleSid{}, + TypeLogin: Login{}, + TypeLoginResp: LoginResp{}, + TypeNewProxy: NewProxy{}, + TypeNewProxyResp: NewProxyResp{}, + TypeCloseProxy: CloseProxy{}, + TypeNewWorkConn: NewWorkConn{}, + TypeReqWorkConn: ReqWorkConn{}, + TypeStartWorkConn: StartWorkConn{}, + TypeNewVisitorConn: NewVisitorConn{}, + TypeNewVisitorConnResp: NewVisitorConnResp{}, + TypePing: Ping{}, + TypePong: Pong{}, + TypeUdpPacket: UdpPacket{}, + TypeNatHoleVisitor: NatHoleVisitor{}, + TypeNatHoleClient: NatHoleClient{}, + TypeNatHoleResp: NatHoleResp{}, + TypeNatHoleClientDetectOK: NatHoleClientDetectOK{}, + TypeNatHoleSid: NatHoleSid{}, } ) @@ -169,6 +171,9 @@ type NatHoleResp struct { Error string `json:"error"` } +type NatHoleClientDetectOK struct { +} + type NatHoleSid struct { Sid string `json:"sid"` } diff --git a/models/nathole/nathole.go b/models/nathole/nathole.go index 1e120ae2..0c33dfe4 100644 --- a/models/nathole/nathole.go +++ b/models/nathole/nathole.go @@ -18,6 +18,11 @@ import ( // Timeout seconds. var NatHoleTimeout int64 = 10 +type SidRequest struct { + Sid string + NotifyCh chan struct{} +} + type NatHoleController struct { listener *net.UDPConn @@ -44,11 +49,11 @@ func NewNatHoleController(udpBindAddr string) (nc *NatHoleController, err error) return nc, nil } -func (nc *NatHoleController) ListenClient(name string, sk string) (sidCh chan string) { +func (nc *NatHoleController) ListenClient(name string, sk string) (sidCh chan *SidRequest) { clientCfg := &NatHoleClientCfg{ Name: name, Sk: sk, - SidCh: make(chan string), + SidCh: make(chan *SidRequest), } nc.mu.Lock() nc.clientCfgs[name] = clientCfg @@ -132,7 +137,10 @@ func (nc *NatHoleController) HandleVisitor(m *msg.NatHoleVisitor, raddr *net.UDP }() err := errors.PanicToError(func() { - clientCfg.SidCh <- sid + clientCfg.SidCh <- &SidRequest{ + Sid: sid, + NotifyCh: session.NotifyCh, + } }) if err != nil { return @@ -158,7 +166,6 @@ func (nc *NatHoleController) HandleClient(m *msg.NatHoleClient, raddr *net.UDPAd } log.Trace("handle client message, sid [%s]", session.Sid) session.ClientAddr = raddr - session.NotifyCh <- struct{}{} resp := nc.GenNatHoleResponse(session, "") log.Trace("send nat hole response to client") @@ -201,5 +208,5 @@ type NatHoleSession struct { type NatHoleClientCfg struct { Name string Sk string - SidCh chan string + SidCh chan *SidRequest } diff --git a/server/proxy/xtcp.go b/server/proxy/xtcp.go index 9c5f9112..87266669 100644 --- a/server/proxy/xtcp.go +++ b/server/proxy/xtcp.go @@ -42,18 +42,40 @@ func (pxy *XtcpProxy) Run() (remoteAddr string, err error) { select { case <-pxy.closeCh: break - case sid := <-sidCh: + case sidRequest := <-sidCh: + sr := sidRequest workConn, errRet := pxy.GetWorkConnFromPool() if errRet != nil { continue } m := &msg.NatHoleSid{ - Sid: sid, + Sid: sr.Sid, } errRet = msg.WriteMsg(workConn, m) if errRet != nil { pxy.Warn("write nat hole sid package error, %v", errRet) + workConn.Close() + break } + + go func() { + raw, errRet := msg.ReadMsg(workConn) + if errRet != nil { + pxy.Warn("read nat hole client ok package error: %v", errRet) + workConn.Close() + return + } + if _, ok := raw.(*msg.NatHoleClientDetectOK); !ok { + pxy.Warn("read nat hole client ok package format error") + workConn.Close() + return + } + + select { + case sr.NotifyCh <- struct{}{}: + default: + } + }() } } }()