mirror of https://github.com/ehang-io/nps
Udp 多路复用 优化
parent
7d8b1d02e1
commit
05e66af647
|
@ -35,7 +35,7 @@ type Tunnel struct {
|
|||
signalList map[string]*list //通信
|
||||
tunnelList map[string]*list //隧道
|
||||
lock sync.Mutex
|
||||
tunnelLock sync.Mutex
|
||||
tunnelLock sync.Mutex
|
||||
}
|
||||
|
||||
func newTunnel(tunnelPort int) *Tunnel {
|
||||
|
@ -181,6 +181,7 @@ func (s *Tunnel) ReturnSignal(conn *Conn, cFlag string) {
|
|||
//重回slice 复用
|
||||
func (s *Tunnel) ReturnTunnel(conn *Conn, cFlag string) {
|
||||
if v, ok := s.tunnelList[cFlag]; ok {
|
||||
FlushConn(conn.conn)
|
||||
v.Add(conn)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -123,6 +123,7 @@ re:
|
|||
relay(c.conn, server, en, crypt, mux)
|
||||
end:
|
||||
if mux {
|
||||
FlushConn(conn)
|
||||
goto re
|
||||
} else {
|
||||
c.Close()
|
||||
|
|
12
lib/crypt.go
12
lib/crypt.go
|
@ -6,6 +6,7 @@ import (
|
|||
"crypto/cipher"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"github.com/pkg/errors"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
@ -38,9 +39,9 @@ func AesDecrypt(crypted, key []byte) ([]byte, error) {
|
|||
origData := make([]byte, len(crypted))
|
||||
// origData := crypted
|
||||
blockMode.CryptBlocks(origData, crypted)
|
||||
origData = PKCS5UnPadding(origData)
|
||||
err, origData = PKCS5UnPadding(origData)
|
||||
// origData = ZeroUnPadding(origData)
|
||||
return origData, nil
|
||||
return origData, err
|
||||
}
|
||||
|
||||
//补全
|
||||
|
@ -51,11 +52,14 @@ func PKCS5Padding(ciphertext []byte, blockSize int) []byte {
|
|||
}
|
||||
|
||||
//去补
|
||||
func PKCS5UnPadding(origData []byte) []byte {
|
||||
func PKCS5UnPadding(origData []byte) (error, []byte) {
|
||||
length := len(origData)
|
||||
// 去掉最后一个字节 unpadding 次
|
||||
unpadding := int(origData[length-1])
|
||||
return origData[:(length - unpadding)]
|
||||
if (length - unpadding) < 0 {
|
||||
return errors.New("len error"), nil
|
||||
}
|
||||
return nil, origData[:(length - unpadding)]
|
||||
}
|
||||
|
||||
//生成32位md5字串
|
||||
|
|
|
@ -190,6 +190,8 @@ func (s *TunnelModeServer) dealClient(c *Conn, cnf *ServerConfig, addr string, m
|
|||
defer func() {
|
||||
if cnf.Mux {
|
||||
s.bridge.ReturnTunnel(link, getverifyval(cnf.VerifyKey))
|
||||
} else {
|
||||
c.Close()
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
|
@ -212,8 +214,6 @@ func (s *TunnelModeServer) dealClient(c *Conn, cnf *ServerConfig, addr string, m
|
|||
}
|
||||
go relay(link.conn, c.conn, cnf.CompressEncode, cnf.Crypt, cnf.Mux)
|
||||
relay(c.conn, link.conn, cnf.CompressDecode, cnf.Crypt, cnf.Mux)
|
||||
} else {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
35
lib/udp.go
35
lib/udp.go
|
@ -5,7 +5,6 @@ import (
|
|||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UdpModeServer struct {
|
||||
|
@ -54,25 +53,25 @@ func (s *UdpModeServer) process(addr *net.UDPAddr, data []byte) {
|
|||
conn.Close()
|
||||
return
|
||||
}
|
||||
conn.WriteTo(data, s.config.CompressEncode, s.config.Crypt)
|
||||
if flag, err := conn.ReadFlag(); err == nil {
|
||||
if flag == CONN_SUCCESS {
|
||||
go func(addr *net.UDPAddr, conn *Conn) {
|
||||
defer func() {
|
||||
if s.config.Mux {
|
||||
s.bridge.ReturnTunnel(conn, getverifyval(s.config.VerifyKey))
|
||||
}
|
||||
}()
|
||||
buf := make([]byte, 1024)
|
||||
conn.conn.SetReadDeadline(time.Now().Add(time.Duration(time.Second * 3)))
|
||||
n, err := conn.ReadFrom(buf, s.config.CompressDecode, s.config.Crypt)
|
||||
if err != nil || err == io.EOF {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
s.listener.WriteToUDP(buf[:n], addr)
|
||||
defer func() {
|
||||
if s.config.Mux {
|
||||
s.bridge.ReturnTunnel(conn, getverifyval(s.config.VerifyKey))
|
||||
} else {
|
||||
conn.Close()
|
||||
}(addr, conn)
|
||||
}
|
||||
}()
|
||||
if flag == CONN_SUCCESS {
|
||||
conn.WriteTo(data, s.config.CompressEncode, s.config.Crypt)
|
||||
buf := make([]byte, 1024)
|
||||
//conn.conn.SetReadDeadline(time.Now().Add(time.Duration(time.Second * 3)))
|
||||
n, err := conn.ReadFrom(buf, s.config.CompressDecode, s.config.Crypt)
|
||||
if err != nil || err == io.EOF {
|
||||
log.Println("revieve error:", err)
|
||||
return
|
||||
}
|
||||
s.listener.WriteToUDP(buf[:n], addr)
|
||||
conn.WriteTo([]byte(IO_EOF), s.config.CompressEncode, s.config.Crypt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
13
lib/util.go
13
lib/util.go
|
@ -17,6 +17,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -315,3 +316,15 @@ func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
|||
}
|
||||
return written, err
|
||||
}
|
||||
|
||||
//连接重置 清空缓存区
|
||||
func FlushConn(c net.Conn) {
|
||||
c.SetReadDeadline(time.Now().Add(time.Second * 3))
|
||||
buf := bufPool.Get().([]byte)
|
||||
for {
|
||||
if _, err := c.Read(buf); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
c.SetReadDeadline(time.Time{})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue