mirror of https://github.com/fatedier/frp
fatedier
6 years ago
4 changed files with 140 additions and 152 deletions
@ -1,127 +1,105 @@
|
||||
package net |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"net" |
||||
"net/http" |
||||
"net/url" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/fatedier/frp/utils/log" |
||||
|
||||
"golang.org/x/net/websocket" |
||||
) |
||||
|
||||
var ( |
||||
ErrWebsocketListenerClosed = errors.New("websocket listener closed") |
||||
) |
||||
|
||||
const ( |
||||
FrpWebsocketPath = "/#frp" |
||||
) |
||||
|
||||
type WebsocketListener struct { |
||||
net.Addr |
||||
ln net.Listener |
||||
accept chan Conn |
||||
log.Logger |
||||
|
||||
server *http.Server |
||||
httpMutex *http.ServeMux |
||||
connChan chan *WebsocketConn |
||||
closeFlag bool |
||||
} |
||||
|
||||
func NewWebsocketListener(ln net.Listener, |
||||
filter func(w http.ResponseWriter, r *http.Request) bool) (l *WebsocketListener, err error) { |
||||
l = &WebsocketListener{ |
||||
httpMutex: http.NewServeMux(), |
||||
connChan: make(chan *WebsocketConn), |
||||
Logger: log.NewPrefixLogger(""), |
||||
// ln: tcp listener for websocket connections
|
||||
func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) { |
||||
wl = &WebsocketListener{ |
||||
Addr: ln.Addr(), |
||||
accept: make(chan Conn), |
||||
Logger: log.NewPrefixLogger(""), |
||||
} |
||||
l.httpMutex.Handle("/", websocket.Handler(func(c *websocket.Conn) { |
||||
conn := NewWebScoketConn(c) |
||||
l.connChan <- conn |
||||
conn.waitClose() |
||||
|
||||
muxer := http.NewServeMux() |
||||
muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) { |
||||
notifyCh := make(chan struct{}) |
||||
conn := WrapCloseNotifyConn(c, func() { |
||||
close(notifyCh) |
||||
}) |
||||
wl.accept <- conn |
||||
<-notifyCh |
||||
})) |
||||
l.server = &http.Server{ |
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
if filter != nil && !filter(w, r) { |
||||
return |
||||
} |
||||
l.httpMutex.ServeHTTP(w, r) |
||||
}), |
||||
|
||||
wl.server = &http.Server{ |
||||
Addr: ln.Addr().String(), |
||||
Handler: muxer, |
||||
} |
||||
ch := make(chan struct{}) |
||||
go func() { |
||||
close(ch) |
||||
err = l.server.Serve(ln) |
||||
}() |
||||
<-ch |
||||
<-time.After(time.Millisecond) |
||||
|
||||
go wl.server.Serve(ln) |
||||
return |
||||
} |
||||
|
||||
func ListenWebsocket(bindAddr string, bindPort int) (l *WebsocketListener, err error) { |
||||
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) |
||||
func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) { |
||||
tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) |
||||
if err != nil { |
||||
return |
||||
return nil, err |
||||
} |
||||
l, err = NewWebsocketListener(ln, nil) |
||||
return |
||||
l := NewWebsocketListener(tcpLn) |
||||
return l, nil |
||||
} |
||||
|
||||
func (p *WebsocketListener) Accept() (Conn, error) { |
||||
c := <-p.connChan |
||||
c, ok := <-p.accept |
||||
if !ok { |
||||
return nil, ErrWebsocketListenerClosed |
||||
} |
||||
return c, nil |
||||
} |
||||
|
||||
func (p *WebsocketListener) Close() error { |
||||
if !p.closeFlag { |
||||
p.closeFlag = true |
||||
p.server.Close() |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
type WebsocketConn struct { |
||||
net.Conn |
||||
log.Logger |
||||
closed int32 |
||||
wait chan struct{} |
||||
} |
||||
|
||||
func NewWebScoketConn(conn net.Conn) (c *WebsocketConn) { |
||||
c = &WebsocketConn{ |
||||
Conn: conn, |
||||
Logger: log.NewPrefixLogger(""), |
||||
wait: make(chan struct{}), |
||||
} |
||||
return |
||||
return p.server.Close() |
||||
} |
||||
|
||||
func (p *WebsocketConn) Close() error { |
||||
if atomic.SwapInt32(&p.closed, 1) == 1 { |
||||
return nil |
||||
} |
||||
close(p.wait) |
||||
return p.Conn.Close() |
||||
} |
||||
|
||||
func (p *WebsocketConn) waitClose() { |
||||
<-p.wait |
||||
} |
||||
|
||||
// ConnectWebsocketServer :
|
||||
// addr: ws://domain:port
|
||||
func ConnectWebsocketServer(addr string) (c Conn, err error) { |
||||
addr = "ws://" + addr |
||||
// addr: domain:port
|
||||
func ConnectWebsocketServer(addr string) (Conn, error) { |
||||
addr = "ws://" + addr + FrpWebsocketPath |
||||
uri, err := url.Parse(addr) |
||||
if err != nil { |
||||
return |
||||
return nil, err |
||||
} |
||||
|
||||
origin := "http://" + uri.Host |
||||
cfg, err := websocket.NewConfig(addr, origin) |
||||
if err != nil { |
||||
return |
||||
return nil, err |
||||
} |
||||
cfg.Dialer = &net.Dialer{ |
||||
Timeout: time.Second * 10, |
||||
Timeout: 10 * time.Second, |
||||
} |
||||
|
||||
conn, err := websocket.DialConfig(cfg) |
||||
if err != nil { |
||||
return |
||||
return nil, err |
||||
} |
||||
c = NewWebScoketConn(conn) |
||||
return |
||||
c := WrapConn(conn) |
||||
return c, nil |
||||
} |
||||
|
Loading…
Reference in new issue