mirror of https://github.com/fatedier/frp
				
				
				
			websocket: update muxer for websocket
							parent
							
								
									64136a3b3e
								
							
						
					
					
						commit
						7793f55545
					
				| 
						 | 
				
			
			@ -186,9 +186,10 @@ func UnmarshalClientConfFromIni(defaultCfg *ClientCommonConf, content string) (c
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	if tmpStr, ok = conf.Get("common", "protocol"); ok {
 | 
			
		||||
		// Now it only support tcp and kcp.
 | 
			
		||||
		if tmpStr != "kcp" && tmpStr != "websocket" {
 | 
			
		||||
			tmpStr = "tcp"
 | 
			
		||||
		// Now it only support tcp and kcp and websocket.
 | 
			
		||||
		if tmpStr != "tcp" && tmpStr != "kcp" && tmpStr != "websocket" {
 | 
			
		||||
			err = fmt.Errorf("Parse conf error: invalid protocol")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		cfg.Protocol = tmpStr
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,11 +15,11 @@
 | 
			
		|||
package server
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/fatedier/frp/assets"
 | 
			
		||||
| 
						 | 
				
			
			@ -139,6 +139,13 @@ func NewService() (svr *Service, err error) {
 | 
			
		|||
		log.Info("frps kcp listen on udp %s:%d", cfg.BindAddr, cfg.KcpBindPort)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Listen for accepting connections from client using websocket protocol.
 | 
			
		||||
	websocketPrefix := []byte("GET /%23frp")
 | 
			
		||||
	websocketLn := svr.muxer.Listen(0, uint32(len(websocketPrefix)), func(data []byte) bool {
 | 
			
		||||
		return bytes.Equal(data, websocketPrefix)
 | 
			
		||||
	})
 | 
			
		||||
	svr.websocketListener = frpNet.NewWebsocketListener(websocketLn)
 | 
			
		||||
 | 
			
		||||
	// Create http vhost muxer.
 | 
			
		||||
	if cfg.VhostHttpPort > 0 {
 | 
			
		||||
		rp := vhost.NewHttpReverseProxy()
 | 
			
		||||
| 
						 | 
				
			
			@ -150,7 +157,9 @@ func NewService() (svr *Service, err error) {
 | 
			
		|||
			Handler: rp,
 | 
			
		||||
		}
 | 
			
		||||
		var l net.Listener
 | 
			
		||||
		if !httpMuxOn {
 | 
			
		||||
		if httpMuxOn {
 | 
			
		||||
			l = svr.muxer.ListenHttp(1)
 | 
			
		||||
		} else {
 | 
			
		||||
			l, err = net.Listen("tcp", address)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				err = fmt.Errorf("Create vhost http listener error, %v", err)
 | 
			
		||||
| 
						 | 
				
			
			@ -165,7 +174,7 @@ func NewService() (svr *Service, err error) {
 | 
			
		|||
	if cfg.VhostHttpsPort > 0 {
 | 
			
		||||
		var l net.Listener
 | 
			
		||||
		if httpsMuxOn {
 | 
			
		||||
			l = svr.muxer.ListenHttps(0)
 | 
			
		||||
			l = svr.muxer.ListenHttps(1)
 | 
			
		||||
		} else {
 | 
			
		||||
			l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -205,37 +214,6 @@ func NewService() (svr *Service, err error) {
 | 
			
		|||
		log.Info("Dashboard listen on %s:%d", cfg.DashboardAddr, cfg.DashboardPort)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !httpMuxOn {
 | 
			
		||||
		svr.websocketListener, err = frpNet.NewWebsocketListener(svr.muxer.ListenHttp(0), nil)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// server := &http.Server{}
 | 
			
		||||
	if httpMuxOn {
 | 
			
		||||
		rp := svr.httpReverseProxy
 | 
			
		||||
		svr.websocketListener, err = frpNet.NewWebsocketListener(svr.muxer.ListenHttp(0),
 | 
			
		||||
			func(w http.ResponseWriter, req *http.Request) bool {
 | 
			
		||||
				domain := getHostFromAddr(req.Host)
 | 
			
		||||
				location := req.URL.Path
 | 
			
		||||
				headers := rp.GetHeaders(domain, location)
 | 
			
		||||
				if headers == nil {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
				rp.ServeHTTP(w, req)
 | 
			
		||||
				return false
 | 
			
		||||
			})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getHostFromAddr(addr string) (host string) {
 | 
			
		||||
	strs := strings.Split(addr, ":")
 | 
			
		||||
	if len(strs) > 1 {
 | 
			
		||||
		host = strs[0]
 | 
			
		||||
	} else {
 | 
			
		||||
		host = addr
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -246,9 +224,9 @@ func (svr *Service) Run() {
 | 
			
		|||
	if g.GlbServerCfg.KcpBindPort > 0 {
 | 
			
		||||
		go svr.HandleListener(svr.kcpListener)
 | 
			
		||||
	}
 | 
			
		||||
	if svr.websocketListener != nil {
 | 
			
		||||
		go svr.HandleListener(svr.websocketListener)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go svr.HandleListener(svr.websocketListener)
 | 
			
		||||
 | 
			
		||||
	svr.HandleListener(svr.listener)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -260,6 +238,7 @@ func (svr *Service) HandleListener(l frpNet.Listener) {
 | 
			
		|||
			log.Warn("Listener for incoming connections from client closed")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Start a new goroutine for dealing connections.
 | 
			
		||||
		go func(frpConn frpNet.Conn) {
 | 
			
		||||
			dealFn := func(conn frpNet.Conn) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -96,6 +96,75 @@ func (conn *WrapReadWriteCloserConn) SetWriteDeadline(t time.Time) error {
 | 
			
		|||
	return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CloseNotifyConn struct {
 | 
			
		||||
	net.Conn
 | 
			
		||||
	log.Logger
 | 
			
		||||
 | 
			
		||||
	// 1 means closed
 | 
			
		||||
	closeFlag int32
 | 
			
		||||
 | 
			
		||||
	closeFn func()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// closeFn will be only called once
 | 
			
		||||
func WrapCloseNotifyConn(c net.Conn, closeFn func()) Conn {
 | 
			
		||||
	return &CloseNotifyConn{
 | 
			
		||||
		Conn:    c,
 | 
			
		||||
		Logger:  log.NewPrefixLogger(""),
 | 
			
		||||
		closeFn: closeFn,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cc *CloseNotifyConn) Close() (err error) {
 | 
			
		||||
	pflag := atomic.SwapInt32(&cc.closeFlag, 1)
 | 
			
		||||
	if pflag == 0 {
 | 
			
		||||
		err = cc.Close()
 | 
			
		||||
		if cc.closeFn != nil {
 | 
			
		||||
			cc.closeFn()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StatsConn struct {
 | 
			
		||||
	Conn
 | 
			
		||||
 | 
			
		||||
	closed     int64 // 1 means closed
 | 
			
		||||
	totalRead  int64
 | 
			
		||||
	totalWrite int64
 | 
			
		||||
	statsFunc  func(totalRead, totalWrite int64)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WrapStatsConn(conn Conn, statsFunc func(total, totalWrite int64)) *StatsConn {
 | 
			
		||||
	return &StatsConn{
 | 
			
		||||
		Conn:      conn,
 | 
			
		||||
		statsFunc: statsFunc,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (statsConn *StatsConn) Read(p []byte) (n int, err error) {
 | 
			
		||||
	n, err = statsConn.Conn.Read(p)
 | 
			
		||||
	statsConn.totalRead += int64(n)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (statsConn *StatsConn) Write(p []byte) (n int, err error) {
 | 
			
		||||
	n, err = statsConn.Conn.Write(p)
 | 
			
		||||
	statsConn.totalWrite += int64(n)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (statsConn *StatsConn) Close() (err error) {
 | 
			
		||||
	old := atomic.SwapInt64(&statsConn.closed, 1)
 | 
			
		||||
	if old != 1 {
 | 
			
		||||
		err = statsConn.Conn.Close()
 | 
			
		||||
		if statsConn.statsFunc != nil {
 | 
			
		||||
			statsConn.statsFunc(statsConn.totalRead, statsConn.totalWrite)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ConnectServer(protocol string, addr string) (c Conn, err error) {
 | 
			
		||||
	switch protocol {
 | 
			
		||||
	case "tcp":
 | 
			
		||||
| 
						 | 
				
			
			@ -138,42 +207,3 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn
 | 
			
		|||
		return nil, fmt.Errorf("unsupport protocol: %s", protocol)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StatsConn struct {
 | 
			
		||||
	Conn
 | 
			
		||||
 | 
			
		||||
	closed     int64 // 1 means closed
 | 
			
		||||
	totalRead  int64
 | 
			
		||||
	totalWrite int64
 | 
			
		||||
	statsFunc  func(totalRead, totalWrite int64)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WrapStatsConn(conn Conn, statsFunc func(total, totalWrite int64)) *StatsConn {
 | 
			
		||||
	return &StatsConn{
 | 
			
		||||
		Conn:      conn,
 | 
			
		||||
		statsFunc: statsFunc,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (statsConn *StatsConn) Read(p []byte) (n int, err error) {
 | 
			
		||||
	n, err = statsConn.Conn.Read(p)
 | 
			
		||||
	statsConn.totalRead += int64(n)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (statsConn *StatsConn) Write(p []byte) (n int, err error) {
 | 
			
		||||
	n, err = statsConn.Conn.Write(p)
 | 
			
		||||
	statsConn.totalWrite += int64(n)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (statsConn *StatsConn) Close() (err error) {
 | 
			
		||||
	old := atomic.SwapInt64(&statsConn.closed, 1)
 | 
			
		||||
	if old != 1 {
 | 
			
		||||
		err = statsConn.Conn.Close()
 | 
			
		||||
		if statsConn.statsFunc != nil {
 | 
			
		||||
			statsConn.statsFunc(statsConn.totalRead, statsConn.totalWrite)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,127 +1,105 @@
 | 
			
		|||
package net
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/fatedier/frp/utils/log"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/net/websocket"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	ErrWebsocketListenerClosed = errors.New("websocket listener closed")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	FrpWebsocketPath = "/#frp"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type WebsocketListener struct {
 | 
			
		||||
	net.Addr
 | 
			
		||||
	ln     net.Listener
 | 
			
		||||
	accept chan Conn
 | 
			
		||||
	log.Logger
 | 
			
		||||
 | 
			
		||||
	server    *http.Server
 | 
			
		||||
	httpMutex *http.ServeMux
 | 
			
		||||
	connChan  chan *WebsocketConn
 | 
			
		||||
	closeFlag bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewWebsocketListener(ln net.Listener,
 | 
			
		||||
	filter func(w http.ResponseWriter, r *http.Request) bool) (l *WebsocketListener, err error) {
 | 
			
		||||
	l = &WebsocketListener{
 | 
			
		||||
		httpMutex: http.NewServeMux(),
 | 
			
		||||
		connChan:  make(chan *WebsocketConn),
 | 
			
		||||
		Logger:    log.NewPrefixLogger(""),
 | 
			
		||||
// ln: tcp listener for websocket connections
 | 
			
		||||
func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
 | 
			
		||||
	wl = &WebsocketListener{
 | 
			
		||||
		Addr:   ln.Addr(),
 | 
			
		||||
		accept: make(chan Conn),
 | 
			
		||||
		Logger: log.NewPrefixLogger(""),
 | 
			
		||||
	}
 | 
			
		||||
	l.httpMutex.Handle("/", websocket.Handler(func(c *websocket.Conn) {
 | 
			
		||||
		conn := NewWebScoketConn(c)
 | 
			
		||||
		l.connChan <- conn
 | 
			
		||||
		conn.waitClose()
 | 
			
		||||
 | 
			
		||||
	muxer := http.NewServeMux()
 | 
			
		||||
	muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) {
 | 
			
		||||
		notifyCh := make(chan struct{})
 | 
			
		||||
		conn := WrapCloseNotifyConn(c, func() {
 | 
			
		||||
			close(notifyCh)
 | 
			
		||||
		})
 | 
			
		||||
		wl.accept <- conn
 | 
			
		||||
		<-notifyCh
 | 
			
		||||
	}))
 | 
			
		||||
	l.server = &http.Server{
 | 
			
		||||
		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			if filter != nil && !filter(w, r) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			l.httpMutex.ServeHTTP(w, r)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
	wl.server = &http.Server{
 | 
			
		||||
		Addr:    ln.Addr().String(),
 | 
			
		||||
		Handler: muxer,
 | 
			
		||||
	}
 | 
			
		||||
	ch := make(chan struct{})
 | 
			
		||||
	go func() {
 | 
			
		||||
		close(ch)
 | 
			
		||||
		err = l.server.Serve(ln)
 | 
			
		||||
	}()
 | 
			
		||||
	<-ch
 | 
			
		||||
	<-time.After(time.Millisecond)
 | 
			
		||||
 | 
			
		||||
	go wl.server.Serve(ln)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ListenWebsocket(bindAddr string, bindPort int) (l *WebsocketListener, err error) {
 | 
			
		||||
	ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
 | 
			
		||||
func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
 | 
			
		||||
	tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	l, err = NewWebsocketListener(ln, nil)
 | 
			
		||||
	return
 | 
			
		||||
	l := NewWebsocketListener(tcpLn)
 | 
			
		||||
	return l, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *WebsocketListener) Accept() (Conn, error) {
 | 
			
		||||
	c := <-p.connChan
 | 
			
		||||
	c, ok := <-p.accept
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, ErrWebsocketListenerClosed
 | 
			
		||||
	}
 | 
			
		||||
	return c, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *WebsocketListener) Close() error {
 | 
			
		||||
	if !p.closeFlag {
 | 
			
		||||
		p.closeFlag = true
 | 
			
		||||
		p.server.Close()
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
	return p.server.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type WebsocketConn struct {
 | 
			
		||||
	net.Conn
 | 
			
		||||
	log.Logger
 | 
			
		||||
	closed int32
 | 
			
		||||
	wait   chan struct{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewWebScoketConn(conn net.Conn) (c *WebsocketConn) {
 | 
			
		||||
	c = &WebsocketConn{
 | 
			
		||||
		Conn:   conn,
 | 
			
		||||
		Logger: log.NewPrefixLogger(""),
 | 
			
		||||
		wait:   make(chan struct{}),
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *WebsocketConn) Close() error {
 | 
			
		||||
	if atomic.SwapInt32(&p.closed, 1) == 1 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	close(p.wait)
 | 
			
		||||
	return p.Conn.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *WebsocketConn) waitClose() {
 | 
			
		||||
	<-p.wait
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConnectWebsocketServer :
 | 
			
		||||
// addr: ws://domain:port
 | 
			
		||||
func ConnectWebsocketServer(addr string) (c Conn, err error) {
 | 
			
		||||
	addr = "ws://" + addr
 | 
			
		||||
// addr: domain:port
 | 
			
		||||
func ConnectWebsocketServer(addr string) (Conn, error) {
 | 
			
		||||
	addr = "ws://" + addr + FrpWebsocketPath
 | 
			
		||||
	uri, err := url.Parse(addr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	origin := "http://" + uri.Host
 | 
			
		||||
	cfg, err := websocket.NewConfig(addr, origin)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	cfg.Dialer = &net.Dialer{
 | 
			
		||||
		Timeout: time.Second * 10,
 | 
			
		||||
		Timeout: 10 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	conn, err := websocket.DialConfig(cfg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	c = NewWebScoketConn(conn)
 | 
			
		||||
	return
 | 
			
		||||
	c := WrapConn(conn)
 | 
			
		||||
	return c, nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue