mirror of https://github.com/fatedier/frp
				
				
				
			feat(nathole): use serverUDPPort in nathole discovery when available (#3382)
							parent
							
								
									a22d6c9504
								
							
						
					
					
						commit
						3faae194d0
					
				| 
						 | 
					@ -53,8 +53,12 @@ var natholeDiscoveryCmd = &cobra.Command{
 | 
				
			||||||
			os.Exit(1)
 | 
								os.Exit(1)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							serverAddr := ""
 | 
				
			||||||
 | 
							if cfg.ServerUDPPort != 0 {
 | 
				
			||||||
 | 
								serverAddr = net.JoinHostPort(cfg.ServerAddr, strconv.Itoa(cfg.ServerUDPPort))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		addresses, err := nathole.Discover(
 | 
							addresses, err := nathole.Discover(
 | 
				
			||||||
			net.JoinHostPort(cfg.ServerAddr, strconv.Itoa(cfg.ServerUDPPort)),
 | 
								serverAddr,
 | 
				
			||||||
			[]string{cfg.NatHoleSTUNServer},
 | 
								[]string{cfg.NatHoleSTUNServer},
 | 
				
			||||||
			[]byte(cfg.Token),
 | 
								[]byte(cfg.Token),
 | 
				
			||||||
		)
 | 
							)
 | 
				
			||||||
| 
						 | 
					@ -62,6 +66,10 @@ var natholeDiscoveryCmd = &cobra.Command{
 | 
				
			||||||
			fmt.Println("discover error:", err)
 | 
								fmt.Println("discover error:", err)
 | 
				
			||||||
			os.Exit(1)
 | 
								os.Exit(1)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							if len(addresses) < 2 {
 | 
				
			||||||
 | 
								fmt.Printf("discover error: can not get enough addresses, need 2, got: %v\n", addresses)
 | 
				
			||||||
 | 
								os.Exit(1)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		natType, behavior, err := nathole.ClassifyNATType(addresses)
 | 
							natType, behavior, err := nathole.ClassifyNATType(addresses)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -79,8 +87,5 @@ func validateForNatHoleDiscovery(cfg config.ClientCommonConf) error {
 | 
				
			||||||
	if cfg.NatHoleSTUNServer == "" {
 | 
						if cfg.NatHoleSTUNServer == "" {
 | 
				
			||||||
		return fmt.Errorf("nat_hole_stun_server can not be empty")
 | 
							return fmt.Errorf("nat_hole_stun_server can not be empty")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if cfg.ServerUDPPort == 0 {
 | 
					 | 
				
			||||||
		return fmt.Errorf("server udp port can not be empty")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,31 +26,12 @@ import (
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var responseTimeout = 3 * time.Second
 | 
					var responseTimeout = 3 * time.Second
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Address struct {
 | 
					 | 
				
			||||||
	IP   string
 | 
					 | 
				
			||||||
	Port int
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type Message struct {
 | 
					type Message struct {
 | 
				
			||||||
	Body []byte
 | 
						Body []byte
 | 
				
			||||||
	Addr string
 | 
						Addr string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Discover(serverAddress string, stunServers []string, key []byte) ([]string, error) {
 | 
					func Discover(serverAddress string, stunServers []string, key []byte) ([]string, error) {
 | 
				
			||||||
	// parse address to net.Address
 | 
					 | 
				
			||||||
	stunAddresses := make([]net.Addr, 0, len(stunServers))
 | 
					 | 
				
			||||||
	for _, stunServer := range stunServers {
 | 
					 | 
				
			||||||
		addr, err := net.ResolveUDPAddr("udp4", stunServer)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return nil, err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		stunAddresses = append(stunAddresses, addr)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	serverAddr, err := net.ResolveUDPAddr("udp4", serverAddress)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// create a discoverConn and get response from messageChan
 | 
						// create a discoverConn and get response from messageChan
 | 
				
			||||||
	discoverConn, err := listen()
 | 
						discoverConn, err := listen()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -61,90 +42,29 @@ func Discover(serverAddress string, stunServers []string, key []byte) ([]string,
 | 
				
			||||||
	go discoverConn.readLoop()
 | 
						go discoverConn.readLoop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	addresses := make([]string, 0, len(stunServers)+1)
 | 
						addresses := make([]string, 0, len(stunServers)+1)
 | 
				
			||||||
	// get external address from frp server
 | 
						if serverAddress != "" {
 | 
				
			||||||
	externalAddr, err := discoverFromServer(discoverConn, serverAddr, key)
 | 
							// get external address from frp server
 | 
				
			||||||
	if err != nil {
 | 
							externalAddr, err := discoverConn.discoverFromServer(serverAddress, key)
 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	addresses = append(addresses, externalAddr)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, stunAddr := range stunAddresses {
 | 
					 | 
				
			||||||
		// get external address from stun server
 | 
					 | 
				
			||||||
		externalAddr, err = discoverFromStunServer(discoverConn, stunAddr)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		addresses = append(addresses, externalAddr)
 | 
							addresses = append(addresses, externalAddr)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, addr := range stunServers {
 | 
				
			||||||
 | 
							// get external address from stun server
 | 
				
			||||||
 | 
							externalAddrs, err := discoverConn.discoverFromStunServer(addr)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							addresses = append(addresses, externalAddrs...)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	return addresses, nil
 | 
						return addresses, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func discoverFromServer(c *discoverConn, addr net.Addr, key []byte) (string, error) {
 | 
					type stunResponse struct {
 | 
				
			||||||
	m := &msg.NatHoleBinding{
 | 
						externalAddr string
 | 
				
			||||||
		TransactionID: NewTransactionID(),
 | 
						otherAddr    string
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	buf, err := EncodeMessage(m, key)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return "", err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if _, err := c.conn.WriteTo(buf, addr); err != nil {
 | 
					 | 
				
			||||||
		return "", err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var respMsg msg.NatHoleBindingResp
 | 
					 | 
				
			||||||
	select {
 | 
					 | 
				
			||||||
	case rawMsg := <-c.messageChan:
 | 
					 | 
				
			||||||
		if err := DecodeMessageInto(rawMsg.Body, key, &respMsg); err != nil {
 | 
					 | 
				
			||||||
			return "", err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case <-time.After(responseTimeout):
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("wait response from frp server timeout")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if respMsg.TransactionID == "" {
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("error format: no transaction id found")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if respMsg.Error != "" {
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("get externalAddr from frp server error: %s", respMsg.Error)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return respMsg.Address, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func discoverFromStunServer(c *discoverConn, addr net.Addr) (string, error) {
 | 
					 | 
				
			||||||
	request, err := stun.Build(stun.TransactionID, stun.BindingRequest)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return "", err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err = request.NewTransactionID(); err != nil {
 | 
					 | 
				
			||||||
		return "", err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if _, err := c.conn.WriteTo(request.Raw, addr); err != nil {
 | 
					 | 
				
			||||||
		return "", err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var m stun.Message
 | 
					 | 
				
			||||||
	select {
 | 
					 | 
				
			||||||
	case msg := <-c.messageChan:
 | 
					 | 
				
			||||||
		m.Raw = msg.Body
 | 
					 | 
				
			||||||
		if err := m.Decode(); err != nil {
 | 
					 | 
				
			||||||
			return "", err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case <-time.After(responseTimeout):
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("wait response from stun server timeout")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	xorAddr := &stun.XORMappedAddress{}
 | 
					 | 
				
			||||||
	mappedAddr := &stun.MappedAddress{}
 | 
					 | 
				
			||||||
	if err := xorAddr.GetFrom(&m); err == nil {
 | 
					 | 
				
			||||||
		return xorAddr.String(), nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err := mappedAddr.GetFrom(&m); err == nil {
 | 
					 | 
				
			||||||
		return mappedAddr.String(), nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return "", fmt.Errorf("no address found")
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type discoverConn struct {
 | 
					type discoverConn struct {
 | 
				
			||||||
| 
						 | 
					@ -190,3 +110,115 @@ func (c *discoverConn) readLoop() {
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *discoverConn) doSTUNRequest(addr string) (*stunResponse, error) {
 | 
				
			||||||
 | 
						serverAddr, err := net.ResolveUDPAddr("udp4", addr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						request, err := stun.Build(stun.TransactionID, stun.BindingRequest)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = request.NewTransactionID(); err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if _, err := c.conn.WriteTo(request.Raw, serverAddr); err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var m stun.Message
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case msg := <-c.messageChan:
 | 
				
			||||||
 | 
							m.Raw = msg.Body
 | 
				
			||||||
 | 
							if err := m.Decode(); err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case <-time.After(responseTimeout):
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("wait response from stun server timeout")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						xorAddrGetter := &stun.XORMappedAddress{}
 | 
				
			||||||
 | 
						mappedAddrGetter := &stun.MappedAddress{}
 | 
				
			||||||
 | 
						changedAddrGetter := ChangedAddress{}
 | 
				
			||||||
 | 
						otherAddrGetter := &stun.OtherAddress{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp := &stunResponse{}
 | 
				
			||||||
 | 
						if err := mappedAddrGetter.GetFrom(&m); err == nil {
 | 
				
			||||||
 | 
							resp.externalAddr = mappedAddrGetter.String()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := xorAddrGetter.GetFrom(&m); err == nil {
 | 
				
			||||||
 | 
							resp.externalAddr = xorAddrGetter.String()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := changedAddrGetter.GetFrom(&m); err == nil {
 | 
				
			||||||
 | 
							resp.otherAddr = changedAddrGetter.String()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := otherAddrGetter.GetFrom(&m); err == nil {
 | 
				
			||||||
 | 
							resp.otherAddr = otherAddrGetter.String()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return resp, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *discoverConn) discoverFromServer(serverAddress string, key []byte) (string, error) {
 | 
				
			||||||
 | 
						addr, err := net.ResolveUDPAddr("udp4", serverAddress)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						m := &msg.NatHoleBinding{
 | 
				
			||||||
 | 
							TransactionID: NewTransactionID(),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						buf, err := EncodeMessage(m, key)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if _, err := c.conn.WriteTo(buf, addr); err != nil {
 | 
				
			||||||
 | 
							return "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var respMsg msg.NatHoleBindingResp
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case rawMsg := <-c.messageChan:
 | 
				
			||||||
 | 
							if err := DecodeMessageInto(rawMsg.Body, key, &respMsg); err != nil {
 | 
				
			||||||
 | 
								return "", err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case <-time.After(responseTimeout):
 | 
				
			||||||
 | 
							return "", fmt.Errorf("wait response from frp server timeout")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if respMsg.TransactionID == "" {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("error format: no transaction id found")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if respMsg.Error != "" {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("get externalAddr from frp server error: %s", respMsg.Error)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return respMsg.Address, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *discoverConn) discoverFromStunServer(addr string) ([]string, error) {
 | 
				
			||||||
 | 
						resp, err := c.doSTUNRequest(addr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if resp.externalAddr == "" {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("no external address found")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						externalAddrs := make([]string, 0, 2)
 | 
				
			||||||
 | 
						externalAddrs = append(externalAddrs, resp.externalAddr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if resp.otherAddr == "" {
 | 
				
			||||||
 | 
							return externalAddrs, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// find external address from changed address
 | 
				
			||||||
 | 
						resp, err = c.doSTUNRequest(resp.otherAddr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if resp.externalAddr != "" {
 | 
				
			||||||
 | 
							externalAddrs = append(externalAddrs, resp.externalAddr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return externalAddrs, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,8 +16,11 @@ package nathole
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/fatedier/golib/crypto"
 | 
						"github.com/fatedier/golib/crypto"
 | 
				
			||||||
 | 
						"github.com/pion/stun"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/fatedier/frp/pkg/msg"
 | 
						"github.com/fatedier/frp/pkg/msg"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -46,3 +49,17 @@ func DecodeMessageInto(data, key []byte, m msg.Message) error {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChangedAddress struct {
 | 
				
			||||||
 | 
						IP   net.IP
 | 
				
			||||||
 | 
						Port int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *ChangedAddress) GetFrom(m *stun.Message) error {
 | 
				
			||||||
 | 
						a := (*stun.MappedAddress)(s)
 | 
				
			||||||
 | 
						return a.GetFromAs(m, stun.AttrChangedAddress)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *ChangedAddress) String() string {
 | 
				
			||||||
 | 
						return net.JoinHostPort(s.IP.String(), strconv.Itoa(s.Port))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue