diff --git a/client/proxy.go b/client/proxy.go index a5da19eb..109be9f8 100644 --- a/client/proxy.go +++ b/client/proxy.go @@ -412,6 +412,7 @@ func HandleTcpWorkConnection(localInfo *config.LocalSvrConf, proxyPlugin plugin. err error ) remote = workConn + defer remote.Close() if baseInfo.UseEncryption { remote, err = frpIo.WithEncryption(remote, encKey) if err != nil { @@ -433,7 +434,6 @@ func HandleTcpWorkConnection(localInfo *config.LocalSvrConf, proxyPlugin plugin. localConn, err := frpNet.ConnectServer("tcp", fmt.Sprintf("%s:%d", localInfo.LocalIp, localInfo.LocalPort)) if err != nil { workConn.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIp, localInfo.LocalPort, err) - remote.Close() return } diff --git a/server/proxy.go b/server/proxy.go index bfb9793a..554e8181 100644 --- a/server/proxy.go +++ b/server/proxy.go @@ -287,9 +287,18 @@ func (pxy *HttpProxy) GetRealConn() (workConn frpNet.Conn, err error) { rwc = frpIo.WithCompression(rwc) } workConn = frpNet.WrapReadWriteCloserToConn(rwc, tmpConn) + workConn = frpNet.WrapStatsConn(workConn, pxy.updateStatsAfterClosedConn) + StatsOpenConnection(pxy.GetName()) return } +func (pxy *HttpProxy) updateStatsAfterClosedConn(totalRead, totalWrite int64) { + name := pxy.GetName() + StatsCloseConnection(name) + StatsAddTrafficIn(name, totalWrite) + StatsAddTrafficOut(name, totalRead) +} + func (pxy *HttpProxy) Close() { pxy.BaseProxy.Close() for _, closeFn := range pxy.closeFuncs { diff --git a/utils/net/conn.go b/utils/net/conn.go index c1f6f462..78319cc7 100644 --- a/utils/net/conn.go +++ b/utils/net/conn.go @@ -174,3 +174,38 @@ func (sc *SharedConn) WriteBuff(buffer []byte) (err error) { _, err = sc.buf.Write(buffer) return err } + +type StatsConn struct { + Conn + + totalRead int64 + totalWrite int64 + statsFunc func(totalRead, totalWrite int64) +} + +func WrapStatsConn(conn Conn, statsFunc func(total, totalWrite int64)) *StatsConn { + return &StatsConn{ + Conn: conn, + statsFunc: statsFunc, + } +} + +func (statsConn *StatsConn) Read(p []byte) (n int, err error) { + n, err = statsConn.Conn.Read(p) + statsConn.totalRead += int64(n) + return +} + +func (statsConn *StatsConn) Write(p []byte) (n int, err error) { + n, err = statsConn.Conn.Write(p) + statsConn.totalWrite += int64(n) + return +} + +func (statsConn *StatsConn) Close() (err error) { + err = statsConn.Conn.Close() + if statsConn.statsFunc != nil { + statsConn.statsFunc(statsConn.totalRead, statsConn.totalWrite) + } + return +}