diff --git a/lib/bridge.go b/lib/bridge.go index dcca556..9c368f7 100755 --- a/lib/bridge.go +++ b/lib/bridge.go @@ -35,7 +35,7 @@ type Tunnel struct { signalList map[string]*list //通信 tunnelList map[string]*list //隧道 lock sync.Mutex - tunnelLock sync.Mutex + tunnelLock sync.Mutex } func newTunnel(tunnelPort int) *Tunnel { @@ -181,6 +181,7 @@ func (s *Tunnel) ReturnSignal(conn *Conn, cFlag string) { //重回slice 复用 func (s *Tunnel) ReturnTunnel(conn *Conn, cFlag string) { if v, ok := s.tunnelList[cFlag]; ok { + FlushConn(conn.conn) v.Add(conn) } } diff --git a/lib/client.go b/lib/client.go index fdb2dd3..5aa4d81 100755 --- a/lib/client.go +++ b/lib/client.go @@ -123,6 +123,7 @@ re: relay(c.conn, server, en, crypt, mux) end: if mux { + FlushConn(conn) goto re } else { c.Close() diff --git a/lib/crypt.go b/lib/crypt.go index 4d98224..1cfde74 100644 --- a/lib/crypt.go +++ b/lib/crypt.go @@ -6,6 +6,7 @@ import ( "crypto/cipher" "crypto/md5" "encoding/hex" + "github.com/pkg/errors" "math/rand" "time" ) @@ -38,9 +39,9 @@ func AesDecrypt(crypted, key []byte) ([]byte, error) { origData := make([]byte, len(crypted)) // origData := crypted blockMode.CryptBlocks(origData, crypted) - origData = PKCS5UnPadding(origData) + err, origData = PKCS5UnPadding(origData) // origData = ZeroUnPadding(origData) - return origData, nil + return origData, err } //补全 @@ -51,11 +52,14 @@ func PKCS5Padding(ciphertext []byte, blockSize int) []byte { } //去补 -func PKCS5UnPadding(origData []byte) []byte { +func PKCS5UnPadding(origData []byte) (error, []byte) { length := len(origData) // 去掉最后一个字节 unpadding 次 unpadding := int(origData[length-1]) - return origData[:(length - unpadding)] + if (length - unpadding) < 0 { + return errors.New("len error"), nil + } + return nil, origData[:(length - unpadding)] } //生成32位md5字串 diff --git a/lib/tcp.go b/lib/tcp.go index 468c339..17a40bd 100755 --- a/lib/tcp.go +++ b/lib/tcp.go @@ -190,6 +190,8 @@ func (s *TunnelModeServer) dealClient(c *Conn, cnf *ServerConfig, addr string, m defer func() { if cnf.Mux { s.bridge.ReturnTunnel(link, getverifyval(cnf.VerifyKey)) + } else { + c.Close() } }() if err != nil { @@ -212,8 +214,6 @@ func (s *TunnelModeServer) dealClient(c *Conn, cnf *ServerConfig, addr string, m } go relay(link.conn, c.conn, cnf.CompressEncode, cnf.Crypt, cnf.Mux) relay(c.conn, link.conn, cnf.CompressDecode, cnf.Crypt, cnf.Mux) - } else { - c.Close() } } return nil diff --git a/lib/udp.go b/lib/udp.go index 2d3baf7..d4a1b49 100755 --- a/lib/udp.go +++ b/lib/udp.go @@ -5,7 +5,6 @@ import ( "log" "net" "strings" - "time" ) type UdpModeServer struct { @@ -54,25 +53,25 @@ func (s *UdpModeServer) process(addr *net.UDPAddr, data []byte) { conn.Close() return } - conn.WriteTo(data, s.config.CompressEncode, s.config.Crypt) if flag, err := conn.ReadFlag(); err == nil { - if flag == CONN_SUCCESS { - go func(addr *net.UDPAddr, conn *Conn) { - defer func() { - if s.config.Mux { - s.bridge.ReturnTunnel(conn, getverifyval(s.config.VerifyKey)) - } - }() - buf := make([]byte, 1024) - conn.conn.SetReadDeadline(time.Now().Add(time.Duration(time.Second * 3))) - n, err := conn.ReadFrom(buf, s.config.CompressDecode, s.config.Crypt) - if err != nil || err == io.EOF { - conn.Close() - return - } - s.listener.WriteToUDP(buf[:n], addr) + defer func() { + if s.config.Mux { + s.bridge.ReturnTunnel(conn, getverifyval(s.config.VerifyKey)) + } else { conn.Close() - }(addr, conn) + } + }() + if flag == CONN_SUCCESS { + conn.WriteTo(data, s.config.CompressEncode, s.config.Crypt) + buf := make([]byte, 1024) + //conn.conn.SetReadDeadline(time.Now().Add(time.Duration(time.Second * 3))) + n, err := conn.ReadFrom(buf, s.config.CompressDecode, s.config.Crypt) + if err != nil || err == io.EOF { + log.Println("revieve error:", err) + return + } + s.listener.WriteToUDP(buf[:n], addr) + conn.WriteTo([]byte(IO_EOF), s.config.CompressEncode, s.config.Crypt) } } } diff --git a/lib/util.go b/lib/util.go index 82fc11b..5ef3c8b 100755 --- a/lib/util.go +++ b/lib/util.go @@ -17,6 +17,7 @@ import ( "strconv" "strings" "sync" + "time" ) var ( @@ -315,3 +316,15 @@ func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) { } return written, err } + +//连接重置 清空缓存区 +func FlushConn(c net.Conn) { + c.SetReadDeadline(time.Now().Add(time.Second * 3)) + buf := bufPool.Get().([]byte) + for { + if _, err := c.Read(buf); err != nil { + break + } + } + c.SetReadDeadline(time.Time{}) +}