diff --git a/Release.md b/Release.md index 16e8324..a834392 100644 --- a/Release.md +++ b/Release.md @@ -1,3 +1,4 @@ ### Fixes * frpc: Return code 1 when the first login attempt fails and exits. +* When auth.method is `oidc` and auth.additionalScopes contains `HeartBeats`, if obtaining AccessToken fails, the application will be unresponsive. diff --git a/client/admin_api.go b/client/admin_api.go index a348e8d..e775f52 100644 --- a/client/admin_api.go +++ b/client/admin_api.go @@ -144,7 +144,14 @@ func (svr *Service) apiStatus(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write(buf) }() - ps := svr.ctl.pm.GetAllProxyStatus() + svr.ctlMu.RLock() + ctl := svr.ctl + svr.ctlMu.RUnlock() + if ctl == nil { + return + } + + ps := ctl.pm.GetAllProxyStatus() for _, status := range ps { res[status.Type] = append(res[status.Type], NewProxyStatusResp(status, svr.cfg.ServerAddr)) } diff --git a/client/control.go b/client/control.go index 33fe2b5..c8d186c 100644 --- a/client/control.go +++ b/client/control.go @@ -16,13 +16,10 @@ package client import ( "context" - "io" "net" - "runtime/debug" + "sync/atomic" "time" - "github.com/fatedier/golib/control/shutdown" - "github.com/fatedier/golib/crypto" "github.com/samber/lo" "github.com/fatedier/frp/client/proxy" @@ -31,6 +28,8 @@ import ( v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/transport" + utilnet "github.com/fatedier/frp/pkg/util/net" + "github.com/fatedier/frp/pkg/util/wait" "github.com/fatedier/frp/pkg/util/xlog" ) @@ -39,6 +38,12 @@ type Control struct { ctx context.Context xl *xlog.Logger + // The client configuration + clientCfg *v1.ClientCommonConfig + + // sets authentication based on selected method + authSetter auth.Setter + // Unique ID obtained from frps. // It should be attached to the login message when reconnecting. runID string @@ -50,36 +55,25 @@ type Control struct { // manage all visitors vm *visitor.Manager - // control connection + // control connection. Once conn is closed, the msgDispatcher and the entire Control will exit. conn net.Conn + // use cm to create new connections, which could be real TCP connections or virtual streams. cm *ConnectionManager - // put a message in this channel to send it over control connection to server - sendCh chan (msg.Message) - - // read from this channel to get the next message sent by server - readCh chan (msg.Message) - - // goroutines can block by reading from this channel, it will be closed only in reader() when control connection is closed - closedCh chan struct{} - - closedDoneCh chan struct{} + doneCh chan struct{} - // last time got the Pong message - lastPong time.Time - - // The client configuration - clientCfg *v1.ClientCommonConfig - - readerShutdown *shutdown.Shutdown - writerShutdown *shutdown.Shutdown - msgHandlerShutdown *shutdown.Shutdown - - // sets authentication based on selected method - authSetter auth.Setter + // of time.Time, last time got the Pong message + lastPong atomic.Value + // The role of msgTransporter is similar to HTTP2. + // It allows multiple messages to be sent simultaneously on the same control connection. + // The server's response messages will be dispatched to the corresponding waiting goroutines based on the laneKey and message type. msgTransporter transport.MessageTransporter + + // msgDispatcher is a wrapper for control connection. + // It provides a channel for sending messages, and you can register handlers to process messages based on their respective types. + msgDispatcher *msg.Dispatcher } func NewControl( @@ -88,31 +82,34 @@ func NewControl( pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer, authSetter auth.Setter, -) *Control { +) (*Control, error) { // new xlog instance ctl := &Control{ - ctx: ctx, - xl: xlog.FromContextSafe(ctx), - runID: runID, - conn: conn, - cm: cm, - pxyCfgs: pxyCfgs, - sendCh: make(chan msg.Message, 100), - readCh: make(chan msg.Message, 100), - closedCh: make(chan struct{}), - closedDoneCh: make(chan struct{}), - clientCfg: clientCfg, - readerShutdown: shutdown.New(), - writerShutdown: shutdown.New(), - msgHandlerShutdown: shutdown.New(), - authSetter: authSetter, + ctx: ctx, + xl: xlog.FromContextSafe(ctx), + clientCfg: clientCfg, + authSetter: authSetter, + runID: runID, + pxyCfgs: pxyCfgs, + conn: conn, + cm: cm, + doneCh: make(chan struct{}), + } + ctl.lastPong.Store(time.Now()) + + cryptoRW, err := utilnet.NewCryptoReadWriter(conn, []byte(clientCfg.Auth.Token)) + if err != nil { + return nil, err } - ctl.msgTransporter = transport.NewMessageTransporter(ctl.sendCh) - ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter) + ctl.msgDispatcher = msg.NewDispatcher(cryptoRW) + ctl.registerMsgHandlers() + ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel()) + + ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter) ctl.vm = visitor.NewManager(ctl.ctx, ctl.runID, ctl.clientCfg, ctl.connectServer, ctl.msgTransporter) ctl.vm.Reload(visitorCfgs) - return ctl + return ctl, nil } func (ctl *Control) Run() { @@ -125,7 +122,7 @@ func (ctl *Control) Run() { go ctl.vm.Run() } -func (ctl *Control) HandleReqWorkConn(_ *msg.ReqWorkConn) { +func (ctl *Control) handleReqWorkConn(_ msg.Message) { xl := ctl.xl workConn, err := ctl.connectServer() if err != nil { @@ -162,8 +159,9 @@ func (ctl *Control) HandleReqWorkConn(_ *msg.ReqWorkConn) { ctl.pm.HandleWorkConn(startMsg.ProxyName, workConn, &startMsg) } -func (ctl *Control) HandleNewProxyResp(inMsg *msg.NewProxyResp) { +func (ctl *Control) handleNewProxyResp(m msg.Message) { xl := ctl.xl + inMsg := m.(*msg.NewProxyResp) // Server will return NewProxyResp message to each NewProxy message. // Start a new proxy handler if no error got err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error) @@ -174,8 +172,9 @@ func (ctl *Control) HandleNewProxyResp(inMsg *msg.NewProxyResp) { } } -func (ctl *Control) HandleNatHoleResp(inMsg *msg.NatHoleResp) { +func (ctl *Control) handleNatHoleResp(m msg.Message) { xl := ctl.xl + inMsg := m.(*msg.NatHoleResp) // Dispatch the NatHoleResp message to the related proxy. ok := ctl.msgTransporter.DispatchWithType(inMsg, msg.TypeNameNatHoleResp, inMsg.TransactionID) @@ -184,6 +183,19 @@ func (ctl *Control) HandleNatHoleResp(inMsg *msg.NatHoleResp) { } } +func (ctl *Control) handlePong(m msg.Message) { + xl := ctl.xl + inMsg := m.(*msg.Pong) + + if inMsg.Error != "" { + xl.Error("Pong message contains error: %s", inMsg.Error) + ctl.conn.Close() + return + } + ctl.lastPong.Store(time.Now()) + xl.Debug("receive heartbeat from server") +} + func (ctl *Control) Close() error { return ctl.GracefulClose(0) } @@ -199,9 +211,9 @@ func (ctl *Control) GracefulClose(d time.Duration) error { return nil } -// ClosedDoneCh returns a channel that will be closed after all resources are released -func (ctl *Control) ClosedDoneCh() <-chan struct{} { - return ctl.closedDoneCh +// Done returns a channel that will be closed after all resources are released +func (ctl *Control) Done() <-chan struct{} { + return ctl.doneCh } // connectServer return a new connection to frps @@ -209,151 +221,70 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) { return ctl.cm.Connect() } -// reader read all messages from frps and send to readCh -func (ctl *Control) reader() { - xl := ctl.xl - defer func() { - if err := recover(); err != nil { - xl.Error("panic error: %v", err) - xl.Error(string(debug.Stack())) - } - }() - defer ctl.readerShutdown.Done() - defer close(ctl.closedCh) - - encReader := crypto.NewReader(ctl.conn, []byte(ctl.clientCfg.Auth.Token)) - for { - m, err := msg.ReadMsg(encReader) - if err != nil { - if err == io.EOF { - xl.Debug("read from control connection EOF") - return - } - xl.Warn("read error: %v", err) - ctl.conn.Close() - return - } - ctl.readCh <- m - } +func (ctl *Control) registerMsgHandlers() { + ctl.msgDispatcher.RegisterHandler(&msg.ReqWorkConn{}, msg.AsyncHandler(ctl.handleReqWorkConn)) + ctl.msgDispatcher.RegisterHandler(&msg.NewProxyResp{}, ctl.handleNewProxyResp) + ctl.msgDispatcher.RegisterHandler(&msg.NatHoleResp{}, ctl.handleNatHoleResp) + ctl.msgDispatcher.RegisterHandler(&msg.Pong{}, ctl.handlePong) } -// writer writes messages got from sendCh to frps -func (ctl *Control) writer() { +// headerWorker sends heartbeat to server and check heartbeat timeout. +func (ctl *Control) heartbeatWorker() { xl := ctl.xl - defer ctl.writerShutdown.Done() - encWriter, err := crypto.NewWriter(ctl.conn, []byte(ctl.clientCfg.Auth.Token)) - if err != nil { - xl.Error("crypto new writer error: %v", err) - ctl.conn.Close() - return - } - for { - m, ok := <-ctl.sendCh - if !ok { - xl.Info("control writer is closing") - return - } - if err := msg.WriteMsg(encWriter, m); err != nil { - xl.Warn("write message to control connection error: %v", err) - return - } - } -} - -// msgHandler handles all channel events and performs corresponding operations. -func (ctl *Control) msgHandler() { - xl := ctl.xl - defer func() { - if err := recover(); err != nil { - xl.Error("panic error: %v", err) - xl.Error(string(debug.Stack())) + // TODO(fatedier): Change default value of HeartbeatInterval to -1 if tcpmux is enabled. + // Users can still enable heartbeat feature by setting HeartbeatInterval to a positive value. + if ctl.clientCfg.Transport.HeartbeatInterval > 0 { + // send heartbeat to server + sendHeartBeat := func() error { + xl.Debug("send heartbeat to server") + pingMsg := &msg.Ping{} + if err := ctl.authSetter.SetPing(pingMsg); err != nil { + xl.Warn("error during ping authentication: %v, skip sending ping message", err) + return err + } + _ = ctl.msgDispatcher.Send(pingMsg) + return nil } - }() - defer ctl.msgHandlerShutdown.Done() - var hbSendCh <-chan time.Time - // TODO(fatedier): disable heartbeat if TCPMux is enabled. - // Just keep it here to keep compatible with old version frps. - if ctl.clientCfg.Transport.HeartbeatInterval > 0 { - hbSend := time.NewTicker(time.Duration(ctl.clientCfg.Transport.HeartbeatInterval) * time.Second) - defer hbSend.Stop() - hbSendCh = hbSend.C + go wait.BackoffUntil(sendHeartBeat, + wait.NewFastBackoffManager(wait.FastBackoffOptions{ + Duration: time.Duration(ctl.clientCfg.Transport.HeartbeatInterval) * time.Second, + InitDurationIfFail: time.Second, + Factor: 2.0, + Jitter: 0.1, + MaxDuration: time.Duration(ctl.clientCfg.Transport.HeartbeatInterval) * time.Second, + }), + true, ctl.doneCh, + ) } - var hbCheckCh <-chan time.Time // Check heartbeat timeout only if TCPMux is not enabled and users don't disable heartbeat feature. if ctl.clientCfg.Transport.HeartbeatInterval > 0 && ctl.clientCfg.Transport.HeartbeatTimeout > 0 && !lo.FromPtr(ctl.clientCfg.Transport.TCPMux) { - hbCheck := time.NewTicker(time.Second) - defer hbCheck.Stop() - hbCheckCh = hbCheck.C - } - ctl.lastPong = time.Now() - for { - select { - case <-hbSendCh: - // send heartbeat to server - xl.Debug("send heartbeat to server") - pingMsg := &msg.Ping{} - if err := ctl.authSetter.SetPing(pingMsg); err != nil { - xl.Warn("error during ping authentication: %v. skip sending ping message", err) - continue - } - ctl.sendCh <- pingMsg - case <-hbCheckCh: - if time.Since(ctl.lastPong) > time.Duration(ctl.clientCfg.Transport.HeartbeatTimeout)*time.Second { + go wait.Until(func() { + if time.Since(ctl.lastPong.Load().(time.Time)) > time.Duration(ctl.clientCfg.Transport.HeartbeatTimeout)*time.Second { xl.Warn("heartbeat timeout") - // let reader() stop ctl.conn.Close() return } - case rawMsg, ok := <-ctl.readCh: - if !ok { - return - } - - switch m := rawMsg.(type) { - case *msg.ReqWorkConn: - go ctl.HandleReqWorkConn(m) - case *msg.NewProxyResp: - ctl.HandleNewProxyResp(m) - case *msg.NatHoleResp: - ctl.HandleNatHoleResp(m) - case *msg.Pong: - if m.Error != "" { - xl.Error("Pong contains error: %s", m.Error) - ctl.conn.Close() - return - } - ctl.lastPong = time.Now() - xl.Debug("receive heartbeat from server") - } - } + }, time.Second, ctl.doneCh) } } -// If controler is notified by closedCh, reader and writer and handler will exit func (ctl *Control) worker() { - go ctl.msgHandler() - go ctl.reader() - go ctl.writer() + go ctl.heartbeatWorker() + go ctl.msgDispatcher.Run() - <-ctl.closedCh - // close related channels and wait until other goroutines done - close(ctl.readCh) - ctl.readerShutdown.WaitDone() - ctl.msgHandlerShutdown.WaitDone() - - close(ctl.sendCh) - ctl.writerShutdown.WaitDone() + <-ctl.msgDispatcher.Done() + ctl.conn.Close() ctl.pm.Close() ctl.vm.Close() - - close(ctl.closedDoneCh) ctl.cm.Close() + + close(ctl.doneCh) } func (ctl *Control) ReloadConf(pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error { diff --git a/client/service.go b/client/service.go index 4b394d8..66a642c 100644 --- a/client/service.go +++ b/client/service.go @@ -17,6 +17,7 @@ package client import ( "context" "crypto/tls" + "errors" "fmt" "io" "net" @@ -24,7 +25,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "github.com/fatedier/golib/crypto" @@ -40,8 +40,8 @@ import ( "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/log" utilnet "github.com/fatedier/frp/pkg/util/net" - "github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/version" + "github.com/fatedier/frp/pkg/util/wait" "github.com/fatedier/frp/pkg/util/xlog" ) @@ -70,12 +70,11 @@ type Service struct { // string if no configuration file was used. cfgFile string - exit uint32 // 0 means not exit - // service context ctx context.Context // call cancel to stop service - cancel context.CancelFunc + cancel context.CancelFunc + gracefulDuration time.Duration } func NewService( @@ -91,7 +90,6 @@ func NewService( pxyCfgs: pxyCfgs, visitorCfgs: visitorCfgs, ctx: context.Background(), - exit: 0, } } @@ -106,8 +104,6 @@ func (svr *Service) Run(ctx context.Context) error { svr.ctx = xlog.NewContext(ctx, xlog.New()) svr.cancel = cancel - xl := xlog.FromContextSafe(svr.ctx) - // set custom DNSServer if svr.cfg.DNSServer != "" { dnsAddr := svr.cfg.DNSServer @@ -124,26 +120,9 @@ func (svr *Service) Run(ctx context.Context) error { } // login to frps - for { - conn, cm, err := svr.login() - if err != nil { - xl.Warn("login to server failed: %v", err) - - // if login_fail_exit is true, just exit this program - // otherwise sleep a while and try again to connect to server - if lo.FromPtr(svr.cfg.LoginFailExit) { - return err - } - util.RandomSleep(5*time.Second, 0.9, 1.1) - } else { - // login success - ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter) - ctl.Run() - svr.ctlMu.Lock() - svr.ctl = ctl - svr.ctlMu.Unlock() - break - } + svr.loopLoginUntilSuccess(10*time.Second, lo.FromPtr(svr.cfg.LoginFailExit)) + if svr.ctl == nil { + return fmt.Errorf("the process exited because the first login to the server failed, and the loginFailExit feature is enabled") } go svr.keepControllerWorking() @@ -160,80 +139,35 @@ func (svr *Service) Run(ctx context.Context) error { log.Info("admin server listen on %s:%d", svr.cfg.WebServer.Addr, svr.cfg.WebServer.Port) } <-svr.ctx.Done() - // service context may not be canceled by svr.Close(), we should call it here to release resources - if atomic.LoadUint32(&svr.exit) == 0 { - svr.Close() - } + svr.stop() return nil } func (svr *Service) keepControllerWorking() { - xl := xlog.FromContextSafe(svr.ctx) - maxDelayTime := 20 * time.Second - delayTime := time.Second - - // if frpc reconnect frps, we need to limit retry times in 1min - // current retry logic is sleep 0s, 0s, 0s, 1s, 2s, 4s, 8s, ... - // when exceed 1min, we will reset delay and counts - cutoffTime := time.Now().Add(time.Minute) - reconnectDelay := time.Second - reconnectCounts := 1 - - for { - <-svr.ctl.ClosedDoneCh() - if atomic.LoadUint32(&svr.exit) != 0 { - return - } - - // the first three attempts with a low delay - if reconnectCounts > 3 { - util.RandomSleep(reconnectDelay, 0.9, 1.1) - xl.Info("wait %v to reconnect", reconnectDelay) - reconnectDelay *= 2 - } else { - util.RandomSleep(time.Second, 0, 0.5) - } - reconnectCounts++ - - now := time.Now() - if now.After(cutoffTime) { - // reset - cutoffTime = now.Add(time.Minute) - reconnectDelay = time.Second - reconnectCounts = 1 - } - - for { - if atomic.LoadUint32(&svr.exit) != 0 { - return - } - - xl.Info("try to reconnect to server...") - conn, cm, err := svr.login() - if err != nil { - xl.Warn("reconnect to server error: %v, wait %v for another retry", err, delayTime) - util.RandomSleep(delayTime, 0.9, 1.1) - - delayTime *= 2 - if delayTime > maxDelayTime { - delayTime = maxDelayTime - } - continue - } - // reconnect success, init delayTime - delayTime = time.Second - - ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter) - ctl.Run() - svr.ctlMu.Lock() - if svr.ctl != nil { - svr.ctl.Close() - } - svr.ctl = ctl - svr.ctlMu.Unlock() - break - } - } + <-svr.ctl.Done() + + // There is a situation where the login is successful but due to certain reasons, + // the control immediately exits. It is necessary to limit the frequency of reconnection in this case. + // The interval for the first three retries in 1 minute will be very short, and then it will increase exponentially. + // The maximum interval is 20 seconds. + wait.BackoffUntil(func() error { + // loopLoginUntilSuccess is another layer of loop that will continuously attempt to + // login to the server until successful. + svr.loopLoginUntilSuccess(20*time.Second, false) + <-svr.ctl.Done() + return errors.New("control is closed and try another loop") + }, wait.NewFastBackoffManager( + wait.FastBackoffOptions{ + Duration: time.Second, + Factor: 2, + Jitter: 0.1, + MaxDuration: 20 * time.Second, + FastRetryCount: 3, + FastRetryDelay: 200 * time.Millisecond, + FastRetryWindow: time.Minute, + FastRetryJitter: 0.5, + }, + ), true, svr.ctx.Done()) } // login creates a connection to frps and registers it self as a client @@ -299,6 +233,54 @@ func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) { return } +func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginExit bool) { + xl := xlog.FromContextSafe(svr.ctx) + successCh := make(chan struct{}) + + loginFunc := func() error { + xl.Info("try to connect to server...") + conn, cm, err := svr.login() + if err != nil { + xl.Warn("connect to server error: %v", err) + if firstLoginExit { + svr.cancel() + } + return err + } + + ctl, err := NewControl(svr.ctx, svr.runID, conn, cm, + svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter) + if err != nil { + conn.Close() + xl.Error("NewControl error: %v", err) + return err + } + + ctl.Run() + // close and replace previous control + svr.ctlMu.Lock() + if svr.ctl != nil { + svr.ctl.Close() + } + svr.ctl = ctl + svr.ctlMu.Unlock() + + close(successCh) + return nil + } + + // try to reconnect to server until success + wait.BackoffUntil(loginFunc, wait.NewFastBackoffManager( + wait.FastBackoffOptions{ + Duration: time.Second, + Factor: 2, + Jitter: 0.1, + MaxDuration: maxInterval, + }), + true, + wait.MergeAndCloseOnAnyStopChannel(svr.ctx.Done(), successCh)) +} + func (svr *Service) ReloadConf(pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error { svr.cfgMu.Lock() svr.pxyCfgs = pxyCfgs @@ -320,20 +302,20 @@ func (svr *Service) Close() { } func (svr *Service) GracefulClose(d time.Duration) { - atomic.StoreUint32(&svr.exit, 1) + svr.gracefulDuration = d + svr.cancel() +} - svr.ctlMu.RLock() +func (svr *Service) stop() { + svr.ctlMu.Lock() + defer svr.ctlMu.Unlock() if svr.ctl != nil { - svr.ctl.GracefulClose(d) + svr.ctl.GracefulClose(svr.gracefulDuration) svr.ctl = nil } - svr.ctlMu.RUnlock() - - if svr.cancel != nil { - svr.cancel() - } } +// ConnectionManager is a wrapper for establishing connections to the server. type ConnectionManager struct { ctx context.Context cfg *v1.ClientCommonConfig @@ -349,6 +331,10 @@ func NewConnectionManager(ctx context.Context, cfg *v1.ClientCommonConfig) *Conn } } +// OpenConnection opens a underlying connection to the server. +// The underlying connection is either a TCP connection or a QUIC connection. +// After the underlying connection is established, you can call Connect() to get a stream. +// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect(). func (cm *ConnectionManager) OpenConnection() error { xl := xlog.FromContextSafe(cm.ctx) @@ -411,6 +397,7 @@ func (cm *ConnectionManager) OpenConnection() error { return nil } +// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled. func (cm *ConnectionManager) Connect() (net.Conn, error) { if cm.quicConn != nil { stream, err := cm.quicConn.OpenStreamSync(context.Background()) diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 696496a..12c388a 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -1,3 +1,17 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package metrics import ( diff --git a/pkg/msg/handler.go b/pkg/msg/handler.go new file mode 100644 index 0000000..cb1eb15 --- /dev/null +++ b/pkg/msg/handler.go @@ -0,0 +1,103 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package msg + +import ( + "io" + "reflect" +) + +func AsyncHandler(f func(Message)) func(Message) { + return func(m Message) { + go f(m) + } +} + +// Dispatcher is used to send messages to net.Conn or register handlers for messages read from net.Conn. +type Dispatcher struct { + rw io.ReadWriter + + sendCh chan Message + doneCh chan struct{} + msgHandlers map[reflect.Type]func(Message) + defaultHandler func(Message) +} + +func NewDispatcher(rw io.ReadWriter) *Dispatcher { + return &Dispatcher{ + rw: rw, + sendCh: make(chan Message, 100), + doneCh: make(chan struct{}), + msgHandlers: make(map[reflect.Type]func(Message)), + } +} + +// Run will block until io.EOF or some error occurs. +func (d *Dispatcher) Run() { + go d.sendLoop() + go d.readLoop() +} + +func (d *Dispatcher) sendLoop() { + for { + select { + case <-d.doneCh: + return + case m := <-d.sendCh: + _ = WriteMsg(d.rw, m) + } + } +} + +func (d *Dispatcher) readLoop() { + for { + m, err := ReadMsg(d.rw) + if err != nil { + close(d.doneCh) + return + } + + if handler, ok := d.msgHandlers[reflect.TypeOf(m)]; ok { + handler(m) + } else if d.defaultHandler != nil { + d.defaultHandler(m) + } + } +} + +func (d *Dispatcher) Send(m Message) error { + select { + case <-d.doneCh: + return io.EOF + case d.sendCh <- m: + return nil + } +} + +func (d *Dispatcher) SendChannel() chan Message { + return d.sendCh +} + +func (d *Dispatcher) RegisterHandler(msg Message, handler func(Message)) { + d.msgHandlers[reflect.TypeOf(msg)] = handler +} + +func (d *Dispatcher) RegisterDefaultHandler(handler func(Message)) { + d.defaultHandler = handler +} + +func (d *Dispatcher) Done() chan struct{} { + return d.doneCh +} diff --git a/pkg/transport/message.go b/pkg/transport/message.go index 6bcd8ce..7163a8a 100644 --- a/pkg/transport/message.go +++ b/pkg/transport/message.go @@ -29,7 +29,9 @@ type MessageTransporter interface { // Recv(ctx context.Context, laneKey string, msgType string) (Message, error) // Do will first send msg, then recv msg with the same laneKey and specified msgType. Do(ctx context.Context, req msg.Message, laneKey, recvMsgType string) (msg.Message, error) + // Dispatch will dispatch message to releated channel registered in Do function by its message type and laneKey. Dispatch(m msg.Message, laneKey string) bool + // Same with Dispatch but with specified message type. DispatchWithType(m msg.Message, msgType, laneKey string) bool } diff --git a/pkg/util/net/conn.go b/pkg/util/net/conn.go index fb2ff67..a5bbe73 100644 --- a/pkg/util/net/conn.go +++ b/pkg/util/net/conn.go @@ -22,6 +22,7 @@ import ( "sync/atomic" "time" + "github.com/fatedier/golib/crypto" quic "github.com/quic-go/quic-go" "github.com/fatedier/frp/pkg/util/xlog" @@ -216,3 +217,18 @@ func (conn *wrapQuicStream) Close() error { conn.Stream.CancelRead(0) return conn.Stream.Close() } + +func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) { + encReader := crypto.NewReader(rw, key) + encWriter, err := crypto.NewWriter(rw, key) + if err != nil { + return nil, err + } + return struct { + io.Reader + io.Writer + }{ + Reader: encReader, + Writer: encWriter, + }, nil +} diff --git a/pkg/util/wait/backoff.go b/pkg/util/wait/backoff.go new file mode 100644 index 0000000..45e0ab6 --- /dev/null +++ b/pkg/util/wait/backoff.go @@ -0,0 +1,197 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wait + +import ( + "math/rand" + "time" + + "github.com/samber/lo" + + "github.com/fatedier/frp/pkg/util/util" +) + +type BackoffFunc func(previousDuration time.Duration, previousConditionError bool) time.Duration + +func (f BackoffFunc) Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration { + return f(previousDuration, previousConditionError) +} + +type BackoffManager interface { + Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration +} + +type FastBackoffOptions struct { + Duration time.Duration + Factor float64 + Jitter float64 + MaxDuration time.Duration + InitDurationIfFail time.Duration + + // If FastRetryCount > 0, then within the FastRetryWindow time window, + // the retry will be performed with a delay of FastRetryDelay for the first FastRetryCount calls. + FastRetryCount int + FastRetryDelay time.Duration + FastRetryJitter float64 + FastRetryWindow time.Duration +} + +type fastBackoffImpl struct { + options FastBackoffOptions + + lastCalledTime time.Time + consecutiveErrCount int + + fastRetryCutoffTime time.Time + countsInFastRetryWindow int +} + +func NewFastBackoffManager(options FastBackoffOptions) BackoffManager { + return &fastBackoffImpl{ + options: options, + countsInFastRetryWindow: 1, + } +} + +func (f *fastBackoffImpl) Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration { + if f.lastCalledTime.IsZero() { + f.lastCalledTime = time.Now() + return f.options.Duration + } + now := time.Now() + f.lastCalledTime = now + + if previousConditionError { + f.consecutiveErrCount++ + } else { + f.consecutiveErrCount = 0 + } + + if f.options.FastRetryCount > 0 && previousConditionError { + f.countsInFastRetryWindow++ + if f.countsInFastRetryWindow <= f.options.FastRetryCount { + return Jitter(f.options.FastRetryDelay, f.options.FastRetryJitter) + } + if now.After(f.fastRetryCutoffTime) { + // reset + f.fastRetryCutoffTime = now.Add(f.options.FastRetryWindow) + f.countsInFastRetryWindow = 0 + } + } + + if previousConditionError { + var duration time.Duration + if f.consecutiveErrCount == 1 { + duration = util.EmptyOr(f.options.InitDurationIfFail, previousDuration) + } else { + duration = previousDuration + } + + duration = util.EmptyOr(duration, time.Second) + if f.options.Factor != 0 { + duration = time.Duration(float64(duration) * f.options.Factor) + } + if f.options.Jitter > 0 { + duration = Jitter(duration, f.options.Jitter) + } + if f.options.MaxDuration > 0 && duration > f.options.MaxDuration { + duration = f.options.MaxDuration + } + return duration + } + return f.options.Duration +} + +func BackoffUntil(f func() error, backoff BackoffManager, sliding bool, stopCh <-chan struct{}) { + var delay time.Duration + previousError := false + + ticker := time.NewTicker(backoff.Backoff(delay, previousError)) + defer ticker.Stop() + + for { + select { + case <-stopCh: + return + default: + } + + if !sliding { + delay = backoff.Backoff(delay, previousError) + } + + if err := f(); err != nil { + previousError = true + } else { + previousError = false + } + + if sliding { + delay = backoff.Backoff(delay, previousError) + } + + ticker.Reset(delay) + select { + case <-stopCh: + return + default: + } + + select { + case <-stopCh: + return + case <-ticker.C: + } + } +} + +// Jitter returns a time.Duration between duration and duration + maxFactor * +// duration. +// +// This allows clients to avoid converging on periodic behavior. If maxFactor +// is 0.0, a suggested default value will be chosen. +func Jitter(duration time.Duration, maxFactor float64) time.Duration { + if maxFactor <= 0.0 { + maxFactor = 1.0 + } + wait := duration + time.Duration(rand.Float64()*maxFactor*float64(duration)) + return wait +} + +func Until(f func(), period time.Duration, stopCh <-chan struct{}) { + ff := func() error { + f() + return nil + } + BackoffUntil(ff, BackoffFunc(func(time.Duration, bool) time.Duration { + return period + }), true, stopCh) +} + +func MergeAndCloseOnAnyStopChannel[T any](upstreams ...<-chan T) <-chan T { + out := make(chan T) + + for _, upstream := range upstreams { + ch := upstream + go lo.Try0(func() { + select { + case <-ch: + close(out) + case <-out: + } + }) + } + return out +} diff --git a/server/control.go b/server/control.go index f2eaaa5..e651a97 100644 --- a/server/control.go +++ b/server/control.go @@ -17,15 +17,12 @@ package server import ( "context" "fmt" - "io" "net" "runtime/debug" "sync" + "sync/atomic" "time" - "github.com/fatedier/golib/control/shutdown" - "github.com/fatedier/golib/crypto" - "github.com/fatedier/golib/errors" "github.com/samber/lo" "github.com/fatedier/frp/pkg/auth" @@ -35,8 +32,10 @@ import ( "github.com/fatedier/frp/pkg/msg" plugin "github.com/fatedier/frp/pkg/plugin/server" "github.com/fatedier/frp/pkg/transport" + utilnet "github.com/fatedier/frp/pkg/util/net" "github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/version" + "github.com/fatedier/frp/pkg/util/wait" "github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/server/controller" "github.com/fatedier/frp/server/metrics" @@ -111,18 +110,16 @@ type Control struct { // other components can use this to communicate with client msgTransporter transport.MessageTransporter + // msgDispatcher is a wrapper for control connection. + // It provides a channel for sending messages, and you can register handlers to process messages based on their respective types. + msgDispatcher *msg.Dispatcher + // login message loginMsg *msg.Login // control connection conn net.Conn - // put a message in this channel to send it over control connection to client - sendCh chan (msg.Message) - - // read from this channel to get the next message sent by client - readCh chan (msg.Message) - // work connections workConnCh chan net.Conn @@ -136,27 +133,21 @@ type Control struct { portsUsedNum int // last time got the Ping message - lastPing time.Time + lastPing atomic.Value // A new run id will be generated when a new client login. // If run id got from login message has same run id, it means it's the same client, so we can // replace old controller instantly. runID string - readerShutdown *shutdown.Shutdown - writerShutdown *shutdown.Shutdown - managerShutdown *shutdown.Shutdown - allShutdown *shutdown.Shutdown - - started bool - mu sync.RWMutex // Server configuration information serverCfg *v1.ServerConfig - xl *xlog.Logger - ctx context.Context + xl *xlog.Logger + ctx context.Context + doneCh chan struct{} } func NewControl( @@ -168,36 +159,38 @@ func NewControl( ctlConn net.Conn, loginMsg *msg.Login, serverCfg *v1.ServerConfig, -) *Control { +) (*Control, error) { poolCount := loginMsg.PoolCount if poolCount > int(serverCfg.Transport.MaxPoolCount) { poolCount = int(serverCfg.Transport.MaxPoolCount) } ctl := &Control{ - rc: rc, - pxyManager: pxyManager, - pluginManager: pluginManager, - authVerifier: authVerifier, - conn: ctlConn, - loginMsg: loginMsg, - sendCh: make(chan msg.Message, 10), - readCh: make(chan msg.Message, 10), - workConnCh: make(chan net.Conn, poolCount+10), - proxies: make(map[string]proxy.Proxy), - poolCount: poolCount, - portsUsedNum: 0, - lastPing: time.Now(), - runID: loginMsg.RunID, - readerShutdown: shutdown.New(), - writerShutdown: shutdown.New(), - managerShutdown: shutdown.New(), - allShutdown: shutdown.New(), - serverCfg: serverCfg, - xl: xlog.FromContextSafe(ctx), - ctx: ctx, + rc: rc, + pxyManager: pxyManager, + pluginManager: pluginManager, + authVerifier: authVerifier, + conn: ctlConn, + loginMsg: loginMsg, + workConnCh: make(chan net.Conn, poolCount+10), + proxies: make(map[string]proxy.Proxy), + poolCount: poolCount, + portsUsedNum: 0, + runID: loginMsg.RunID, + serverCfg: serverCfg, + xl: xlog.FromContextSafe(ctx), + ctx: ctx, + doneCh: make(chan struct{}), + } + ctl.lastPing.Store(time.Now()) + + cryptoRW, err := utilnet.NewCryptoReadWriter(ctl.conn, []byte(ctl.serverCfg.Auth.Token)) + if err != nil { + return nil, err } - ctl.msgTransporter = transport.NewMessageTransporter(ctl.sendCh) - return ctl + ctl.msgDispatcher = msg.NewDispatcher(cryptoRW) + ctl.registerMsgHandlers() + ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel()) + return ctl, nil } // Start send a login success message to client and start working. @@ -208,27 +201,18 @@ func (ctl *Control) Start() { Error: "", } _ = msg.WriteMsg(ctl.conn, loginRespMsg) - ctl.mu.Lock() - ctl.started = true - ctl.mu.Unlock() - go ctl.writer() go func() { for i := 0; i < ctl.poolCount; i++ { // ignore error here, that means that this control is closed - _ = errors.PanicToError(func() { - ctl.sendCh <- &msg.ReqWorkConn{} - }) + _ = ctl.msgDispatcher.Send(&msg.ReqWorkConn{}) } }() - - go ctl.manager() - go ctl.reader() - go ctl.stoper() + go ctl.worker() } func (ctl *Control) Close() error { - ctl.allShutdown.Start() + ctl.conn.Close() return nil } @@ -236,7 +220,7 @@ func (ctl *Control) Replaced(newCtl *Control) { xl := ctl.xl xl.Info("Replaced by client [%s]", newCtl.runID) ctl.runID = "" - ctl.allShutdown.Start() + ctl.conn.Close() } func (ctl *Control) RegisterWorkConn(conn net.Conn) error { @@ -282,9 +266,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) { xl.Debug("get work connection from pool") default: // no work connections available in the poll, send message to frpc to get more - if err = errors.PanicToError(func() { - ctl.sendCh <- &msg.ReqWorkConn{} - }); err != nil { + if err := ctl.msgDispatcher.Send(&msg.ReqWorkConn{}); err != nil { return nil, fmt.Errorf("control is already closed") } @@ -304,92 +286,39 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) { } // When we get a work connection from pool, replace it with a new one. - _ = errors.PanicToError(func() { - ctl.sendCh <- &msg.ReqWorkConn{} - }) + _ = ctl.msgDispatcher.Send(&msg.ReqWorkConn{}) return } -func (ctl *Control) writer() { +func (ctl *Control) heartbeatWorker() { xl := ctl.xl - defer func() { - if err := recover(); err != nil { - xl.Error("panic error: %v", err) - xl.Error(string(debug.Stack())) - } - }() - - defer ctl.allShutdown.Start() - defer ctl.writerShutdown.Done() - - encWriter, err := crypto.NewWriter(ctl.conn, []byte(ctl.serverCfg.Auth.Token)) - if err != nil { - xl.Error("crypto new writer error: %v", err) - ctl.allShutdown.Start() - return - } - for { - m, ok := <-ctl.sendCh - if !ok { - xl.Info("control writer is closing") - return - } - if err := msg.WriteMsg(encWriter, m); err != nil { - xl.Warn("write message to control connection error: %v", err) - return - } - } -} - -func (ctl *Control) reader() { - xl := ctl.xl - defer func() { - if err := recover(); err != nil { - xl.Error("panic error: %v", err) - xl.Error(string(debug.Stack())) - } - }() - - defer ctl.allShutdown.Start() - defer ctl.readerShutdown.Done() - - encReader := crypto.NewReader(ctl.conn, []byte(ctl.serverCfg.Auth.Token)) - for { - m, err := msg.ReadMsg(encReader) - if err != nil { - if err == io.EOF { - xl.Debug("control connection closed") + // Don't need application heartbeat if TCPMux is enabled, + // yamux will do same thing. + // TODO(fatedier): let default HeartbeatTimeout to -1 if TCPMux is enabled. Users can still set it to positive value to enable it. + if !lo.FromPtr(ctl.serverCfg.Transport.TCPMux) && ctl.serverCfg.Transport.HeartbeatTimeout > 0 { + go wait.Until(func() { + if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second { + xl.Warn("heartbeat timeout") return } - xl.Warn("read error: %v", err) - ctl.conn.Close() - return - } - - ctl.readCh <- m + }, time.Second, ctl.doneCh) } } -func (ctl *Control) stoper() { +// block until Control closed +func (ctl *Control) WaitClosed() { + <-ctl.doneCh +} + +func (ctl *Control) worker() { xl := ctl.xl - defer func() { - if err := recover(); err != nil { - xl.Error("panic error: %v", err) - xl.Error(string(debug.Stack())) - } - }() - ctl.allShutdown.WaitStart() + go ctl.heartbeatWorker() + go ctl.msgDispatcher.Run() + <-ctl.msgDispatcher.Done() ctl.conn.Close() - ctl.readerShutdown.WaitDone() - - close(ctl.readCh) - ctl.managerShutdown.WaitDone() - - close(ctl.sendCh) - ctl.writerShutdown.WaitDone() ctl.mu.Lock() defer ctl.mu.Unlock() @@ -419,136 +348,104 @@ func (ctl *Control) stoper() { }() } - ctl.allShutdown.Done() - xl.Info("client exit success") metrics.Server.CloseClient() + xl.Info("client exit success") + close(ctl.doneCh) } -// block until Control closed -func (ctl *Control) WaitClosed() { - ctl.mu.RLock() - started := ctl.started - ctl.mu.RUnlock() - - if !started { - ctl.allShutdown.Done() - return - } - ctl.allShutdown.WaitDone() +func (ctl *Control) registerMsgHandlers() { + ctl.msgDispatcher.RegisterHandler(&msg.NewProxy{}, ctl.handleNewProxy) + ctl.msgDispatcher.RegisterHandler(&msg.Ping{}, ctl.handlePing) + ctl.msgDispatcher.RegisterHandler(&msg.NatHoleVisitor{}, msg.AsyncHandler(ctl.handleNatHoleVisitor)) + ctl.msgDispatcher.RegisterHandler(&msg.NatHoleClient{}, msg.AsyncHandler(ctl.handleNatHoleClient)) + ctl.msgDispatcher.RegisterHandler(&msg.NatHoleReport{}, msg.AsyncHandler(ctl.handleNatHoleReport)) + ctl.msgDispatcher.RegisterHandler(&msg.CloseProxy{}, ctl.handleCloseProxy) } -func (ctl *Control) manager() { +func (ctl *Control) handleNewProxy(m msg.Message) { xl := ctl.xl - defer func() { - if err := recover(); err != nil { - xl.Error("panic error: %v", err) - xl.Error(string(debug.Stack())) - } - }() + inMsg := m.(*msg.NewProxy) - defer ctl.allShutdown.Start() - defer ctl.managerShutdown.Done() + content := &plugin.NewProxyContent{ + User: plugin.UserInfo{ + User: ctl.loginMsg.User, + Metas: ctl.loginMsg.Metas, + RunID: ctl.loginMsg.RunID, + }, + NewProxy: *inMsg, + } + var remoteAddr string + retContent, err := ctl.pluginManager.NewProxy(content) + if err == nil { + inMsg = &retContent.NewProxy + remoteAddr, err = ctl.RegisterProxy(inMsg) + } - var heartbeatCh <-chan time.Time - // Don't need application heartbeat if TCPMux is enabled, - // yamux will do same thing. - if !lo.FromPtr(ctl.serverCfg.Transport.TCPMux) && ctl.serverCfg.Transport.HeartbeatTimeout > 0 { - heartbeat := time.NewTicker(time.Second) - defer heartbeat.Stop() - heartbeatCh = heartbeat.C + // register proxy in this control + resp := &msg.NewProxyResp{ + ProxyName: inMsg.ProxyName, + } + if err != nil { + xl.Warn("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err) + resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName), + err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)) + } else { + resp.RemoteAddr = remoteAddr + xl.Info("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType) + metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType) } + _ = ctl.msgDispatcher.Send(resp) +} - for { - select { - case <-heartbeatCh: - if time.Since(ctl.lastPing) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second { - xl.Warn("heartbeat timeout") - return - } - case rawMsg, ok := <-ctl.readCh: - if !ok { - return - } +func (ctl *Control) handlePing(m msg.Message) { + xl := ctl.xl + inMsg := m.(*msg.Ping) - switch m := rawMsg.(type) { - case *msg.NewProxy: - content := &plugin.NewProxyContent{ - User: plugin.UserInfo{ - User: ctl.loginMsg.User, - Metas: ctl.loginMsg.Metas, - RunID: ctl.loginMsg.RunID, - }, - NewProxy: *m, - } - var remoteAddr string - retContent, err := ctl.pluginManager.NewProxy(content) - if err == nil { - m = &retContent.NewProxy - remoteAddr, err = ctl.RegisterProxy(m) - } - - // register proxy in this control - resp := &msg.NewProxyResp{ - ProxyName: m.ProxyName, - } - if err != nil { - xl.Warn("new proxy [%s] type [%s] error: %v", m.ProxyName, m.ProxyType, err) - resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", m.ProxyName), - err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)) - } else { - resp.RemoteAddr = remoteAddr - xl.Info("new proxy [%s] type [%s] success", m.ProxyName, m.ProxyType) - metrics.Server.NewProxy(m.ProxyName, m.ProxyType) - } - ctl.sendCh <- resp - case *msg.NatHoleVisitor: - go ctl.HandleNatHoleVisitor(m) - case *msg.NatHoleClient: - go ctl.HandleNatHoleClient(m) - case *msg.NatHoleReport: - go ctl.HandleNatHoleReport(m) - case *msg.CloseProxy: - _ = ctl.CloseProxy(m) - xl.Info("close proxy [%s] success", m.ProxyName) - case *msg.Ping: - content := &plugin.PingContent{ - User: plugin.UserInfo{ - User: ctl.loginMsg.User, - Metas: ctl.loginMsg.Metas, - RunID: ctl.loginMsg.RunID, - }, - Ping: *m, - } - retContent, err := ctl.pluginManager.Ping(content) - if err == nil { - m = &retContent.Ping - err = ctl.authVerifier.VerifyPing(m) - } - if err != nil { - xl.Warn("received invalid ping: %v", err) - ctl.sendCh <- &msg.Pong{ - Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)), - } - return - } - ctl.lastPing = time.Now() - xl.Debug("receive heartbeat") - ctl.sendCh <- &msg.Pong{} - } - } + content := &plugin.PingContent{ + User: plugin.UserInfo{ + User: ctl.loginMsg.User, + Metas: ctl.loginMsg.Metas, + RunID: ctl.loginMsg.RunID, + }, + Ping: *inMsg, + } + retContent, err := ctl.pluginManager.Ping(content) + if err == nil { + inMsg = &retContent.Ping + err = ctl.authVerifier.VerifyPing(inMsg) } + if err != nil { + xl.Warn("received invalid ping: %v", err) + _ = ctl.msgDispatcher.Send(&msg.Pong{ + Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)), + }) + return + } + ctl.lastPing.Store(time.Now()) + xl.Debug("receive heartbeat") + _ = ctl.msgDispatcher.Send(&msg.Pong{}) } -func (ctl *Control) HandleNatHoleVisitor(m *msg.NatHoleVisitor) { - ctl.rc.NatHoleController.HandleVisitor(m, ctl.msgTransporter, ctl.loginMsg.User) +func (ctl *Control) handleNatHoleVisitor(m msg.Message) { + inMsg := m.(*msg.NatHoleVisitor) + ctl.rc.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.loginMsg.User) } -func (ctl *Control) HandleNatHoleClient(m *msg.NatHoleClient) { - ctl.rc.NatHoleController.HandleClient(m, ctl.msgTransporter) +func (ctl *Control) handleNatHoleClient(m msg.Message) { + inMsg := m.(*msg.NatHoleClient) + ctl.rc.NatHoleController.HandleClient(inMsg, ctl.msgTransporter) } -func (ctl *Control) HandleNatHoleReport(m *msg.NatHoleReport) { - ctl.rc.NatHoleController.HandleReport(m) +func (ctl *Control) handleNatHoleReport(m msg.Message) { + inMsg := m.(*msg.NatHoleReport) + ctl.rc.NatHoleController.HandleReport(inMsg) +} + +func (ctl *Control) handleCloseProxy(m msg.Message) { + xl := ctl.xl + inMsg := m.(*msg.CloseProxy) + _ = ctl.CloseProxy(inMsg) + xl.Info("close proxy [%s] success", inMsg.ProxyName) } func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) { diff --git a/server/service.go b/server/service.go index 9deffa0..2629b34 100644 --- a/server/service.go +++ b/server/service.go @@ -516,13 +516,14 @@ func (svr *Service) HandleQUICListener(l *quic.Listener) { } } -func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err error) { +func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) error { // If client's RunID is empty, it's a new client, we just create a new controller. // Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one. + var err error if loginMsg.RunID == "" { loginMsg.RunID, err = util.RandID() if err != nil { - return + return err } } @@ -534,11 +535,16 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err ctlConn.RemoteAddr().String(), loginMsg.Version, loginMsg.Hostname, loginMsg.Os, loginMsg.Arch) // Check auth. - if err = svr.authVerifier.VerifyLogin(loginMsg); err != nil { - return + if err := svr.authVerifier.VerifyLogin(loginMsg); err != nil { + return err } - ctl := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, svr.authVerifier, ctlConn, loginMsg, svr.cfg) + ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, svr.authVerifier, ctlConn, loginMsg, svr.cfg) + if err != nil { + xl.Warn("create new controller error: %v", err) + // don't return detailed errors to client + return fmt.Errorf("unexpect error when creating new controller") + } if oldCtl := svr.ctlManager.Add(loginMsg.RunID, ctl); oldCtl != nil { oldCtl.WaitClosed() } @@ -553,7 +559,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err ctl.WaitClosed() svr.ctlManager.Del(loginMsg.RunID, ctl) }() - return + return nil } // RegisterWorkConn register a new work connection to control and proxies need it.