diff --git a/.gitignore b/.gitignore index fab4548..e237cc4 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,6 @@ _testmain.go # Self bin/ +# Cache +*.swp +*.swo diff --git a/cmd/frpc/config.go b/cmd/frpc/config.go index 5374222..ff6f1ad 100644 --- a/cmd/frpc/config.go +++ b/cmd/frpc/config.go @@ -11,11 +11,12 @@ import ( // common config var ( - ServerAddr string = "0.0.0.0" - ServerPort int64 = 7000 - LogFile string = "./frpc.log" - LogLevel string = "warn" - LogWay string = "file" + ServerAddr string = "0.0.0.0" + ServerPort int64 = 7000 + LogFile string = "./frpc.log" + LogLevel string = "warn" + LogWay string = "file" + HeartBeatInterval int64 = 5 ) var ProxyClients map[string]*models.ProxyClient = make(map[string]*models.ProxyClient) diff --git a/cmd/frpc/control.go b/cmd/frpc/control.go index 313bfcb..57fce55 100644 --- a/cmd/frpc/control.go +++ b/cmd/frpc/control.go @@ -4,59 +4,47 @@ import ( "encoding/json" "io" "sync" + "time" "github.com/fatedier/frp/pkg/models" "github.com/fatedier/frp/pkg/utils/conn" "github.com/fatedier/frp/pkg/utils/log" ) +var isHeartBeatContinue bool = true + func ControlProcess(cli *models.ProxyClient, wait *sync.WaitGroup) { defer wait.Done() - c := &conn.Conn{} - err := c.ConnectServer(ServerAddr, ServerPort) - if err != nil { - log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", cli.Name, ServerAddr, ServerPort, err) + c := loginToServer(cli) + if c == nil { + log.Error("ProxyName [%s], connect to server failed!", cli.Name) return } defer c.Close() - req := &models.ClientCtlReq{ - Type: models.ControlConn, - ProxyName: cli.Name, - Passwd: cli.Passwd, - } - buf, _ := json.Marshal(req) - err = c.Write(string(buf) + "\n") - if err != nil { - log.Error("ProxyName [%s], write to server error, %v", cli.Name, err) - return - } - - res, err := c.ReadLine() - if err != nil { - log.Error("ProxyName [%s], read from server error, %v", cli.Name, err) - return - } - log.Debug("ProxyName [%s], read [%s]", cli.Name, res) - - clientCtlRes := &models.ClientCtlRes{} - if err = json.Unmarshal([]byte(res), &clientCtlRes); err != nil { - log.Error("ProxyName [%s], format server response error, %v", cli.Name, err) - return - } - - if clientCtlRes.Code != 0 { - log.Error("ProxyName [%s], start proxy error, %s", cli.Name, clientCtlRes.Msg) - return - } - for { // ignore response content now _, err := c.ReadLine() if err == io.EOF { + isHeartBeatContinue = false log.Debug("ProxyName [%s], server close this control conn", cli.Name) - break + var sleepTime time.Duration = 1 + for { + log.Debug("ProxyName [%s], try to reconnect to server[%s:%d]...", cli.Name, ServerAddr, ServerPort) + tmpConn := loginToServer(cli) + if tmpConn != nil { + c.Close() + c = tmpConn + break + } + + if sleepTime < 60 { + sleepTime++ + } + time.Sleep(sleepTime * time.Second) + } + continue } else if err != nil { log.Warn("ProxyName [%s], read from server error, %v", cli.Name, err) continue @@ -65,3 +53,72 @@ func ControlProcess(cli *models.ProxyClient, wait *sync.WaitGroup) { cli.StartTunnel(ServerAddr, ServerPort) } } + +func loginToServer(cli *models.ProxyClient) (connection *conn.Conn) { + c := &conn.Conn{} + + connection = nil + for i := 0; i < 1; i++ { + err := c.ConnectServer(ServerAddr, ServerPort) + if err != nil { + log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", cli.Name, ServerAddr, ServerPort, err) + break + } + + req := &models.ClientCtlReq{ + Type: models.ControlConn, + ProxyName: cli.Name, + Passwd: cli.Passwd, + } + buf, _ := json.Marshal(req) + err = c.Write(string(buf) + "\n") + if err != nil { + log.Error("ProxyName [%s], write to server error, %v", cli.Name, err) + break + } + + res, err := c.ReadLine() + if err != nil { + log.Error("ProxyName [%s], read from server error, %v", cli.Name, err) + break + } + log.Debug("ProxyName [%s], read [%s]", cli.Name, res) + + clientCtlRes := &models.ClientCtlRes{} + if err = json.Unmarshal([]byte(res), &clientCtlRes); err != nil { + log.Error("ProxyName [%s], format server response error, %v", cli.Name, err) + break + } + + if clientCtlRes.Code != 0 { + log.Error("ProxyName [%s], start proxy error, %s", cli.Name, clientCtlRes.Msg) + break + } + + connection = c + go startHeartBeat(connection) + log.Debug("ProxyName [%s], connect to server[%s:%d] success!", cli.Name, ServerAddr, ServerPort) + } + + if connection == nil { + c.Close() + } + + return +} + +func startHeartBeat(con *conn.Conn) { + isHeartBeatContinue = true + log.Debug("Start to send heartbeat") + for { + time.Sleep(time.Duration(HeartBeatInterval) * time.Second) + if isHeartBeatContinue { + err := con.Write("\n") + if err != nil { + log.Error("Send hearbeat to server failed! Err:%s", err.Error()) + } + } else { + break + } + } +} diff --git a/cmd/frps/config.go b/cmd/frps/config.go index b7564c2..af523e0 100644 --- a/cmd/frps/config.go +++ b/cmd/frps/config.go @@ -11,11 +11,12 @@ import ( // common config var ( - BindAddr string = "0.0.0.0" - BindPort int64 = 9527 - LogFile string = "./frps.log" - LogLevel string = "warn" - LogWay string = "file" + BindAddr string = "0.0.0.0" + BindPort int64 = 9527 + LogFile string = "./frps.log" + LogLevel string = "warn" + LogWay string = "file" + HeartBeatTimeout int64 = 30 ) var ProxyServers map[string]*models.ProxyServer = make(map[string]*models.ProxyServer) diff --git a/cmd/frps/control.go b/cmd/frps/control.go index 4e58738..609b25a 100644 --- a/cmd/frps/control.go +++ b/cmd/frps/control.go @@ -3,6 +3,8 @@ package main import ( "encoding/json" "fmt" + "io" + "time" "github.com/fatedier/frp/pkg/models" "github.com/fatedier/frp/pkg/utils/conn" @@ -17,7 +19,7 @@ func ProcessControlConn(l *conn.Listener) { } } -// control connection from every client and server +// connection from every client and server func controlWorker(c *conn.Conn) { // the first message is from client to server // if error, close connection @@ -43,17 +45,21 @@ func controlWorker(c *conn.Conn) { } if needRes { + // control conn + defer c.Close() + buf, _ := json.Marshal(clientCtlRes) err = c.Write(string(buf) + "\n") if err != nil { log.Warn("Write error, %v", err) + time.Sleep(1 * time.Second) + return } } else { // work conn, just return return } - defer c.Close() // others is from server to client server, ok := ProxyServers[clientCtlReq.ProxyName] if !ok { @@ -61,10 +67,16 @@ func controlWorker(c *conn.Conn) { return } + // read control msg from client + go readControlMsgFromClient(server, c) + serverCtlReq := &models.ClientCtlReq{} serverCtlReq.Type = models.WorkConn for { - server.WaitUserConn() + _, isStop := server.WaitUserConn() + if isStop { + break + } buf, _ := json.Marshal(serverCtlReq) err = c.Write(string(buf) + "\n") if err != nil { @@ -76,6 +88,7 @@ func controlWorker(c *conn.Conn) { log.Debug("ProxyName [%s], write to client to add work conn success", server.Name) } + log.Error("ProxyName [%s], I'm dead!", server.Name) return } @@ -124,11 +137,38 @@ func checkProxy(req *models.ClientCtlReq, c *conn.Conn) (succ bool, msg string, server.CliConnChan <- c } else { - msg = fmt.Sprintf("ProxyName [%s], type [%d] unsupport", req.ProxyName) - log.Warn(msg) + log.Warn("ProxyName [%s], type [%d] unsupport", req.ProxyName, req.Type) return } succ = true return } + +func readControlMsgFromClient(server *models.ProxyServer, c *conn.Conn) { + isContinueRead := true + f := func() { + isContinueRead = false + server.StopWaitUserConn() + } + timer := time.AfterFunc(time.Duration(HeartBeatTimeout)*time.Second, f) + defer timer.Stop() + + for isContinueRead { + content, err := c.ReadLine() + //log.Debug("Receive msg from client! content:%s", content) + if err != nil { + if err == io.EOF { + log.Warn("Server detect client[%s] is dead!", server.Name) + server.StopWaitUserConn() + break + } + log.Error("ProxyName [%s], read error:%s", server.Name, err.Error()) + continue + } + + if content == "\n" { + timer.Reset(time.Duration(HeartBeatTimeout) * time.Second) + } + } +} diff --git a/conf/frpc.ini b/conf/frpc.ini index d2ba710..f6df4b6 100644 --- a/conf/frpc.ini +++ b/conf/frpc.ini @@ -4,9 +4,9 @@ server_addr = 127.0.0.1 bind_port = 7000 log_file = ./frpc.log # debug, info, warn, error -log_level = info +log_level = debug # file, console -log_way = file +log_way = console # test1即为name [test1] diff --git a/conf/frps.ini b/conf/frps.ini index f6a6995..0c44cb1 100644 --- a/conf/frps.ini +++ b/conf/frps.ini @@ -4,9 +4,9 @@ bind_addr = 0.0.0.0 bind_port = 7000 log_file = ./frps.log # debug, info, warn, error -log_level = info +log_level = debug # file, console -log_way = file +log_way = console # test1即为name [test1] diff --git a/pkg/models/client.go b/pkg/models/client.go index 38fa9be..3fc5f57 100644 --- a/pkg/models/client.go +++ b/pkg/models/client.go @@ -63,6 +63,7 @@ func (p *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err erro return } + // l means local, r means remote log.Debug("Join two conns, (l[%s] r[%s]) (l[%s] r[%s])", localConn.GetLocalAddr(), localConn.GetRemoteAddr(), remoteConn.GetLocalAddr(), remoteConn.GetRemoteAddr()) go conn.Join(localConn, remoteConn) diff --git a/pkg/models/server.go b/pkg/models/server.go index b6bff36..7f58e5e 100644 --- a/pkg/models/server.go +++ b/pkg/models/server.go @@ -19,17 +19,19 @@ type ProxyServer struct { BindAddr string ListenPort int64 - Status int64 - Listener *conn.Listener // accept new connection from remote users - CtlMsgChan chan int64 // every time accept a new user conn, put "1" to the channel - CliConnChan chan *conn.Conn // get client conns from control goroutine - UserConnList *list.List // store user conns - Mutex sync.Mutex + Status int64 + Listener *conn.Listener // accept new connection from remote users + CtlMsgChan chan int64 // every time accept a new user conn, put "1" to the channel + StopBlockChan chan int64 // put any number to the channel, if you want to stop wait user conn + CliConnChan chan *conn.Conn // get client conns from control goroutine + UserConnList *list.List // store user conns + Mutex sync.Mutex } func (p *ProxyServer) Init() { p.Status = Idle p.CtlMsgChan = make(chan int64) + p.StopBlockChan = make(chan int64) p.CliConnChan = make(chan *conn.Conn) p.UserConnList = list.New() } @@ -87,11 +89,13 @@ func (p *ProxyServer) Start() (err error) { p.UserConnList.Remove(element) } else { cliConn.Close() + p.Unlock() continue } p.Unlock() // msg will transfer to another without modifying + // l means local, r means remote log.Debug("Join two conns, (l[%s] r[%s]) (l[%s] r[%s])", cliConn.GetLocalAddr(), cliConn.GetRemoteAddr(), userConn.GetLocalAddr(), userConn.GetRemoteAddr()) go conn.Join(cliConn, userConn) @@ -110,7 +114,15 @@ func (p *ProxyServer) Close() { p.Unlock() } -func (p *ProxyServer) WaitUserConn() (res int64) { - res = <-p.CtlMsgChan - return +func (p *ProxyServer) WaitUserConn() (res int64, isStop bool) { + select { + case res = <-p.CtlMsgChan: + return res, false + case <-p.StopBlockChan: + return 0, true + } +} + +func (p *ProxyServer) StopWaitUserConn() { + p.StopBlockChan <- 1 } diff --git a/pkg/utils/conn/conn.go b/pkg/utils/conn/conn.go index f8e352f..5f65329 100644 --- a/pkg/utils/conn/conn.go +++ b/pkg/utils/conn/conn.go @@ -59,7 +59,9 @@ func (c *Conn) Write(content string) (err error) { } func (c *Conn) Close() { - c.TcpConn.Close() + if c.TcpConn != nil { + c.TcpConn.Close() + } } func Listen(bindAddr string, bindPort int64) (l *Listener, err error) { diff --git a/pkg/utils/pcrypto/pcrypto_test.go b/pkg/utils/pcrypto/pcrypto_test.go index f83c003..73377e3 100644 --- a/pkg/utils/pcrypto/pcrypto_test.go +++ b/pkg/utils/pcrypto/pcrypto_test.go @@ -6,7 +6,7 @@ import ( "testing" ) -func Test_Encrypto(t *testing.T) { +func TestEncrypto(t *testing.T) { pp := new(Pcrypto) pp.Init([]byte("Hana")) res, err := pp.Encrypto([]byte("Just One Test!")) @@ -17,7 +17,7 @@ func Test_Encrypto(t *testing.T) { fmt.Printf("[%x]\n", res) } -func Test_Decrypto(t *testing.T) { +func TestDecrypto(t *testing.T) { pp := new(Pcrypto) pp.Init([]byte("Hana")) res, err := pp.Encrypto([]byte("Just One Test!")) @@ -33,13 +33,13 @@ func Test_Decrypto(t *testing.T) { fmt.Printf("[%s]\n", string(res)) } -func Test_PKCS7Padding(t *testing.T) { +func TestPKCS7Padding(t *testing.T) { ltt := []byte("Test_PKCS7Padding") ltt = PKCS7Padding(ltt, aes.BlockSize) fmt.Printf("[%x]\n", (ltt)) } -func Test_PKCS7UnPadding(t *testing.T) { +func TestPKCS7UnPadding(t *testing.T) { ltt := []byte("Test_PKCS7Padding") ltt = PKCS7Padding(ltt, aes.BlockSize) ltt = PKCS7UnPadding(ltt)