From 7793f5554518d785d1b69b3818ece4d001972907 Mon Sep 17 00:00:00 2001 From: fatedier Date: Fri, 10 Aug 2018 11:43:08 +0800 Subject: [PATCH] websocket: update muxer for websocket --- models/config/client_common.go | 7 +- server/service.go | 53 ++++---------- utils/net/conn.go | 102 +++++++++++++++++--------- utils/net/websocket.go | 130 ++++++++++++++------------------- 4 files changed, 140 insertions(+), 152 deletions(-) diff --git a/models/config/client_common.go b/models/config/client_common.go index c1d61cb..5dc49aa 100644 --- a/models/config/client_common.go +++ b/models/config/client_common.go @@ -186,9 +186,10 @@ func UnmarshalClientConfFromIni(defaultCfg *ClientCommonConf, content string) (c } if tmpStr, ok = conf.Get("common", "protocol"); ok { - // Now it only support tcp and kcp. - if tmpStr != "kcp" && tmpStr != "websocket" { - tmpStr = "tcp" + // Now it only support tcp and kcp and websocket. + if tmpStr != "tcp" && tmpStr != "kcp" && tmpStr != "websocket" { + err = fmt.Errorf("Parse conf error: invalid protocol") + return } cfg.Protocol = tmpStr } diff --git a/server/service.go b/server/service.go index dcb7a2b..024b683 100644 --- a/server/service.go +++ b/server/service.go @@ -15,11 +15,11 @@ package server import ( + "bytes" "fmt" "io/ioutil" "net" "net/http" - "strings" "time" "github.com/fatedier/frp/assets" @@ -139,6 +139,13 @@ func NewService() (svr *Service, err error) { log.Info("frps kcp listen on udp %s:%d", cfg.BindAddr, cfg.KcpBindPort) } + // Listen for accepting connections from client using websocket protocol. + websocketPrefix := []byte("GET /%23frp") + websocketLn := svr.muxer.Listen(0, uint32(len(websocketPrefix)), func(data []byte) bool { + return bytes.Equal(data, websocketPrefix) + }) + svr.websocketListener = frpNet.NewWebsocketListener(websocketLn) + // Create http vhost muxer. if cfg.VhostHttpPort > 0 { rp := vhost.NewHttpReverseProxy() @@ -150,7 +157,9 @@ func NewService() (svr *Service, err error) { Handler: rp, } var l net.Listener - if !httpMuxOn { + if httpMuxOn { + l = svr.muxer.ListenHttp(1) + } else { l, err = net.Listen("tcp", address) if err != nil { err = fmt.Errorf("Create vhost http listener error, %v", err) @@ -165,7 +174,7 @@ func NewService() (svr *Service, err error) { if cfg.VhostHttpsPort > 0 { var l net.Listener if httpsMuxOn { - l = svr.muxer.ListenHttps(0) + l = svr.muxer.ListenHttps(1) } else { l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort)) if err != nil { @@ -205,37 +214,6 @@ func NewService() (svr *Service, err error) { log.Info("Dashboard listen on %s:%d", cfg.DashboardAddr, cfg.DashboardPort) } - if !httpMuxOn { - svr.websocketListener, err = frpNet.NewWebsocketListener(svr.muxer.ListenHttp(0), nil) - return - } - - // server := &http.Server{} - if httpMuxOn { - rp := svr.httpReverseProxy - svr.websocketListener, err = frpNet.NewWebsocketListener(svr.muxer.ListenHttp(0), - func(w http.ResponseWriter, req *http.Request) bool { - domain := getHostFromAddr(req.Host) - location := req.URL.Path - headers := rp.GetHeaders(domain, location) - if headers == nil { - return true - } - rp.ServeHTTP(w, req) - return false - }) - } - - return -} - -func getHostFromAddr(addr string) (host string) { - strs := strings.Split(addr, ":") - if len(strs) > 1 { - host = strs[0] - } else { - host = addr - } return } @@ -246,9 +224,9 @@ func (svr *Service) Run() { if g.GlbServerCfg.KcpBindPort > 0 { go svr.HandleListener(svr.kcpListener) } - if svr.websocketListener != nil { - go svr.HandleListener(svr.websocketListener) - } + + go svr.HandleListener(svr.websocketListener) + svr.HandleListener(svr.listener) } @@ -260,6 +238,7 @@ func (svr *Service) HandleListener(l frpNet.Listener) { log.Warn("Listener for incoming connections from client closed") return } + // Start a new goroutine for dealing connections. go func(frpConn frpNet.Conn) { dealFn := func(conn frpNet.Conn) { diff --git a/utils/net/conn.go b/utils/net/conn.go index 825a989..6dab2bd 100644 --- a/utils/net/conn.go +++ b/utils/net/conn.go @@ -96,47 +96,34 @@ func (conn *WrapReadWriteCloserConn) SetWriteDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -func ConnectServer(protocol string, addr string) (c Conn, err error) { - switch protocol { - case "tcp": - return ConnectTcpServer(addr) - case "kcp": - kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3) - if errRet != nil { - err = errRet - return - } - kcpConn.SetStreamMode(true) - kcpConn.SetWriteDelay(true) - kcpConn.SetNoDelay(1, 20, 2, 1) - kcpConn.SetWindowSize(128, 512) - kcpConn.SetMtu(1350) - kcpConn.SetACKNoDelay(false) - kcpConn.SetReadBuffer(4194304) - kcpConn.SetWriteBuffer(4194304) - c = WrapConn(kcpConn) - return - default: - return nil, fmt.Errorf("unsupport protocol: %s", protocol) +type CloseNotifyConn struct { + net.Conn + log.Logger + + // 1 means closed + closeFlag int32 + + closeFn func() +} + +// closeFn will be only called once +func WrapCloseNotifyConn(c net.Conn, closeFn func()) Conn { + return &CloseNotifyConn{ + Conn: c, + Logger: log.NewPrefixLogger(""), + closeFn: closeFn, } } -func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn, err error) { - switch protocol { - case "tcp": - var conn net.Conn - if conn, err = gnet.DialTcpByProxy(proxyUrl, addr); err != nil { - return +func (cc *CloseNotifyConn) Close() (err error) { + pflag := atomic.SwapInt32(&cc.closeFlag, 1) + if pflag == 0 { + err = cc.Close() + if cc.closeFn != nil { + cc.closeFn() } - return WrapConn(conn), nil - case "kcp": - // http proxy is not supported for kcp - return ConnectServer(protocol, addr) - case "websocket": - return ConnectWebsocketServer(addr) - default: - return nil, fmt.Errorf("unsupport protocol: %s", protocol) } + return } type StatsConn struct { @@ -177,3 +164,46 @@ func (statsConn *StatsConn) Close() (err error) { } return } + +func ConnectServer(protocol string, addr string) (c Conn, err error) { + switch protocol { + case "tcp": + return ConnectTcpServer(addr) + case "kcp": + kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3) + if errRet != nil { + err = errRet + return + } + kcpConn.SetStreamMode(true) + kcpConn.SetWriteDelay(true) + kcpConn.SetNoDelay(1, 20, 2, 1) + kcpConn.SetWindowSize(128, 512) + kcpConn.SetMtu(1350) + kcpConn.SetACKNoDelay(false) + kcpConn.SetReadBuffer(4194304) + kcpConn.SetWriteBuffer(4194304) + c = WrapConn(kcpConn) + return + default: + return nil, fmt.Errorf("unsupport protocol: %s", protocol) + } +} + +func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn, err error) { + switch protocol { + case "tcp": + var conn net.Conn + if conn, err = gnet.DialTcpByProxy(proxyUrl, addr); err != nil { + return + } + return WrapConn(conn), nil + case "kcp": + // http proxy is not supported for kcp + return ConnectServer(protocol, addr) + case "websocket": + return ConnectWebsocketServer(addr) + default: + return nil, fmt.Errorf("unsupport protocol: %s", protocol) + } +} diff --git a/utils/net/websocket.go b/utils/net/websocket.go index 0411112..a3bf0f0 100644 --- a/utils/net/websocket.go +++ b/utils/net/websocket.go @@ -1,127 +1,105 @@ package net import ( + "errors" "fmt" "net" "net/http" "net/url" - "sync/atomic" "time" "github.com/fatedier/frp/utils/log" + "golang.org/x/net/websocket" ) +var ( + ErrWebsocketListenerClosed = errors.New("websocket listener closed") +) + +const ( + FrpWebsocketPath = "/#frp" +) + type WebsocketListener struct { + net.Addr + ln net.Listener + accept chan Conn log.Logger + server *http.Server httpMutex *http.ServeMux - connChan chan *WebsocketConn - closeFlag bool } -func NewWebsocketListener(ln net.Listener, - filter func(w http.ResponseWriter, r *http.Request) bool) (l *WebsocketListener, err error) { - l = &WebsocketListener{ - httpMutex: http.NewServeMux(), - connChan: make(chan *WebsocketConn), - Logger: log.NewPrefixLogger(""), +// ln: tcp listener for websocket connections +func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) { + wl = &WebsocketListener{ + Addr: ln.Addr(), + accept: make(chan Conn), + Logger: log.NewPrefixLogger(""), } - l.httpMutex.Handle("/", websocket.Handler(func(c *websocket.Conn) { - conn := NewWebScoketConn(c) - l.connChan <- conn - conn.waitClose() + + muxer := http.NewServeMux() + muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) { + notifyCh := make(chan struct{}) + conn := WrapCloseNotifyConn(c, func() { + close(notifyCh) + }) + wl.accept <- conn + <-notifyCh })) - l.server = &http.Server{ - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if filter != nil && !filter(w, r) { - return - } - l.httpMutex.ServeHTTP(w, r) - }), + + wl.server = &http.Server{ + Addr: ln.Addr().String(), + Handler: muxer, } - ch := make(chan struct{}) - go func() { - close(ch) - err = l.server.Serve(ln) - }() - <-ch - <-time.After(time.Millisecond) + + go wl.server.Serve(ln) return } -func ListenWebsocket(bindAddr string, bindPort int) (l *WebsocketListener, err error) { - ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) +func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) { + tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) if err != nil { - return + return nil, err } - l, err = NewWebsocketListener(ln, nil) - return + l := NewWebsocketListener(tcpLn) + return l, nil } func (p *WebsocketListener) Accept() (Conn, error) { - c := <-p.connChan + c, ok := <-p.accept + if !ok { + return nil, ErrWebsocketListenerClosed + } return c, nil } func (p *WebsocketListener) Close() error { - if !p.closeFlag { - p.closeFlag = true - p.server.Close() - } - return nil -} - -type WebsocketConn struct { - net.Conn - log.Logger - closed int32 - wait chan struct{} -} - -func NewWebScoketConn(conn net.Conn) (c *WebsocketConn) { - c = &WebsocketConn{ - Conn: conn, - Logger: log.NewPrefixLogger(""), - wait: make(chan struct{}), - } - return + return p.server.Close() } -func (p *WebsocketConn) Close() error { - if atomic.SwapInt32(&p.closed, 1) == 1 { - return nil - } - close(p.wait) - return p.Conn.Close() -} - -func (p *WebsocketConn) waitClose() { - <-p.wait -} - -// ConnectWebsocketServer : -// addr: ws://domain:port -func ConnectWebsocketServer(addr string) (c Conn, err error) { - addr = "ws://" + addr +// addr: domain:port +func ConnectWebsocketServer(addr string) (Conn, error) { + addr = "ws://" + addr + FrpWebsocketPath uri, err := url.Parse(addr) if err != nil { - return + return nil, err } origin := "http://" + uri.Host cfg, err := websocket.NewConfig(addr, origin) if err != nil { - return + return nil, err } cfg.Dialer = &net.Dialer{ - Timeout: time.Second * 10, + Timeout: 10 * time.Second, } conn, err := websocket.DialConfig(cfg) if err != nil { - return + return nil, err } - c = NewWebScoketConn(conn) - return + c := WrapConn(conn) + return c, nil }