mirror of https://github.com/ehang-io/nps
parent
2e8af6f120
commit
59d789d253
21
README.md
21
README.md
|
@ -45,7 +45,7 @@ go语言编写,无第三方依赖,各个平台都已经编译在release中
|
|||
* [与nginx配合](#与nginx配合)
|
||||
* [关闭http|https代理](#关闭代理)
|
||||
* [将nps安装到系统](#将nps安装到系统)
|
||||
* 单隧道模式及介绍
|
||||
* 单隧道模式及介绍(即将移除)
|
||||
* [tcp隧道模式](#tcp隧道模式)
|
||||
* [udp隧道模式](#udp隧道模式)
|
||||
* [socks5代理模式](#socks5代理模式)
|
||||
|
@ -62,6 +62,7 @@ go语言编写,无第三方依赖,各个平台都已经编译在release中
|
|||
* [带宽限制](#带宽限制)
|
||||
* [负载均衡](#负载均衡)
|
||||
* [守护进程](#守护进程)
|
||||
* [KCP协议支持](#KCP协议支持)
|
||||
* [相关说明](#相关说明)
|
||||
* [流量统计](#流量统计)
|
||||
* [热更新支持](#热更新支持)
|
||||
|
@ -138,12 +139,13 @@ go语言编写,无第三方依赖,各个平台都已经编译在release中
|
|||
---|---
|
||||
httpport | web管理端口
|
||||
password | web界面管理密码
|
||||
tcpport | 服务端客户端通信端口
|
||||
bridePort | 服务端客户端通信端口
|
||||
pemPath | ssl certFile绝对路径
|
||||
keyPath | ssl keyFile绝对路径
|
||||
httpsProxyPort | 域名代理https代理监听端口
|
||||
httpProxyPort | 域名代理http代理监听端口
|
||||
authip|web api免验证IP地址
|
||||
bridgeType|客户端与服务端连接方式kcp或tcp
|
||||
|
||||
### 详细说明
|
||||
|
||||
|
@ -539,12 +541,23 @@ authip | 免验证ip,适用于web api
|
|||
### 守护进程
|
||||
本代理支持守护进程,使用示例如下,服务端客户端所有模式通用,支持linux,darwin,windows。
|
||||
```
|
||||
./(nps|npc) start|stop|restart|status xxxxxx
|
||||
./(nps|npc) start|stop|restart|status 若有其他参数可加其他参数
|
||||
```
|
||||
```
|
||||
(nps|npc).exe start|stop|restart|status xxxxxx
|
||||
(nps|npc).exe start|stop|restart|status 若有其他参数可加其他参数
|
||||
```
|
||||
|
||||
### KCP协议支持
|
||||
KCP 是一个快速可靠协议,能以比 TCP浪费10%-20%的带宽的代价,换取平均延迟降低 30%-40%,在弱网环境下对性能能有一定的提升。可在app.conf中修改bridgeType为kcp
|
||||
,设置后本代理将开启udp端口(bridgePort)
|
||||
|
||||
注意:当服务端为kcp时,客户端连接时也需要加上参数
|
||||
|
||||
```
|
||||
-type=kcp
|
||||
```
|
||||
|
||||
|
||||
## 相关说明
|
||||
|
||||
### 获取用户真实ip
|
||||
|
|
172
bridge/bridge.go
172
bridge/bridge.go
|
@ -2,71 +2,103 @@ package bridge
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/conn"
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"github.com/cnlh/nps/lib/kcp"
|
||||
"github.com/cnlh/nps/lib/lg"
|
||||
"github.com/cnlh/nps/lib/pool"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
tunnel *lib.Conn
|
||||
signal *lib.Conn
|
||||
linkMap map[int]*lib.Link
|
||||
tunnel *conn.Conn
|
||||
signal *conn.Conn
|
||||
linkMap map[int]*conn.Link
|
||||
linkStatusMap map[int]bool
|
||||
stop chan bool
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
type Bridge struct {
|
||||
TunnelPort int //通信隧道端口
|
||||
listener *net.TCPListener //server端监听
|
||||
Client map[int]*Client
|
||||
RunList map[int]interface{} //运行中的任务
|
||||
lock sync.Mutex
|
||||
tunnelLock sync.Mutex
|
||||
clientLock sync.Mutex
|
||||
func NewClient(t *conn.Conn, s *conn.Conn) *Client {
|
||||
return &Client{
|
||||
linkMap: make(map[int]*conn.Link),
|
||||
stop: make(chan bool),
|
||||
linkStatusMap: make(map[int]bool),
|
||||
signal: s,
|
||||
tunnel: t,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTunnel(tunnelPort int, runList map[int]interface{}) *Bridge {
|
||||
type Bridge struct {
|
||||
TunnelPort int //通信隧道端口
|
||||
tcpListener *net.TCPListener //server端监听
|
||||
kcpListener *kcp.Listener //server端监听
|
||||
Client map[int]*Client
|
||||
RunList map[int]interface{} //运行中的任务
|
||||
tunnelType string //bridge type kcp or tcp
|
||||
lock sync.Mutex
|
||||
tunnelLock sync.Mutex
|
||||
clientLock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewTunnel(tunnelPort int, runList map[int]interface{}, tunnelType string) *Bridge {
|
||||
t := new(Bridge)
|
||||
t.TunnelPort = tunnelPort
|
||||
t.Client = make(map[int]*Client)
|
||||
t.RunList = runList
|
||||
t.tunnelType = tunnelType
|
||||
return t
|
||||
}
|
||||
|
||||
func (s *Bridge) StartTunnel() error {
|
||||
var err error
|
||||
s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.TunnelPort, ""})
|
||||
if err != nil {
|
||||
return err
|
||||
if s.tunnelType == "kcp" {
|
||||
s.kcpListener, err = kcp.ListenWithOptions(":"+strconv.Itoa(s.TunnelPort), nil, 150, 3)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
c, err := s.kcpListener.AcceptKCP()
|
||||
conn.SetUdpSession(c)
|
||||
if err != nil {
|
||||
lg.Println(err)
|
||||
continue
|
||||
}
|
||||
go s.cliProcess(conn.NewConn(c))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
s.tcpListener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.TunnelPort, ""})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
c, err := s.tcpListener.Accept()
|
||||
if err != nil {
|
||||
lg.Println(err)
|
||||
continue
|
||||
}
|
||||
go s.cliProcess(conn.NewConn(c))
|
||||
}
|
||||
}()
|
||||
}
|
||||
go s.tunnelProcess()
|
||||
return nil
|
||||
}
|
||||
|
||||
//tcp server
|
||||
func (s *Bridge) tunnelProcess() error {
|
||||
var err error
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
lib.Println(err)
|
||||
continue
|
||||
}
|
||||
go s.cliProcess(lib.NewConn(conn))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
//验证失败,返回错误验证flag,并且关闭连接
|
||||
func (s *Bridge) verifyError(c *lib.Conn) {
|
||||
c.Write([]byte(lib.VERIFY_EER))
|
||||
func (s *Bridge) verifyError(c *conn.Conn) {
|
||||
c.Write([]byte(common.VERIFY_EER))
|
||||
c.Conn.Close()
|
||||
}
|
||||
|
||||
func (s *Bridge) cliProcess(c *lib.Conn) {
|
||||
c.SetReadDeadline(5)
|
||||
func (s *Bridge) cliProcess(c *conn.Conn) {
|
||||
c.SetReadDeadline(5, s.tunnelType)
|
||||
var buf []byte
|
||||
var err error
|
||||
if buf, err = c.ReadLen(32); err != nil {
|
||||
|
@ -74,9 +106,9 @@ func (s *Bridge) cliProcess(c *lib.Conn) {
|
|||
return
|
||||
}
|
||||
//验证
|
||||
id, err := lib.GetCsvDb().GetIdByVerifyKey(string(buf), c.Conn.RemoteAddr().String())
|
||||
id, err := file.GetCsvDb().GetIdByVerifyKey(string(buf), c.Conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
lib.Println("当前客户端连接校验错误,关闭此客户端:", c.Conn.RemoteAddr())
|
||||
lg.Println("当前客户端连接校验错误,关闭此客户端:", c.Conn.RemoteAddr())
|
||||
s.verifyError(c)
|
||||
return
|
||||
}
|
||||
|
@ -97,40 +129,39 @@ func (s *Bridge) closeClient(id int) {
|
|||
}
|
||||
|
||||
//tcp连接类型区分
|
||||
func (s *Bridge) typeDeal(typeVal string, c *lib.Conn, id int) {
|
||||
func (s *Bridge) typeDeal(typeVal string, c *conn.Conn, id int) {
|
||||
switch typeVal {
|
||||
case lib.WORK_MAIN:
|
||||
case common.WORK_MAIN:
|
||||
//客户端已经存在,下线
|
||||
s.clientLock.Lock()
|
||||
if _, ok := s.Client[id]; ok {
|
||||
s.clientLock.Unlock()
|
||||
s.closeClient(id)
|
||||
} else {
|
||||
s.clientLock.Unlock()
|
||||
}
|
||||
s.clientLock.Lock()
|
||||
|
||||
s.Client[id] = &Client{
|
||||
linkMap: make(map[int]*lib.Link),
|
||||
stop: make(chan bool),
|
||||
linkStatusMap: make(map[int]bool),
|
||||
}
|
||||
lib.Printf("客户端%d连接成功,地址为:%s", id, c.Conn.RemoteAddr())
|
||||
s.Client[id].signal = c
|
||||
s.clientLock.Unlock()
|
||||
go s.GetStatus(id)
|
||||
case lib.WORK_CHAN:
|
||||
s.clientLock.Lock()
|
||||
if v, ok := s.Client[id]; ok {
|
||||
s.clientLock.Unlock()
|
||||
v.tunnel = c
|
||||
if v.signal != nil {
|
||||
v.signal.WriteClose()
|
||||
}
|
||||
v.Lock()
|
||||
v.signal = c
|
||||
v.Unlock()
|
||||
} else {
|
||||
s.Client[id] = NewClient(nil, c)
|
||||
s.clientLock.Unlock()
|
||||
}
|
||||
lg.Printf("客户端%d连接成功,地址为:%s", id, c.Conn.RemoteAddr())
|
||||
go s.GetStatus(id)
|
||||
case common.WORK_CHAN:
|
||||
s.clientLock.Lock()
|
||||
if v, ok := s.Client[id]; ok {
|
||||
s.clientLock.Unlock()
|
||||
v.Lock()
|
||||
v.tunnel = c
|
||||
v.Unlock()
|
||||
} else {
|
||||
s.Client[id] = NewClient(c, nil)
|
||||
s.clientLock.Unlock()
|
||||
return
|
||||
}
|
||||
go s.clientCopy(id)
|
||||
}
|
||||
c.SetAlive()
|
||||
c.SetAlive(s.tunnelType)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -161,13 +192,13 @@ func (s *Bridge) waitStatus(clientId, id int) (bool) {
|
|||
return false
|
||||
}
|
||||
|
||||
func (s *Bridge) SendLinkInfo(clientId int, link *lib.Link) (tunnel *lib.Conn, err error) {
|
||||
func (s *Bridge) SendLinkInfo(clientId int, link *conn.Link) (tunnel *conn.Conn, err error) {
|
||||
s.clientLock.Lock()
|
||||
if v, ok := s.Client[clientId]; ok {
|
||||
s.clientLock.Unlock()
|
||||
v.signal.SendLinkInfo(link)
|
||||
if err != nil {
|
||||
lib.Println("send error:", err, link.Id)
|
||||
lg.Println("send error:", err, link.Id)
|
||||
s.DelClient(clientId)
|
||||
return
|
||||
}
|
||||
|
@ -192,7 +223,7 @@ func (s *Bridge) SendLinkInfo(clientId int, link *lib.Link) (tunnel *lib.Conn, e
|
|||
}
|
||||
|
||||
//得到一个tcp隧道
|
||||
func (s *Bridge) GetTunnel(id int, en, de int, crypt, mux bool) (conn *lib.Conn, err error) {
|
||||
func (s *Bridge) GetTunnel(id int, en, de int, crypt, mux bool) (conn *conn.Conn, err error) {
|
||||
s.clientLock.Lock()
|
||||
defer s.clientLock.Unlock()
|
||||
if v, ok := s.Client[id]; !ok {
|
||||
|
@ -204,7 +235,7 @@ func (s *Bridge) GetTunnel(id int, en, de int, crypt, mux bool) (conn *lib.Conn,
|
|||
}
|
||||
|
||||
//得到一个通信通道
|
||||
func (s *Bridge) GetSignal(id int) (conn *lib.Conn, err error) {
|
||||
func (s *Bridge) GetSignal(id int) (conn *conn.Conn, err error) {
|
||||
s.clientLock.Lock()
|
||||
defer s.clientLock.Unlock()
|
||||
if v, ok := s.Client[id]; !ok {
|
||||
|
@ -257,19 +288,19 @@ func (s *Bridge) clientCopy(clientId int) {
|
|||
for {
|
||||
if id, err := client.tunnel.GetLen(); err != nil {
|
||||
s.closeClient(clientId)
|
||||
lib.Println("读取msg id 错误", err, id)
|
||||
lg.Println("读取msg id 错误", err, id)
|
||||
break
|
||||
} else {
|
||||
client.Lock()
|
||||
if link, ok := client.linkMap[id]; ok {
|
||||
client.Unlock()
|
||||
if content, err := client.tunnel.GetMsgContent(link); err != nil {
|
||||
lib.PutBufPoolCopy(content)
|
||||
pool.PutBufPoolCopy(content)
|
||||
s.closeClient(clientId)
|
||||
lib.Println("read msg content error", err, "close client")
|
||||
lg.Println("read msg content error", err, "close client")
|
||||
break
|
||||
} else {
|
||||
if len(content) == len(lib.IO_EOF) && string(content) == lib.IO_EOF {
|
||||
if len(content) == len(common.IO_EOF) && string(content) == common.IO_EOF {
|
||||
if link.Conn != nil {
|
||||
link.Conn.Close()
|
||||
}
|
||||
|
@ -281,7 +312,7 @@ func (s *Bridge) clientCopy(clientId int) {
|
|||
}
|
||||
link.Flow.Add(0, len(content))
|
||||
}
|
||||
lib.PutBufPoolCopy(content)
|
||||
pool.PutBufPoolCopy(content)
|
||||
}
|
||||
} else {
|
||||
client.Unlock()
|
||||
|
@ -289,5 +320,4 @@ func (s *Bridge) clientCopy(clientId int) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
122
client/client.go
122
client/client.go
|
@ -1,30 +1,35 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"github.com/cnlh/nps/lib/conn"
|
||||
"github.com/cnlh/nps/lib/kcp"
|
||||
"github.com/cnlh/nps/lib/lg"
|
||||
"github.com/cnlh/nps/lib/pool"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TRPClient struct {
|
||||
svrAddr string
|
||||
linkMap map[int]*lib.Link
|
||||
stop chan bool
|
||||
tunnel *lib.Conn
|
||||
svrAddr string
|
||||
linkMap map[int]*conn.Link
|
||||
stop chan bool
|
||||
tunnel *conn.Conn
|
||||
bridgeConnType string
|
||||
sync.Mutex
|
||||
vKey string
|
||||
}
|
||||
|
||||
//new client
|
||||
func NewRPClient(svraddr string, vKey string) *TRPClient {
|
||||
func NewRPClient(svraddr string, vKey string, bridgeConnType string) *TRPClient {
|
||||
return &TRPClient{
|
||||
svrAddr: svraddr,
|
||||
linkMap: make(map[int]*lib.Link),
|
||||
stop: make(chan bool),
|
||||
tunnel: nil,
|
||||
Mutex: sync.Mutex{},
|
||||
vKey: vKey,
|
||||
svrAddr: svraddr,
|
||||
linkMap: make(map[int]*conn.Link),
|
||||
stop: make(chan bool),
|
||||
Mutex: sync.Mutex{},
|
||||
vKey: vKey,
|
||||
bridgeConnType: bridgeConnType,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,37 +41,44 @@ func (s *TRPClient) Start() error {
|
|||
|
||||
//新建
|
||||
func (s *TRPClient) NewConn() {
|
||||
var err error
|
||||
var c net.Conn
|
||||
retry:
|
||||
conn, err := net.Dial("tcp", s.svrAddr)
|
||||
if s.bridgeConnType == "tcp" {
|
||||
c, err = net.Dial("tcp", s.svrAddr)
|
||||
} else {
|
||||
var sess *kcp.UDPSession
|
||||
sess, err = kcp.DialWithOptions(s.svrAddr, nil, 150, 3)
|
||||
conn.SetUdpSession(sess)
|
||||
c = sess
|
||||
}
|
||||
if err != nil {
|
||||
lib.Println("连接服务端失败,五秒后将重连")
|
||||
lg.Println("连接服务端失败,五秒后将重连")
|
||||
time.Sleep(time.Second * 5)
|
||||
goto retry
|
||||
return
|
||||
}
|
||||
s.processor(lib.NewConn(conn))
|
||||
s.processor(conn.NewConn(c))
|
||||
}
|
||||
|
||||
//处理
|
||||
func (s *TRPClient) processor(c *lib.Conn) {
|
||||
c.SetAlive()
|
||||
if _, err := c.Write([]byte(lib.Getverifyval(s.vKey))); err != nil {
|
||||
func (s *TRPClient) processor(c *conn.Conn) {
|
||||
c.SetAlive(s.bridgeConnType)
|
||||
if _, err := c.Write([]byte(common.Getverifyval(s.vKey))); err != nil {
|
||||
return
|
||||
}
|
||||
c.WriteMain()
|
||||
|
||||
go s.dealChan()
|
||||
|
||||
for {
|
||||
flags, err := c.ReadFlag()
|
||||
if err != nil {
|
||||
lib.Println("服务端断开,正在重新连接")
|
||||
lg.Println("服务端断开,正在重新连接")
|
||||
break
|
||||
}
|
||||
switch flags {
|
||||
case lib.VERIFY_EER:
|
||||
lib.Fatalf("vKey:%s不正确,服务端拒绝连接,请检查", s.vKey)
|
||||
case lib.NEW_CONN:
|
||||
case common.VERIFY_EER:
|
||||
lg.Fatalf("vKey:%s不正确,服务端拒绝连接,请检查", s.vKey)
|
||||
case common.NEW_CONN:
|
||||
if link, err := c.GetLinkInfo(); err != nil {
|
||||
break
|
||||
} else {
|
||||
|
@ -75,54 +87,46 @@ func (s *TRPClient) processor(c *lib.Conn) {
|
|||
s.Unlock()
|
||||
go s.linkProcess(link, c)
|
||||
}
|
||||
case lib.RES_CLOSE:
|
||||
lib.Fatalln("该vkey被另一客户连接")
|
||||
case lib.RES_MSG:
|
||||
lib.Println("服务端返回错误,重新连接")
|
||||
case common.RES_CLOSE:
|
||||
lg.Fatalln("该vkey被另一客户连接")
|
||||
case common.RES_MSG:
|
||||
lg.Println("服务端返回错误,重新连接")
|
||||
break
|
||||
default:
|
||||
lib.Println("无法解析该错误,重新连接")
|
||||
lg.Println("无法解析该错误,重新连接")
|
||||
break
|
||||
}
|
||||
}
|
||||
s.stop <- true
|
||||
s.linkMap = make(map[int]*lib.Link)
|
||||
s.linkMap = make(map[int]*conn.Link)
|
||||
go s.NewConn()
|
||||
}
|
||||
func (s *TRPClient) linkProcess(link *lib.Link, c *lib.Conn) {
|
||||
func (s *TRPClient) linkProcess(link *conn.Link, c *conn.Conn) {
|
||||
//与目标建立连接
|
||||
server, err := net.DialTimeout(link.ConnType, link.Host, time.Second*3)
|
||||
|
||||
if err != nil {
|
||||
c.WriteFail(link.Id)
|
||||
lib.Println("connect to ", link.Host, "error:", err)
|
||||
lg.Println("connect to ", link.Host, "error:", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteSuccess(link.Id)
|
||||
|
||||
link.Conn = lib.NewConn(server)
|
||||
|
||||
link.Conn = conn.NewConn(server)
|
||||
buf := pool.BufPoolCopy.Get().([]byte)
|
||||
for {
|
||||
buf := lib.BufPoolCopy.Get().([]byte)
|
||||
if n, err := server.Read(buf); err != nil {
|
||||
lib.PutBufPoolCopy(buf)
|
||||
s.tunnel.SendMsg([]byte(lib.IO_EOF), link)
|
||||
s.tunnel.SendMsg([]byte(common.IO_EOF), link)
|
||||
break
|
||||
} else {
|
||||
if _, err := s.tunnel.SendMsg(buf[:n], link); err != nil {
|
||||
lib.PutBufPoolCopy(buf)
|
||||
c.Close()
|
||||
break
|
||||
}
|
||||
lib.PutBufPoolCopy(buf)
|
||||
//if link.ConnType == utils.CONN_UDP {
|
||||
// c.Close()
|
||||
// break
|
||||
//}
|
||||
}
|
||||
}
|
||||
|
||||
pool.PutBufPoolCopy(buf)
|
||||
s.Lock()
|
||||
delete(s.linkMap, link.Id)
|
||||
s.Unlock()
|
||||
|
@ -131,41 +135,50 @@ func (s *TRPClient) linkProcess(link *lib.Link, c *lib.Conn) {
|
|||
//隧道模式处理
|
||||
func (s *TRPClient) dealChan() {
|
||||
var err error
|
||||
//创建一个tcp连接
|
||||
conn, err := net.Dial("tcp", s.svrAddr)
|
||||
var c net.Conn
|
||||
var sess *kcp.UDPSession
|
||||
if s.bridgeConnType == "tcp" {
|
||||
c, err = net.Dial("tcp", s.svrAddr)
|
||||
} else {
|
||||
sess, err = kcp.DialWithOptions(s.svrAddr, nil, 10, 3)
|
||||
conn.SetUdpSession(sess)
|
||||
c = sess
|
||||
}
|
||||
if err != nil {
|
||||
lib.Println("connect to ", s.svrAddr, "error:", err)
|
||||
lg.Println("connect to ", s.svrAddr, "error:", err)
|
||||
return
|
||||
}
|
||||
//验证
|
||||
if _, err := conn.Write([]byte(lib.Getverifyval(s.vKey))); err != nil {
|
||||
lib.Println("connect to ", s.svrAddr, "error:", err)
|
||||
if _, err := c.Write([]byte(common.Getverifyval(s.vKey))); err != nil {
|
||||
lg.Println("connect to ", s.svrAddr, "error:", err)
|
||||
return
|
||||
}
|
||||
//默认长连接保持
|
||||
s.tunnel = lib.NewConn(conn)
|
||||
s.tunnel.SetAlive()
|
||||
s.tunnel = conn.NewConn(c)
|
||||
s.tunnel.SetAlive(s.bridgeConnType)
|
||||
//写标志
|
||||
s.tunnel.WriteChan()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
if id, err := s.tunnel.GetLen(); err != nil {
|
||||
lib.Println("get msg id error")
|
||||
lg.Println("get msg id error")
|
||||
break
|
||||
} else {
|
||||
s.Lock()
|
||||
if v, ok := s.linkMap[id]; ok {
|
||||
s.Unlock()
|
||||
if content, err := s.tunnel.GetMsgContent(v); err != nil {
|
||||
lib.Println("get msg content error:", err, id)
|
||||
lg.Println("get msg content error:", err, id)
|
||||
pool.PutBufPoolCopy(content)
|
||||
break
|
||||
} else {
|
||||
if len(content) == len(lib.IO_EOF) && string(content) == lib.IO_EOF {
|
||||
if len(content) == len(common.IO_EOF) && string(content) == common.IO_EOF {
|
||||
v.Conn.Close()
|
||||
} else if v.Conn != nil {
|
||||
v.Conn.Write(content)
|
||||
}
|
||||
pool.PutBufPoolCopy(content)
|
||||
}
|
||||
} else {
|
||||
s.Unlock()
|
||||
|
@ -175,5 +188,6 @@ func (s *TRPClient) dealChan() {
|
|||
}()
|
||||
select {
|
||||
case <-s.stop:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,8 +3,9 @@ package main
|
|||
import (
|
||||
"flag"
|
||||
"github.com/cnlh/nps/client"
|
||||
"github.com/cnlh/nps/lib"
|
||||
_ "github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/daemon"
|
||||
"github.com/cnlh/nps/lib/lg"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
@ -14,20 +15,21 @@ var (
|
|||
serverAddr = flag.String("server", "", "服务器地址ip:端口")
|
||||
verifyKey = flag.String("vkey", "", "验证密钥")
|
||||
logType = flag.String("log", "stdout", "日志输出方式(stdout|file)")
|
||||
connType = flag.String("type", "tcp", "与服务端建立连接方式(kcp|tcp)")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
lib.InitDaemon("npc")
|
||||
daemon.InitDaemon("npc", common.GetRunPath(), common.GetPidPath())
|
||||
if *logType == "stdout" {
|
||||
lib.InitLogFile("npc", true)
|
||||
lg.InitLogFile("npc", true, common.GetLogPath())
|
||||
} else {
|
||||
lib.InitLogFile("npc", false)
|
||||
lg.InitLogFile("npc", false, common.GetLogPath())
|
||||
}
|
||||
stop := make(chan int)
|
||||
for _, v := range strings.Split(*verifyKey, ",") {
|
||||
lib.Println("客户端启动,连接:", *serverAddr, " 验证令牌:", v)
|
||||
go client.NewRPClient(*serverAddr, v).Start()
|
||||
lg.Println("客户端启动,连接:", *serverAddr, " 验证令牌:", v)
|
||||
go client.NewRPClient(*serverAddr, v, *connType).Start()
|
||||
}
|
||||
<-stop
|
||||
}
|
||||
|
|
|
@ -2,8 +2,12 @@ package main
|
|||
|
||||
import (
|
||||
"flag"
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/beego"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"github.com/cnlh/nps/lib/daemon"
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"github.com/cnlh/nps/lib/install"
|
||||
"github.com/cnlh/nps/lib/lg"
|
||||
"github.com/cnlh/nps/server"
|
||||
_ "github.com/cnlh/nps/web/routers"
|
||||
"log"
|
||||
|
@ -28,58 +32,65 @@ var (
|
|||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
if len(os.Args) > 1 && os.Args[1] == "test" {
|
||||
server.TestServerConfig()
|
||||
log.Println("test ok, no error")
|
||||
return
|
||||
if len(os.Args) > 1 {
|
||||
switch os.Args[1] {
|
||||
case "test":
|
||||
server.TestServerConfig()
|
||||
log.Println("test ok, no error")
|
||||
return
|
||||
case "start", "restart", "stop", "status":
|
||||
daemon.InitDaemon("nps", common.GetRunPath(), common.GetPidPath())
|
||||
case "install":
|
||||
install.InstallNps()
|
||||
return
|
||||
}
|
||||
}
|
||||
lib.InitDaemon("nps")
|
||||
if *logType == "stdout" {
|
||||
lib.InitLogFile("nps", true)
|
||||
lg.InitLogFile("nps", true, common.GetLogPath())
|
||||
} else {
|
||||
lib.InitLogFile("nps", false)
|
||||
lg.InitLogFile("nps", false, common.GetLogPath())
|
||||
}
|
||||
task := &lib.Tunnel{
|
||||
task := &file.Tunnel{
|
||||
TcpPort: *httpPort,
|
||||
Mode: *rpMode,
|
||||
Target: *tunnelTarget,
|
||||
Config: &lib.Config{
|
||||
Config: &file.Config{
|
||||
U: *u,
|
||||
P: *p,
|
||||
Compress: *compress,
|
||||
Crypt: lib.GetBoolByStr(*crypt),
|
||||
Crypt: common.GetBoolByStr(*crypt),
|
||||
},
|
||||
Flow: &lib.Flow{},
|
||||
Flow: &file.Flow{},
|
||||
UseClientCnf: false,
|
||||
}
|
||||
if *VerifyKey != "" {
|
||||
c := &lib.Client{
|
||||
c := &file.Client{
|
||||
Id: 0,
|
||||
VerifyKey: *VerifyKey,
|
||||
Addr: "",
|
||||
Remark: "",
|
||||
Status: true,
|
||||
IsConnect: false,
|
||||
Cnf: &lib.Config{},
|
||||
Flow: &lib.Flow{},
|
||||
Cnf: &file.Config{},
|
||||
Flow: &file.Flow{},
|
||||
}
|
||||
c.Cnf.CompressDecode, c.Cnf.CompressEncode = lib.GetCompressType(c.Cnf.Compress)
|
||||
lib.GetCsvDb().Clients[0] = c
|
||||
c.Cnf.CompressDecode, c.Cnf.CompressEncode = common.GetCompressType(c.Cnf.Compress)
|
||||
file.GetCsvDb().Clients[0] = c
|
||||
task.Client = c
|
||||
}
|
||||
if *TcpPort == 0 {
|
||||
p, err := beego.AppConfig.Int("tcpport")
|
||||
p, err := beego.AppConfig.Int("bridgePort")
|
||||
if err == nil && *rpMode == "webServer" {
|
||||
*TcpPort = p
|
||||
} else {
|
||||
*TcpPort = 8284
|
||||
}
|
||||
}
|
||||
lib.Println("服务端启动,监听tcp服务端端口:", *TcpPort)
|
||||
task.Config.CompressDecode, task.Config.CompressEncode = lib.GetCompressType(task.Config.Compress)
|
||||
lg.Printf("服务端启动,监听%s服务端口:%d", beego.AppConfig.String("bridgeType"), *TcpPort)
|
||||
task.Config.CompressDecode, task.Config.CompressEncode = common.GetCompressType(task.Config.Compress)
|
||||
if *rpMode != "webServer" {
|
||||
lib.GetCsvDb().Tasks[0] = task
|
||||
file.GetCsvDb().Tasks[0] = task
|
||||
}
|
||||
beego.LoadAppConfig("ini", filepath.Join(lib.GetRunPath(), "conf", "app.conf"))
|
||||
server.StartNewServer(*TcpPort, task)
|
||||
beego.LoadAppConfig("ini", filepath.Join(common.GetRunPath(), "conf", "app.conf"))
|
||||
server.StartNewServer(*TcpPort, task, beego.AppConfig.String("bridgeType"))
|
||||
}
|
||||
|
|
|
@ -1,28 +1,33 @@
|
|||
appname = nps
|
||||
|
||||
#web管理端口
|
||||
httpport = 8081
|
||||
#Web Management Port
|
||||
httpport = 8080
|
||||
|
||||
#启动模式dev|pro
|
||||
#Boot mode(dev|pro)
|
||||
runmode = dev
|
||||
|
||||
#web管理密码
|
||||
#Web Management Password
|
||||
password=123
|
||||
|
||||
##客户端与服务端通信端口
|
||||
tcpport=8284
|
||||
##Communication Port between Client and Server
|
||||
##If the data transfer mode is tcp, it is TCP port
|
||||
##If the data transfer mode is kcp, it is UDP port
|
||||
bridgePort=8284
|
||||
|
||||
#web api免验证IP地址
|
||||
#Web API unauthenticated IP address
|
||||
authip=127.0.0.1
|
||||
|
||||
##http代理端口,为空则不启动
|
||||
##HTTP proxy port, no startup if empty
|
||||
httpProxyPort=80
|
||||
|
||||
##https代理端口,为空则不启动
|
||||
##HTTPS proxy port, no startup if empty
|
||||
httpsProxyPort=
|
||||
|
||||
##certFile绝对路径
|
||||
##certFile absolute path
|
||||
pemPath=/etc/nginx/certificate.crt
|
||||
|
||||
##keyFile绝对路径
|
||||
keyPath=/etc/nginx/private.key
|
||||
##KeyFile absolute path
|
||||
keyPath=/etc/nginx/private.key
|
||||
|
||||
##Data transmission mode(kcp or tcp)
|
||||
bridgeType=tcp
|
|
@ -1 +1 @@
|
|||
1,ydiigrm4ghu7mym1,,true,,,0,,0,0
|
||||
1,ydiigrm4ghu7mym1,测试,true,,,0,,0,0
|
||||
|
|
|
|
@ -1 +1,2 @@
|
|||
a.o.com,127.0.0.1:8081,1,,,测试
|
||||
a.o.com,127.0.0.1:8080,1,,,测试
|
||||
b.o.com,127.0.0.1:8082,1,,,
|
||||
|
|
|
|
@ -1,4 +1,4 @@
|
|||
9001,tunnelServer,123.206.77.88:22,,,,1,0,0,0,1,1,true,测试tcp
|
||||
53,udpServer,114.114.114.114:53,,,,1,0,0,0,2,1,true,udp
|
||||
0,socks5Server,,,,,1,0,0,0,3,1,true,socks5
|
||||
9005,httpProxyServer,,,,,1,0,0,0,4,1,true,
|
||||
9002,socks5Server,,,,,1,0,0,0,3,1,true,socks5
|
||||
9001,tunnelServer,127.0.0.1:8082,,,,1,0,0,0,1,1,true,测试tcp
|
||||
|
|
|
|
@ -0,0 +1,28 @@
|
|||
package common
|
||||
|
||||
const (
|
||||
COMPRESS_NONE_ENCODE = iota
|
||||
COMPRESS_NONE_DECODE
|
||||
COMPRESS_SNAPY_ENCODE
|
||||
COMPRESS_SNAPY_DECODE
|
||||
VERIFY_EER = "vkey"
|
||||
WORK_MAIN = "main"
|
||||
WORK_CHAN = "chan"
|
||||
RES_SIGN = "sign"
|
||||
RES_MSG = "msg0"
|
||||
RES_CLOSE = "clse"
|
||||
NEW_CONN = "conn" //新连接标志
|
||||
NEW_TASK = "task" //新连接标志
|
||||
CONN_SUCCESS = "sucs"
|
||||
CONN_TCP = "tcp"
|
||||
CONN_UDP = "udp"
|
||||
UnauthorizedBytes = `HTTP/1.1 401 Unauthorized
|
||||
Content-Type: text/plain; charset=utf-8
|
||||
WWW-Authenticate: Basic realm="easyProxy"
|
||||
|
||||
401 Unauthorized`
|
||||
IO_EOF = "PROXYEOF"
|
||||
ConnectionFailBytes = `HTTP/1.1 404 Not Found
|
||||
|
||||
`
|
||||
)
|
|
@ -0,0 +1,67 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
//Get the currently selected configuration file directory
|
||||
//For non-Windows systems, select the /etc/nps as config directory if exist, or select ./
|
||||
//windows system, select the C:\Program Files\nps as config directory if exist, or select ./
|
||||
func GetRunPath() string {
|
||||
var path string
|
||||
if path = GetInstallPath(); !FileExists(path) {
|
||||
return "./"
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
//Different systems get different installation paths
|
||||
func GetInstallPath() string {
|
||||
var path string
|
||||
if IsWindows() {
|
||||
path = `C:\Program Files\nps`
|
||||
} else {
|
||||
path = "/etc/nps"
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
//Get the absolute path to the running directory
|
||||
func GetAppPath() string {
|
||||
if path, err := filepath.Abs(filepath.Dir(os.Args[0])); err == nil {
|
||||
return path
|
||||
}
|
||||
return os.Args[0]
|
||||
}
|
||||
|
||||
//Determine whether the current system is a Windows system?
|
||||
func IsWindows() bool {
|
||||
if runtime.GOOS == "windows" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
//interface log file path
|
||||
func GetLogPath() string {
|
||||
var path string
|
||||
if IsWindows() {
|
||||
path = "./"
|
||||
} else {
|
||||
path = "/tmp"
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
//interface pid file path
|
||||
func GetPidPath() string {
|
||||
var path string
|
||||
if IsWindows() {
|
||||
path = "./"
|
||||
} else {
|
||||
path = "/tmp"
|
||||
}
|
||||
return path
|
||||
}
|
|
@ -1,45 +1,21 @@
|
|||
package lib
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"github.com/cnlh/nps/lib/crypt"
|
||||
"github.com/cnlh/nps/lib/lg"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
COMPRESS_NONE_ENCODE = iota
|
||||
COMPRESS_NONE_DECODE
|
||||
COMPRESS_SNAPY_ENCODE
|
||||
COMPRESS_SNAPY_DECODE
|
||||
VERIFY_EER = "vkey"
|
||||
WORK_MAIN = "main"
|
||||
WORK_CHAN = "chan"
|
||||
RES_SIGN = "sign"
|
||||
RES_MSG = "msg0"
|
||||
RES_CLOSE = "clse"
|
||||
NEW_CONN = "conn" //新连接标志
|
||||
CONN_SUCCESS = "sucs"
|
||||
CONN_TCP = "tcp"
|
||||
CONN_UDP = "udp"
|
||||
UnauthorizedBytes = `HTTP/1.1 401 Unauthorized
|
||||
Content-Type: text/plain; charset=utf-8
|
||||
WWW-Authenticate: Basic realm="easyProxy"
|
||||
|
||||
401 Unauthorized`
|
||||
IO_EOF = "PROXYEOF"
|
||||
ConnectionFailBytes = `HTTP/1.1 404 Not Found
|
||||
|
||||
`
|
||||
)
|
||||
|
||||
//判断压缩方式
|
||||
//Judging Compression Mode
|
||||
func GetCompressType(compress string) (int, int) {
|
||||
switch compress {
|
||||
case "":
|
||||
|
@ -47,12 +23,12 @@ func GetCompressType(compress string) (int, int) {
|
|||
case "snappy":
|
||||
return COMPRESS_SNAPY_DECODE, COMPRESS_SNAPY_ENCODE
|
||||
default:
|
||||
Fatalln("数据压缩格式错误")
|
||||
lg.Fatalln("数据压缩格式错误")
|
||||
}
|
||||
return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
|
||||
}
|
||||
|
||||
//通过host获取对应的ip地址
|
||||
//Get the corresponding IP address through domain name
|
||||
func GetHostByName(hostname string) string {
|
||||
if !DomainCheck(hostname) {
|
||||
return hostname
|
||||
|
@ -68,7 +44,7 @@ func GetHostByName(hostname string) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
//检查是否是域名
|
||||
//Check the legality of domain
|
||||
func DomainCheck(domain string) bool {
|
||||
var match bool
|
||||
IsLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}(/)"
|
||||
|
@ -80,7 +56,7 @@ func DomainCheck(domain string) bool {
|
|||
return match
|
||||
}
|
||||
|
||||
//检查basic认证
|
||||
//Check if the Request request is validated
|
||||
func CheckAuth(r *http.Request, user, passwd string) bool {
|
||||
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
|
||||
if len(s) != 2 {
|
||||
|
@ -122,11 +98,12 @@ func GetIntNoErrByStr(str string) int {
|
|||
return i
|
||||
}
|
||||
|
||||
//简单的一个校验值
|
||||
//Get verify value
|
||||
func Getverifyval(vkey string) string {
|
||||
return Md5(vkey)
|
||||
return crypt.Md5(vkey)
|
||||
}
|
||||
|
||||
//Change headers and host of request
|
||||
func ChangeHostAndHeader(r *http.Request, host string, header string, addr string) {
|
||||
if host != "" {
|
||||
r.Host = host
|
||||
|
@ -145,6 +122,7 @@ func ChangeHostAndHeader(r *http.Request, host string, header string, addr strin
|
|||
r.Header.Set("X-Real-IP", addr)
|
||||
}
|
||||
|
||||
//Read file content by file path
|
||||
func ReadAllFromFile(filePath string) ([]byte, error) {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
|
@ -163,53 +141,7 @@ func FileExists(name string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func GetRunPath() string {
|
||||
var path string
|
||||
if path = GetInstallPath(); !FileExists(path) {
|
||||
return "./"
|
||||
}
|
||||
return path
|
||||
}
|
||||
func GetInstallPath() string {
|
||||
var path string
|
||||
if IsWindows() {
|
||||
path = `C:\Program Files\nps`
|
||||
} else {
|
||||
path = "/etc/nps"
|
||||
}
|
||||
return path
|
||||
}
|
||||
func GetAppPath() string {
|
||||
if path, err := filepath.Abs(filepath.Dir(os.Args[0])); err == nil {
|
||||
return path
|
||||
}
|
||||
return os.Args[0]
|
||||
}
|
||||
func IsWindows() bool {
|
||||
if runtime.GOOS == "windows" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
func GetLogPath() string {
|
||||
var path string
|
||||
if IsWindows() {
|
||||
path = "./"
|
||||
} else {
|
||||
path = "/tmp"
|
||||
}
|
||||
return path
|
||||
}
|
||||
func GetPidPath() string {
|
||||
var path string
|
||||
if IsWindows() {
|
||||
path = "./"
|
||||
} else {
|
||||
path = "/tmp"
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
//Judge whether the TCP port can open normally
|
||||
func TestTcpPort(port int) bool {
|
||||
l, err := net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), port, ""})
|
||||
defer l.Close()
|
||||
|
@ -218,3 +150,27 @@ func TestTcpPort(port int) bool {
|
|||
}
|
||||
return true
|
||||
}
|
||||
|
||||
//Judge whether the UDP port can open normally
|
||||
func TestUdpPort(port int) bool {
|
||||
l, err := net.ListenUDP("udp", &net.UDPAddr{net.ParseIP("0.0.0.0"), port, ""})
|
||||
defer l.Close()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
//Write length and individual byte data
|
||||
//Length prevents sticking
|
||||
//# Characters are used to separate data
|
||||
func BinaryWrite(raw *bytes.Buffer, v ...string) {
|
||||
buffer := new(bytes.Buffer)
|
||||
var l int32
|
||||
for _, v := range v {
|
||||
l += int32(len([]byte(v))) + int32(len([]byte("#")))
|
||||
binary.Write(buffer, binary.LittleEndian, []byte(v))
|
||||
binary.Write(buffer, binary.LittleEndian, []byte("#"))
|
||||
}
|
||||
binary.Write(raw, binary.LittleEndian, buffer.Bytes())
|
||||
}
|
|
@ -1,11 +1,15 @@
|
|||
package lib
|
||||
package conn
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/golang/snappy"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"github.com/cnlh/nps/lib/kcp"
|
||||
"github.com/cnlh/nps/lib/pool"
|
||||
"github.com/cnlh/nps/lib/rate"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -18,126 +22,6 @@ import (
|
|||
|
||||
const cryptKey = "1234567812345678"
|
||||
|
||||
type CryptConn struct {
|
||||
conn net.Conn
|
||||
crypt bool
|
||||
rate *Rate
|
||||
}
|
||||
|
||||
func NewCryptConn(conn net.Conn, crypt bool, rate *Rate) *CryptConn {
|
||||
c := new(CryptConn)
|
||||
c.conn = conn
|
||||
c.crypt = crypt
|
||||
c.rate = rate
|
||||
return c
|
||||
}
|
||||
|
||||
//加密写
|
||||
func (s *CryptConn) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
if s.crypt {
|
||||
if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if b, err = GetLenBytes(b); err != nil {
|
||||
return
|
||||
}
|
||||
_, err = s.conn.Write(b)
|
||||
if s.rate != nil {
|
||||
s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//解密读
|
||||
func (s *CryptConn) Read(b []byte) (n int, err error) {
|
||||
var lens int
|
||||
var buf []byte
|
||||
var rb []byte
|
||||
c := NewConn(s.conn)
|
||||
if lens, err = c.GetLen(); err != nil {
|
||||
return
|
||||
}
|
||||
if buf, err = c.ReadLen(lens); err != nil {
|
||||
return
|
||||
}
|
||||
if s.crypt {
|
||||
if rb, err = AesDecrypt(buf, []byte(cryptKey)); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
rb = buf
|
||||
}
|
||||
copy(b, rb)
|
||||
n = len(rb)
|
||||
if s.rate != nil {
|
||||
s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type SnappyConn struct {
|
||||
w *snappy.Writer
|
||||
r *snappy.Reader
|
||||
crypt bool
|
||||
rate *Rate
|
||||
}
|
||||
|
||||
func NewSnappyConn(conn net.Conn, crypt bool, rate *Rate) *SnappyConn {
|
||||
c := new(SnappyConn)
|
||||
c.w = snappy.NewBufferedWriter(conn)
|
||||
c.r = snappy.NewReader(conn)
|
||||
c.crypt = crypt
|
||||
c.rate = rate
|
||||
return c
|
||||
}
|
||||
|
||||
//snappy压缩写 包含加密
|
||||
func (s *SnappyConn) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
if s.crypt {
|
||||
if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
|
||||
Println("encode crypt error:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if _, err = s.w.Write(b); err != nil {
|
||||
return
|
||||
}
|
||||
if err = s.w.Flush(); err != nil {
|
||||
return
|
||||
}
|
||||
if s.rate != nil {
|
||||
s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//snappy压缩读 包含解密
|
||||
func (s *SnappyConn) Read(b []byte) (n int, err error) {
|
||||
buf := BufPool.Get().([]byte)
|
||||
defer BufPool.Put(buf)
|
||||
if n, err = s.r.Read(buf); err != nil {
|
||||
return
|
||||
}
|
||||
var bs []byte
|
||||
if s.crypt {
|
||||
if bs, err = AesDecrypt(buf[:n], []byte(cryptKey)); err != nil {
|
||||
Println("decode crypt error:", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
bs = buf[:n]
|
||||
}
|
||||
n = len(bs)
|
||||
copy(b, bs)
|
||||
if s.rate != nil {
|
||||
s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
Conn net.Conn
|
||||
sync.Mutex
|
||||
|
@ -186,16 +70,16 @@ func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.
|
|||
|
||||
//读取指定长度内容
|
||||
func (s *Conn) ReadLen(cLen int) ([]byte, error) {
|
||||
if cLen > poolSize {
|
||||
if cLen > pool.PoolSize {
|
||||
return nil, errors.New("长度错误" + strconv.Itoa(cLen))
|
||||
}
|
||||
var buf []byte
|
||||
if cLen <= poolSizeSmall {
|
||||
buf = BufPoolSmall.Get().([]byte)[:cLen]
|
||||
defer BufPoolSmall.Put(buf)
|
||||
if cLen <= pool.PoolSizeSmall {
|
||||
buf = pool.BufPoolSmall.Get().([]byte)[:cLen]
|
||||
defer pool.BufPoolSmall.Put(buf)
|
||||
} else {
|
||||
buf = BufPoolMax.Get().([]byte)[:cLen]
|
||||
defer BufPoolMax.Put(buf)
|
||||
buf = pool.BufPoolMax.Get().([]byte)[:cLen]
|
||||
defer pool.BufPoolMax.Put(buf)
|
||||
}
|
||||
if n, err := io.ReadFull(s, buf); err != nil || n != cLen {
|
||||
return buf, errors.New("读取指定长度错误" + err.Error())
|
||||
|
@ -231,35 +115,64 @@ func (s *Conn) GetConnStatus() (id int, status bool, err error) {
|
|||
if b, err = s.ReadLen(1); err != nil {
|
||||
return
|
||||
} else {
|
||||
status = GetBoolByStr(string(b[0]))
|
||||
status = common.GetBoolByStr(string(b[0]))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//设置连接为长连接
|
||||
func (s *Conn) SetAlive() {
|
||||
func (s *Conn) SetAlive(tp string) {
|
||||
if tp == "kcp" {
|
||||
s.setKcpAlive()
|
||||
} else {
|
||||
s.setTcpAlive()
|
||||
}
|
||||
}
|
||||
|
||||
//设置连接为长连接
|
||||
func (s *Conn) setTcpAlive() {
|
||||
conn := s.Conn.(*net.TCPConn)
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
conn.SetKeepAlive(true)
|
||||
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
|
||||
}
|
||||
|
||||
//设置连接为长连接
|
||||
func (s *Conn) setKcpAlive() {
|
||||
conn := s.Conn.(*kcp.UDPSession)
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
}
|
||||
|
||||
//设置连接为长连接
|
||||
func (s *Conn) SetReadDeadline(t time.Duration, tp string) {
|
||||
if tp == "kcp" {
|
||||
s.SetKcpReadDeadline(t)
|
||||
} else {
|
||||
s.SetTcpReadDeadline(t)
|
||||
}
|
||||
}
|
||||
|
||||
//set read dead time
|
||||
func (s *Conn) SetReadDeadline(t time.Duration) {
|
||||
func (s *Conn) SetTcpReadDeadline(t time.Duration) {
|
||||
s.Conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Duration(t) * time.Second))
|
||||
}
|
||||
|
||||
//set read dead time
|
||||
func (s *Conn) SetKcpReadDeadline(t time.Duration) {
|
||||
s.Conn.(*kcp.UDPSession).SetReadDeadline(time.Now().Add(time.Duration(t) * time.Second))
|
||||
}
|
||||
|
||||
//单独读(加密|压缩)
|
||||
func (s *Conn) ReadFrom(b []byte, compress int, crypt bool, rate *Rate) (int, error) {
|
||||
if COMPRESS_SNAPY_DECODE == compress {
|
||||
func (s *Conn) ReadFrom(b []byte, compress int, crypt bool, rate *rate.Rate) (int, error) {
|
||||
if common.COMPRESS_SNAPY_DECODE == compress {
|
||||
return NewSnappyConn(s.Conn, crypt, rate).Read(b)
|
||||
}
|
||||
return NewCryptConn(s.Conn, crypt, rate).Read(b)
|
||||
}
|
||||
|
||||
//单独写(加密|压缩)
|
||||
func (s *Conn) WriteTo(b []byte, compress int, crypt bool, rate *Rate) (n int, err error) {
|
||||
if COMPRESS_SNAPY_ENCODE == compress {
|
||||
func (s *Conn) WriteTo(b []byte, compress int, crypt bool, rate *rate.Rate) (n int, err error) {
|
||||
if common.COMPRESS_SNAPY_ENCODE == compress {
|
||||
return NewSnappyConn(s.Conn, crypt, rate).Write(b)
|
||||
}
|
||||
return NewCryptConn(s.Conn, crypt, rate).Write(b)
|
||||
|
@ -292,7 +205,7 @@ func (s *Conn) SendMsg(content []byte, link *Link) (n int, err error) {
|
|||
func (s *Conn) GetMsgContent(link *Link) (content []byte, err error) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
buf := BufPoolCopy.Get().([]byte)
|
||||
buf := pool.BufPoolCopy.Get().([]byte)
|
||||
if n, err := s.ReadFrom(buf, link.De, link.Crypt, link.Rate); err == nil && n > 4 {
|
||||
content = buf[:n]
|
||||
}
|
||||
|
@ -310,7 +223,7 @@ func (s *Conn) SendLinkInfo(link *Link) (int, error) {
|
|||
+----------+------+----------+------+----+----+------+
|
||||
*/
|
||||
raw := bytes.NewBuffer([]byte{})
|
||||
binary.Write(raw, binary.LittleEndian, []byte(NEW_CONN))
|
||||
binary.Write(raw, binary.LittleEndian, []byte(common.NEW_CONN))
|
||||
binary.Write(raw, binary.LittleEndian, int32(14+len(link.Host)))
|
||||
binary.Write(raw, binary.LittleEndian, int32(link.Id))
|
||||
binary.Write(raw, binary.LittleEndian, []byte(link.ConnType))
|
||||
|
@ -318,13 +231,13 @@ func (s *Conn) SendLinkInfo(link *Link) (int, error) {
|
|||
binary.Write(raw, binary.LittleEndian, []byte(link.Host))
|
||||
binary.Write(raw, binary.LittleEndian, []byte(strconv.Itoa(link.En)))
|
||||
binary.Write(raw, binary.LittleEndian, []byte(strconv.Itoa(link.De)))
|
||||
binary.Write(raw, binary.LittleEndian, []byte(GetStrByBool(link.Crypt)))
|
||||
binary.Write(raw, binary.LittleEndian, []byte(common.GetStrByBool(link.Crypt)))
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
return s.Write(raw.Bytes())
|
||||
}
|
||||
|
||||
func (s *Conn) GetLinkInfo() (link *Link, err error) {
|
||||
func (s *Conn) GetLinkInfo() (lk *Link, err error) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
var hostLen, n int
|
||||
|
@ -332,21 +245,69 @@ func (s *Conn) GetLinkInfo() (link *Link, err error) {
|
|||
if n, err = s.GetLen(); err != nil {
|
||||
return
|
||||
}
|
||||
link = new(Link)
|
||||
lk = new(Link)
|
||||
if buf, err = s.ReadLen(n); err != nil {
|
||||
return
|
||||
}
|
||||
if link.Id, err = GetLenByBytes(buf[:4]); err != nil {
|
||||
if lk.Id, err = GetLenByBytes(buf[:4]); err != nil {
|
||||
return
|
||||
}
|
||||
link.ConnType = string(buf[4:7])
|
||||
lk.ConnType = string(buf[4:7])
|
||||
if hostLen, err = GetLenByBytes(buf[7:11]); err != nil {
|
||||
return
|
||||
} else {
|
||||
link.Host = string(buf[11 : 11+hostLen])
|
||||
link.En = GetIntNoErrByStr(string(buf[11+hostLen]))
|
||||
link.De = GetIntNoErrByStr(string(buf[12+hostLen]))
|
||||
link.Crypt = GetBoolByStr(string(buf[13+hostLen]))
|
||||
lk.Host = string(buf[11 : 11+hostLen])
|
||||
lk.En = common.GetIntNoErrByStr(string(buf[11+hostLen]))
|
||||
lk.De = common.GetIntNoErrByStr(string(buf[12+hostLen]))
|
||||
lk.Crypt = common.GetBoolByStr(string(buf[13+hostLen]))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//send task info
|
||||
func (s *Conn) SendTaskInfo(t *file.Tunnel) (int, error) {
|
||||
/*
|
||||
The task info is formed as follows:
|
||||
+----+-----+---------+
|
||||
|type| len | content |
|
||||
+----+---------------+
|
||||
| 4 | 4 | ... |
|
||||
+----+---------------+
|
||||
*/
|
||||
raw := bytes.NewBuffer([]byte{})
|
||||
binary.Write(raw, binary.LittleEndian, common.NEW_TASK)
|
||||
common.BinaryWrite(raw, t.Mode, string(t.TcpPort), string(t.Target), string(t.Config.U), string(t.Config.P), common.GetStrByBool(t.Config.Crypt), t.Config.Compress, t.Remark)
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
return s.Write(raw.Bytes())
|
||||
}
|
||||
|
||||
//get task info
|
||||
func (s *Conn) GetTaskInfo() (t *file.Tunnel, err error) {
|
||||
var l int
|
||||
var b []byte
|
||||
if l, err = s.GetLen(); err != nil {
|
||||
return
|
||||
} else if b, err = s.ReadLen(l); err != nil {
|
||||
return
|
||||
} else {
|
||||
arr := strings.Split(string(b), "#")
|
||||
t.Mode = arr[0]
|
||||
t.TcpPort, _ = strconv.Atoi(arr[1])
|
||||
t.Target = arr[2]
|
||||
t.Config = new(file.Config)
|
||||
t.Config.U = arr[3]
|
||||
t.Config.P = arr[4]
|
||||
t.Config.Compress = arr[5]
|
||||
t.Config.CompressDecode, t.Config.CompressDecode = common.GetCompressType(arr[5])
|
||||
t.Id = file.GetCsvDb().GetTaskId()
|
||||
t.Status = true
|
||||
if t.Client, err = file.GetCsvDb().GetClient(0); err != nil {
|
||||
return
|
||||
}
|
||||
t.Flow = new(file.Flow)
|
||||
t.Remark = arr[6]
|
||||
t.UseClientCnf = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -388,31 +349,31 @@ func (s *Conn) Read(b []byte) (int, error) {
|
|||
|
||||
//write error
|
||||
func (s *Conn) WriteError() (int, error) {
|
||||
return s.Write([]byte(RES_MSG))
|
||||
return s.Write([]byte(common.RES_MSG))
|
||||
}
|
||||
|
||||
//write sign flag
|
||||
func (s *Conn) WriteSign() (int, error) {
|
||||
return s.Write([]byte(RES_SIGN))
|
||||
return s.Write([]byte(common.RES_SIGN))
|
||||
}
|
||||
|
||||
//write sign flag
|
||||
func (s *Conn) WriteClose() (int, error) {
|
||||
return s.Write([]byte(RES_CLOSE))
|
||||
return s.Write([]byte(common.RES_CLOSE))
|
||||
}
|
||||
|
||||
//write main
|
||||
func (s *Conn) WriteMain() (int, error) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
return s.Write([]byte(WORK_MAIN))
|
||||
return s.Write([]byte(common.WORK_MAIN))
|
||||
}
|
||||
|
||||
//write chan
|
||||
func (s *Conn) WriteChan() (int, error) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
return s.Write([]byte(WORK_CHAN))
|
||||
return s.Write([]byte(common.WORK_CHAN))
|
||||
}
|
||||
|
||||
//获取长度+内容
|
||||
|
@ -436,3 +397,13 @@ func GetLenByBytes(buf []byte) (int, error) {
|
|||
}
|
||||
return int(nlen), nil
|
||||
}
|
||||
|
||||
func SetUdpSession(sess *kcp.UDPSession) {
|
||||
sess.SetStreamMode(true)
|
||||
sess.SetWindowSize(1024, 1024)
|
||||
sess.SetReadBuffer(64 * 1024)
|
||||
sess.SetWriteBuffer(64 * 1024)
|
||||
sess.SetNoDelay(1, 10, 2, 1)
|
||||
sess.SetMtu(1600)
|
||||
sess.SetACKNoDelay(true)
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package conn
|
||||
|
||||
import (
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"github.com/cnlh/nps/lib/rate"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Link struct {
|
||||
Id int //id
|
||||
ConnType string //连接类型
|
||||
Host string //目标
|
||||
En int //加密
|
||||
De int //解密
|
||||
Crypt bool //加密
|
||||
Conn *Conn
|
||||
Flow *file.Flow
|
||||
UdpListener *net.UDPConn
|
||||
Rate *rate.Rate
|
||||
UdpRemoteAddr *net.UDPAddr
|
||||
}
|
||||
|
||||
func NewLink(id int, connType string, host string, en, de int, crypt bool, c *Conn, flow *file.Flow, udpListener *net.UDPConn, rate *rate.Rate, UdpRemoteAddr *net.UDPAddr) *Link {
|
||||
return &Link{
|
||||
Id: id,
|
||||
ConnType: connType,
|
||||
Host: host,
|
||||
En: en,
|
||||
De: de,
|
||||
Crypt: crypt,
|
||||
Conn: c,
|
||||
Flow: flow,
|
||||
UdpListener: udpListener,
|
||||
Rate: rate,
|
||||
UdpRemoteAddr: UdpRemoteAddr,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
package conn
|
||||
|
||||
import (
|
||||
"github.com/cnlh/nps/lib/crypt"
|
||||
"github.com/cnlh/nps/lib/rate"
|
||||
"net"
|
||||
)
|
||||
|
||||
type CryptConn struct {
|
||||
conn net.Conn
|
||||
crypt bool
|
||||
rate *rate.Rate
|
||||
}
|
||||
|
||||
func NewCryptConn(conn net.Conn, crypt bool, rate *rate.Rate) *CryptConn {
|
||||
c := new(CryptConn)
|
||||
c.conn = conn
|
||||
c.crypt = crypt
|
||||
c.rate = rate
|
||||
return c
|
||||
}
|
||||
|
||||
//加密写
|
||||
func (s *CryptConn) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
if s.crypt {
|
||||
if b, err = crypt.AesEncrypt(b, []byte(cryptKey)); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if b, err = GetLenBytes(b); err != nil {
|
||||
return
|
||||
}
|
||||
_, err = s.conn.Write(b)
|
||||
if s.rate != nil {
|
||||
s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//解密读
|
||||
func (s *CryptConn) Read(b []byte) (n int, err error) {
|
||||
var lens int
|
||||
var buf []byte
|
||||
var rb []byte
|
||||
c := NewConn(s.conn)
|
||||
if lens, err = c.GetLen(); err != nil {
|
||||
return
|
||||
}
|
||||
if buf, err = c.ReadLen(lens); err != nil {
|
||||
return
|
||||
}
|
||||
if s.crypt {
|
||||
if rb, err = crypt.AesDecrypt(buf, []byte(cryptKey)); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
rb = buf
|
||||
}
|
||||
copy(b, rb)
|
||||
n = len(rb)
|
||||
if s.rate != nil {
|
||||
s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
package conn
|
||||
|
||||
import (
|
||||
"github.com/cnlh/nps/lib/crypt"
|
||||
"github.com/cnlh/nps/lib/lg"
|
||||
"github.com/cnlh/nps/lib/pool"
|
||||
"github.com/cnlh/nps/lib/rate"
|
||||
"github.com/cnlh/nps/lib/snappy"
|
||||
"log"
|
||||
"net"
|
||||
)
|
||||
|
||||
type SnappyConn struct {
|
||||
w *snappy.Writer
|
||||
r *snappy.Reader
|
||||
crypt bool
|
||||
rate *rate.Rate
|
||||
}
|
||||
|
||||
func NewSnappyConn(conn net.Conn, crypt bool, rate *rate.Rate) *SnappyConn {
|
||||
c := new(SnappyConn)
|
||||
c.w = snappy.NewBufferedWriter(conn)
|
||||
c.r = snappy.NewReader(conn)
|
||||
c.crypt = crypt
|
||||
c.rate = rate
|
||||
return c
|
||||
}
|
||||
|
||||
//snappy压缩写 包含加密
|
||||
func (s *SnappyConn) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
if s.crypt {
|
||||
if b, err = crypt.AesEncrypt(b, []byte(cryptKey)); err != nil {
|
||||
lg.Println("encode crypt error:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if _, err = s.w.Write(b); err != nil {
|
||||
return
|
||||
}
|
||||
if err = s.w.Flush(); err != nil {
|
||||
return
|
||||
}
|
||||
if s.rate != nil {
|
||||
s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//snappy压缩读 包含解密
|
||||
func (s *SnappyConn) Read(b []byte) (n int, err error) {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
defer pool.BufPool.Put(buf)
|
||||
if n, err = s.r.Read(buf); err != nil {
|
||||
return
|
||||
}
|
||||
var bs []byte
|
||||
if s.crypt {
|
||||
if bs, err = crypt.AesDecrypt(buf[:n], []byte(cryptKey)); err != nil {
|
||||
log.Println("decode crypt error:", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
bs = buf[:n]
|
||||
}
|
||||
n = len(bs)
|
||||
copy(b, bs)
|
||||
if s.rate != nil {
|
||||
s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package lib
|
||||
package crypt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -37,21 +37,19 @@ func AesDecrypt(crypted, key []byte) ([]byte, error) {
|
|||
blockSize := block.BlockSize()
|
||||
blockMode := cipher.NewCBCDecrypter(block, key[:blockSize])
|
||||
origData := make([]byte, len(crypted))
|
||||
// origData := crypted
|
||||
blockMode.CryptBlocks(origData, crypted)
|
||||
err, origData = PKCS5UnPadding(origData)
|
||||
// origData = ZeroUnPadding(origData)
|
||||
return origData, err
|
||||
}
|
||||
|
||||
//补全
|
||||
//Completion when the length is insufficient
|
||||
func PKCS5Padding(ciphertext []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(ciphertext)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(ciphertext, padtext...)
|
||||
}
|
||||
|
||||
//去补
|
||||
//Remove excess
|
||||
func PKCS5UnPadding(origData []byte) (error, []byte) {
|
||||
length := len(origData)
|
||||
// 去掉最后一个字节 unpadding 次
|
||||
|
@ -62,14 +60,14 @@ func PKCS5UnPadding(origData []byte) (error, []byte) {
|
|||
return nil, origData[:(length - unpadding)]
|
||||
}
|
||||
|
||||
//生成32位md5字串
|
||||
//Generate 32-bit MD5 strings
|
||||
func Md5(s string) string {
|
||||
h := md5.New()
|
||||
h.Write([]byte(s))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
//生成随机验证密钥
|
||||
//Generating Random Verification Key
|
||||
func GetRandomString(l int) string {
|
||||
str := "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
bytes := []byte(str)
|
|
@ -1,6 +1,7 @@
|
|||
package lib
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
|
@ -10,7 +11,7 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
func InitDaemon(f string) {
|
||||
func InitDaemon(f string, runPath string, pidPath string) {
|
||||
if len(os.Args) < 2 {
|
||||
return
|
||||
}
|
||||
|
@ -22,22 +23,17 @@ func InitDaemon(f string) {
|
|||
args = append(args, "-log=file")
|
||||
switch os.Args[1] {
|
||||
case "start":
|
||||
start(args, f)
|
||||
start(args, f, pidPath, runPath)
|
||||
os.Exit(0)
|
||||
case "stop":
|
||||
stop(f, args[0])
|
||||
stop(f, args[0], pidPath)
|
||||
os.Exit(0)
|
||||
case "restart":
|
||||
stop(f, args[0])
|
||||
start(args, f)
|
||||
os.Exit(0)
|
||||
case "install":
|
||||
if f == "nps" {
|
||||
InstallNps()
|
||||
}
|
||||
stop(f, args[0], pidPath)
|
||||
start(args, f, pidPath, runPath)
|
||||
os.Exit(0)
|
||||
case "status":
|
||||
if status(f) {
|
||||
if status(f, pidPath) {
|
||||
log.Printf("%s is running", f)
|
||||
} else {
|
||||
log.Printf("%s is not running", f)
|
||||
|
@ -46,11 +42,11 @@ func InitDaemon(f string) {
|
|||
}
|
||||
}
|
||||
|
||||
func status(f string) bool {
|
||||
func status(f string, pidPath string) bool {
|
||||
var cmd *exec.Cmd
|
||||
b, err := ioutil.ReadFile(filepath.Join(GetPidPath(), f+".pid"))
|
||||
b, err := ioutil.ReadFile(filepath.Join(pidPath, f+".pid"))
|
||||
if err == nil {
|
||||
if !IsWindows() {
|
||||
if !common.IsWindows() {
|
||||
cmd = exec.Command("/bin/sh", "-c", "ps -ax | awk '{ print $1 }' | grep "+string(b))
|
||||
} else {
|
||||
cmd = exec.Command("tasklist", )
|
||||
|
@ -63,38 +59,38 @@ func status(f string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func start(osArgs []string, f string) {
|
||||
if status(f) {
|
||||
func start(osArgs []string, f string, pidPath, runPath string) {
|
||||
if status(f, pidPath) {
|
||||
log.Printf(" %s is running", f)
|
||||
return
|
||||
}
|
||||
cmd := exec.Command(osArgs[0], osArgs[1:]...)
|
||||
cmd.Start()
|
||||
if cmd.Process.Pid > 0 {
|
||||
log.Println("start ok , pid:", cmd.Process.Pid, "config path:", GetRunPath())
|
||||
log.Println("start ok , pid:", cmd.Process.Pid, "config path:", runPath)
|
||||
d1 := []byte(strconv.Itoa(cmd.Process.Pid))
|
||||
ioutil.WriteFile(filepath.Join(GetPidPath(), f+".pid"), d1, 0600)
|
||||
ioutil.WriteFile(filepath.Join(pidPath, f+".pid"), d1, 0600)
|
||||
} else {
|
||||
log.Println("start error")
|
||||
}
|
||||
}
|
||||
|
||||
func stop(f string, p string) {
|
||||
if !status(f) {
|
||||
func stop(f string, p string, pidPath string) {
|
||||
if !status(f, pidPath) {
|
||||
log.Printf(" %s is not running", f)
|
||||
return
|
||||
}
|
||||
var c *exec.Cmd
|
||||
var err error
|
||||
if IsWindows() {
|
||||
if common.IsWindows() {
|
||||
p := strings.Split(p, `\`)
|
||||
c = exec.Command("taskkill", "/F", "/IM", p[len(p)-1])
|
||||
} else {
|
||||
b, err := ioutil.ReadFile(filepath.Join(GetPidPath(), f+".pid"))
|
||||
b, err := ioutil.ReadFile(filepath.Join(pidPath, f+".pid"))
|
||||
if err == nil {
|
||||
c = exec.Command("/bin/bash", "-c", `kill -9 `+string(b))
|
||||
} else {
|
||||
log.Fatalln("stop error,PID file does not exist")
|
||||
log.Fatalln("stop error,pid file does not exist")
|
||||
}
|
||||
}
|
||||
err = c.Run()
|
|
@ -0,0 +1,19 @@
|
|||
package file
|
||||
|
||||
import (
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
CsvDb *Csv
|
||||
once sync.Once
|
||||
)
|
||||
//init csv from file
|
||||
func GetCsvDb() *Csv {
|
||||
once.Do(func() {
|
||||
CsvDb = NewCsv(common.GetRunPath())
|
||||
CsvDb.Init()
|
||||
})
|
||||
return CsvDb
|
||||
}
|
|
@ -1,8 +1,11 @@
|
|||
package lib
|
||||
package file
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"github.com/cnlh/nps/lib/lg"
|
||||
"github.com/cnlh/nps/lib/rate"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
@ -10,13 +13,10 @@ import (
|
|||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
CsvDb *Csv
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func NewCsv() *Csv {
|
||||
return new(Csv)
|
||||
func NewCsv(runPath string) *Csv {
|
||||
return &Csv{
|
||||
RunPath: runPath,
|
||||
}
|
||||
}
|
||||
|
||||
type Csv struct {
|
||||
|
@ -24,6 +24,7 @@ type Csv struct {
|
|||
Path string
|
||||
Hosts []*Host //域名列表
|
||||
Clients []*Client //客户端
|
||||
RunPath string //存储根目录
|
||||
ClientIncreaseId int //客户端id
|
||||
TaskIncreaseId int //任务自增ID
|
||||
sync.Mutex
|
||||
|
@ -37,9 +38,9 @@ func (s *Csv) Init() {
|
|||
|
||||
func (s *Csv) StoreTasksToCsv() {
|
||||
// 创建文件
|
||||
csvFile, err := os.Create(filepath.Join(GetRunPath(), "conf", "tasks.csv"))
|
||||
csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "tasks.csv"))
|
||||
if err != nil {
|
||||
Fatalf(err.Error())
|
||||
lg.Fatalf(err.Error())
|
||||
}
|
||||
defer csvFile.Close()
|
||||
writer := csv.NewWriter(csvFile)
|
||||
|
@ -51,8 +52,8 @@ func (s *Csv) StoreTasksToCsv() {
|
|||
task.Config.U,
|
||||
task.Config.P,
|
||||
task.Config.Compress,
|
||||
GetStrByBool(task.Status),
|
||||
GetStrByBool(task.Config.Crypt),
|
||||
common.GetStrByBool(task.Status),
|
||||
common.GetStrByBool(task.Config.Crypt),
|
||||
strconv.Itoa(task.Config.CompressEncode),
|
||||
strconv.Itoa(task.Config.CompressDecode),
|
||||
strconv.Itoa(task.Id),
|
||||
|
@ -62,7 +63,7 @@ func (s *Csv) StoreTasksToCsv() {
|
|||
}
|
||||
err := writer.Write(record)
|
||||
if err != nil {
|
||||
Fatalf(err.Error())
|
||||
lg.Fatalf(err.Error())
|
||||
}
|
||||
}
|
||||
writer.Flush()
|
||||
|
@ -87,33 +88,33 @@ func (s *Csv) openFile(path string) ([][]string, error) {
|
|||
}
|
||||
|
||||
func (s *Csv) LoadTaskFromCsv() {
|
||||
path := filepath.Join(GetRunPath(), "conf", "tasks.csv")
|
||||
path := filepath.Join(s.RunPath, "conf", "tasks.csv")
|
||||
records, err := s.openFile(path)
|
||||
if err != nil {
|
||||
Fatalln("配置文件打开错误:", path)
|
||||
lg.Fatalln("配置文件打开错误:", path)
|
||||
}
|
||||
var tasks []*Tunnel
|
||||
// 将每一行数据保存到内存slice中
|
||||
for _, item := range records {
|
||||
post := &Tunnel{
|
||||
TcpPort: GetIntNoErrByStr(item[0]),
|
||||
TcpPort: common.GetIntNoErrByStr(item[0]),
|
||||
Mode: item[1],
|
||||
Target: item[2],
|
||||
Config: &Config{
|
||||
U: item[3],
|
||||
P: item[4],
|
||||
Compress: item[5],
|
||||
Crypt: GetBoolByStr(item[7]),
|
||||
CompressEncode: GetIntNoErrByStr(item[8]),
|
||||
CompressDecode: GetIntNoErrByStr(item[9]),
|
||||
Crypt: common.GetBoolByStr(item[7]),
|
||||
CompressEncode: common.GetIntNoErrByStr(item[8]),
|
||||
CompressDecode: common.GetIntNoErrByStr(item[9]),
|
||||
},
|
||||
Status: GetBoolByStr(item[6]),
|
||||
Id: GetIntNoErrByStr(item[10]),
|
||||
UseClientCnf: GetBoolByStr(item[12]),
|
||||
Status: common.GetBoolByStr(item[6]),
|
||||
Id: common.GetIntNoErrByStr(item[10]),
|
||||
UseClientCnf: common.GetBoolByStr(item[12]),
|
||||
Remark: item[13],
|
||||
}
|
||||
post.Flow = new(Flow)
|
||||
if post.Client, err = s.GetClient(GetIntNoErrByStr(item[11])); err != nil {
|
||||
if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[11])); err != nil {
|
||||
continue
|
||||
}
|
||||
tasks = append(tasks, post)
|
||||
|
@ -135,7 +136,7 @@ func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (int, error) {
|
|||
s.Lock()
|
||||
defer s.Unlock()
|
||||
for _, v := range s.Clients {
|
||||
if Getverifyval(v.VerifyKey) == vKey && v.Status {
|
||||
if common.Getverifyval(v.VerifyKey) == vKey && v.Status {
|
||||
if arr := strings.Split(addr, ":"); len(arr) > 0 {
|
||||
v.Addr = arr[0]
|
||||
}
|
||||
|
@ -186,7 +187,7 @@ func (s *Csv) GetTask(id int) (v *Tunnel, err error) {
|
|||
|
||||
func (s *Csv) StoreHostToCsv() {
|
||||
// 创建文件
|
||||
csvFile, err := os.Create(filepath.Join(GetRunPath(), "conf", "hosts.csv"))
|
||||
csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "hosts.csv"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -214,24 +215,24 @@ func (s *Csv) StoreHostToCsv() {
|
|||
}
|
||||
|
||||
func (s *Csv) LoadClientFromCsv() {
|
||||
path := filepath.Join(GetRunPath(), "conf", "clients.csv")
|
||||
path := filepath.Join(s.RunPath, "conf", "clients.csv")
|
||||
records, err := s.openFile(path)
|
||||
if err != nil {
|
||||
Fatalln("配置文件打开错误:", path)
|
||||
lg.Fatalln("配置文件打开错误:", path)
|
||||
}
|
||||
var clients []*Client
|
||||
// 将每一行数据保存到内存slice中
|
||||
for _, item := range records {
|
||||
post := &Client{
|
||||
Id: GetIntNoErrByStr(item[0]),
|
||||
Id: common.GetIntNoErrByStr(item[0]),
|
||||
VerifyKey: item[1],
|
||||
Remark: item[2],
|
||||
Status: GetBoolByStr(item[3]),
|
||||
RateLimit: GetIntNoErrByStr(item[8]),
|
||||
Status: common.GetBoolByStr(item[3]),
|
||||
RateLimit: common.GetIntNoErrByStr(item[8]),
|
||||
Cnf: &Config{
|
||||
U: item[4],
|
||||
P: item[5],
|
||||
Crypt: GetBoolByStr(item[6]),
|
||||
Crypt: common.GetBoolByStr(item[6]),
|
||||
Compress: item[7],
|
||||
},
|
||||
}
|
||||
|
@ -239,21 +240,21 @@ func (s *Csv) LoadClientFromCsv() {
|
|||
s.ClientIncreaseId = post.Id
|
||||
}
|
||||
if post.RateLimit > 0 {
|
||||
post.Rate = NewRate(int64(post.RateLimit * 1024))
|
||||
post.Rate = rate.NewRate(int64(post.RateLimit * 1024))
|
||||
post.Rate.Start()
|
||||
}
|
||||
post.Flow = new(Flow)
|
||||
post.Flow.FlowLimit = int64(GetIntNoErrByStr(item[9]))
|
||||
post.Flow.FlowLimit = int64(common.GetIntNoErrByStr(item[9]))
|
||||
clients = append(clients, post)
|
||||
}
|
||||
s.Clients = clients
|
||||
}
|
||||
|
||||
func (s *Csv) LoadHostFromCsv() {
|
||||
path := filepath.Join(GetRunPath(), "conf", "hosts.csv")
|
||||
path := filepath.Join(s.RunPath, "conf", "hosts.csv")
|
||||
records, err := s.openFile(path)
|
||||
if err != nil {
|
||||
Fatalln("配置文件打开错误:", path)
|
||||
lg.Fatalln("配置文件打开错误:", path)
|
||||
}
|
||||
var hosts []*Host
|
||||
// 将每一行数据保存到内存slice中
|
||||
|
@ -265,7 +266,7 @@ func (s *Csv) LoadHostFromCsv() {
|
|||
HostChange: item[4],
|
||||
Remark: item[5],
|
||||
}
|
||||
if post.Client, err = s.GetClient(GetIntNoErrByStr(item[2])); err != nil {
|
||||
if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[2])); err != nil {
|
||||
continue
|
||||
}
|
||||
post.Flow = new(Flow)
|
||||
|
@ -387,11 +388,12 @@ func (s *Csv) GetClient(id int) (v *Client, err error) {
|
|||
err = errors.New("未找到")
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Csv) StoreClientsToCsv() {
|
||||
// 创建文件
|
||||
csvFile, err := os.Create(filepath.Join(GetRunPath(), "conf", "clients.csv"))
|
||||
csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "clients.csv"))
|
||||
if err != nil {
|
||||
Fatalln(err.Error())
|
||||
lg.Fatalln(err.Error())
|
||||
}
|
||||
defer csvFile.Close()
|
||||
writer := csv.NewWriter(csvFile)
|
||||
|
@ -403,24 +405,15 @@ func (s *Csv) StoreClientsToCsv() {
|
|||
strconv.FormatBool(client.Status),
|
||||
client.Cnf.U,
|
||||
client.Cnf.P,
|
||||
GetStrByBool(client.Cnf.Crypt),
|
||||
common.GetStrByBool(client.Cnf.Crypt),
|
||||
client.Cnf.Compress,
|
||||
strconv.Itoa(client.RateLimit),
|
||||
strconv.Itoa(int(client.Flow.FlowLimit)),
|
||||
}
|
||||
err := writer.Write(record)
|
||||
if err != nil {
|
||||
Fatalln(err.Error())
|
||||
lg.Fatalln(err.Error())
|
||||
}
|
||||
}
|
||||
writer.Flush()
|
||||
}
|
||||
|
||||
//init csv from file
|
||||
func GetCsvDb() *Csv {
|
||||
once.Do(func() {
|
||||
CsvDb = NewCsv()
|
||||
CsvDb.Init()
|
||||
})
|
||||
return CsvDb
|
||||
}
|
|
@ -1,41 +1,11 @@
|
|||
package lib
|
||||
package file
|
||||
|
||||
import (
|
||||
"net"
|
||||
"github.com/cnlh/nps/lib/rate"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Link struct {
|
||||
Id int //id
|
||||
ConnType string //连接类型
|
||||
Host string //目标
|
||||
En int //加密
|
||||
De int //解密
|
||||
Crypt bool //加密
|
||||
Conn *Conn
|
||||
Flow *Flow
|
||||
UdpListener *net.UDPConn
|
||||
Rate *Rate
|
||||
UdpRemoteAddr *net.UDPAddr
|
||||
}
|
||||
|
||||
func NewLink(id int, connType string, host string, en, de int, crypt bool, conn *Conn, flow *Flow, udpListener *net.UDPConn, rate *Rate, UdpRemoteAddr *net.UDPAddr) *Link {
|
||||
return &Link{
|
||||
Id: id,
|
||||
ConnType: connType,
|
||||
Host: host,
|
||||
En: en,
|
||||
De: de,
|
||||
Crypt: crypt,
|
||||
Conn: conn,
|
||||
Flow: flow,
|
||||
UdpListener: udpListener,
|
||||
Rate: rate,
|
||||
UdpRemoteAddr: UdpRemoteAddr,
|
||||
}
|
||||
}
|
||||
|
||||
type Flow struct {
|
||||
ExportFlow int64 //出口流量
|
||||
InletFlow int64 //入口流量
|
||||
|
@ -52,15 +22,15 @@ func (s *Flow) Add(in, out int) {
|
|||
|
||||
type Client struct {
|
||||
Cnf *Config
|
||||
Id int //id
|
||||
VerifyKey string //验证密钥
|
||||
Addr string //客户端ip地址
|
||||
Remark string //备注
|
||||
Status bool //是否开启
|
||||
IsConnect bool //是否连接
|
||||
RateLimit int //速度限制 /kb
|
||||
Flow *Flow //流量
|
||||
Rate *Rate //速度控制
|
||||
Id int //id
|
||||
VerifyKey string //验证密钥
|
||||
Addr string //客户端ip地址
|
||||
Remark string //备注
|
||||
Status bool //是否开启
|
||||
IsConnect bool //是否连接
|
||||
RateLimit int //速度限制 /kb
|
||||
Flow *Flow //流量
|
||||
Rate *rate.Rate //速度控制
|
||||
id int
|
||||
sync.RWMutex
|
||||
}
|
||||
|
@ -74,7 +44,7 @@ func (s *Client) GetId() int {
|
|||
|
||||
type Tunnel struct {
|
||||
Id int //Id
|
||||
TcpPort int //服务端与客户端通信端口
|
||||
TcpPort int //服务端监听端口
|
||||
Mode string //启动方式
|
||||
Target string //目标
|
||||
Status bool //是否开启
|
|
@ -1,8 +1,9 @@
|
|||
package lib
|
||||
package install
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
@ -11,22 +12,22 @@ import (
|
|||
)
|
||||
|
||||
func InstallNps() {
|
||||
path := GetInstallPath()
|
||||
path := common.GetInstallPath()
|
||||
MkidrDirAll(path, "conf", "web/static", "web/views")
|
||||
//复制文件到对应目录
|
||||
if err := CopyDir(filepath.Join(GetAppPath(), "web", "views"), filepath.Join(path, "web", "views")); err != nil {
|
||||
if err := CopyDir(filepath.Join(common.GetAppPath(), "web", "views"), filepath.Join(path, "web", "views")); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
if err := CopyDir(filepath.Join(GetAppPath(), "web", "static"), filepath.Join(path, "web", "static")); err != nil {
|
||||
if err := CopyDir(filepath.Join(common.GetAppPath(), "web", "static"), filepath.Join(path, "web", "static")); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
if err := CopyDir(filepath.Join(GetAppPath(), "conf"), filepath.Join(path, "conf")); err != nil {
|
||||
if err := CopyDir(filepath.Join(common.GetAppPath(), "conf"), filepath.Join(path, "conf")); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
if !IsWindows() {
|
||||
if _, err := copyFile(filepath.Join(GetAppPath(), "nps"), "/usr/bin/nps"); err != nil {
|
||||
if _, err := copyFile(filepath.Join(GetAppPath(), "nps"), "/usr/local/bin/nps"); err != nil {
|
||||
if !common.IsWindows() {
|
||||
if _, err := copyFile(filepath.Join(common.GetAppPath(), "nps"), "/usr/bin/nps"); err != nil {
|
||||
if _, err := copyFile(filepath.Join(common.GetAppPath(), "nps"), "/usr/local/bin/nps"); err != nil {
|
||||
log.Fatalln(err)
|
||||
} else {
|
||||
os.Chmod("/usr/local/bin/nps", 0777)
|
||||
|
@ -41,7 +42,7 @@ func InstallNps() {
|
|||
log.Println("install ok!")
|
||||
log.Println("Static files and configuration files in the current directory will be useless")
|
||||
log.Println("The new configuration file is located in", path, "you can edit them")
|
||||
if !IsWindows() {
|
||||
if !common.IsWindows() {
|
||||
log.Println("You can start with nps test|start|stop|restart|status anywhere")
|
||||
} else {
|
||||
log.Println("You can copy executable files to any directory and start working with nps.exe test|start|stop|restart|status")
|
|
@ -0,0 +1,785 @@
|
|||
package kcp
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/des"
|
||||
"crypto/sha1"
|
||||
|
||||
"github.com/templexxx/xor"
|
||||
"github.com/tjfoc/gmsm/sm4"
|
||||
|
||||
"golang.org/x/crypto/blowfish"
|
||||
"golang.org/x/crypto/cast5"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/crypto/salsa20"
|
||||
"golang.org/x/crypto/tea"
|
||||
"golang.org/x/crypto/twofish"
|
||||
"golang.org/x/crypto/xtea"
|
||||
)
|
||||
|
||||
var (
|
||||
initialVector = []byte{167, 115, 79, 156, 18, 172, 27, 1, 164, 21, 242, 193, 252, 120, 230, 107}
|
||||
saltxor = `sH3CIVoF#rWLtJo6`
|
||||
)
|
||||
|
||||
// BlockCrypt defines encryption/decryption methods for a given byte slice.
|
||||
// Notes on implementing: the data to be encrypted contains a builtin
|
||||
// nonce at the first 16 bytes
|
||||
type BlockCrypt interface {
|
||||
// Encrypt encrypts the whole block in src into dst.
|
||||
// Dst and src may point at the same memory.
|
||||
Encrypt(dst, src []byte)
|
||||
|
||||
// Decrypt decrypts the whole block in src into dst.
|
||||
// Dst and src may point at the same memory.
|
||||
Decrypt(dst, src []byte)
|
||||
}
|
||||
|
||||
type salsa20BlockCrypt struct {
|
||||
key [32]byte
|
||||
}
|
||||
|
||||
// NewSalsa20BlockCrypt https://en.wikipedia.org/wiki/Salsa20
|
||||
func NewSalsa20BlockCrypt(key []byte) (BlockCrypt, error) {
|
||||
c := new(salsa20BlockCrypt)
|
||||
copy(c.key[:], key)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *salsa20BlockCrypt) Encrypt(dst, src []byte) {
|
||||
salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key)
|
||||
copy(dst[:8], src[:8])
|
||||
}
|
||||
func (c *salsa20BlockCrypt) Decrypt(dst, src []byte) {
|
||||
salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key)
|
||||
copy(dst[:8], src[:8])
|
||||
}
|
||||
|
||||
type sm4BlockCrypt struct {
|
||||
encbuf [sm4.BlockSize]byte
|
||||
decbuf [2 * sm4.BlockSize]byte
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
// NewSM4BlockCrypt https://github.com/tjfoc/gmsm/tree/master/sm4
|
||||
func NewSM4BlockCrypt(key []byte) (BlockCrypt, error) {
|
||||
c := new(sm4BlockCrypt)
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.block = block
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *sm4BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
|
||||
func (c *sm4BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
|
||||
|
||||
type twofishBlockCrypt struct {
|
||||
encbuf [twofish.BlockSize]byte
|
||||
decbuf [2 * twofish.BlockSize]byte
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
// NewTwofishBlockCrypt https://en.wikipedia.org/wiki/Twofish
|
||||
func NewTwofishBlockCrypt(key []byte) (BlockCrypt, error) {
|
||||
c := new(twofishBlockCrypt)
|
||||
block, err := twofish.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.block = block
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *twofishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
|
||||
func (c *twofishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
|
||||
|
||||
type tripleDESBlockCrypt struct {
|
||||
encbuf [des.BlockSize]byte
|
||||
decbuf [2 * des.BlockSize]byte
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
// NewTripleDESBlockCrypt https://en.wikipedia.org/wiki/Triple_DES
|
||||
func NewTripleDESBlockCrypt(key []byte) (BlockCrypt, error) {
|
||||
c := new(tripleDESBlockCrypt)
|
||||
block, err := des.NewTripleDESCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.block = block
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *tripleDESBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
|
||||
func (c *tripleDESBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
|
||||
|
||||
type cast5BlockCrypt struct {
|
||||
encbuf [cast5.BlockSize]byte
|
||||
decbuf [2 * cast5.BlockSize]byte
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
// NewCast5BlockCrypt https://en.wikipedia.org/wiki/CAST-128
|
||||
func NewCast5BlockCrypt(key []byte) (BlockCrypt, error) {
|
||||
c := new(cast5BlockCrypt)
|
||||
block, err := cast5.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.block = block
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *cast5BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
|
||||
func (c *cast5BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
|
||||
|
||||
type blowfishBlockCrypt struct {
|
||||
encbuf [blowfish.BlockSize]byte
|
||||
decbuf [2 * blowfish.BlockSize]byte
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
// NewBlowfishBlockCrypt https://en.wikipedia.org/wiki/Blowfish_(cipher)
|
||||
func NewBlowfishBlockCrypt(key []byte) (BlockCrypt, error) {
|
||||
c := new(blowfishBlockCrypt)
|
||||
block, err := blowfish.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.block = block
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *blowfishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
|
||||
func (c *blowfishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
|
||||
|
||||
type aesBlockCrypt struct {
|
||||
encbuf [aes.BlockSize]byte
|
||||
decbuf [2 * aes.BlockSize]byte
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
// NewAESBlockCrypt https://en.wikipedia.org/wiki/Advanced_Encryption_Standard
|
||||
func NewAESBlockCrypt(key []byte) (BlockCrypt, error) {
|
||||
c := new(aesBlockCrypt)
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.block = block
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *aesBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
|
||||
func (c *aesBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
|
||||
|
||||
type teaBlockCrypt struct {
|
||||
encbuf [tea.BlockSize]byte
|
||||
decbuf [2 * tea.BlockSize]byte
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
// NewTEABlockCrypt https://en.wikipedia.org/wiki/Tiny_Encryption_Algorithm
|
||||
func NewTEABlockCrypt(key []byte) (BlockCrypt, error) {
|
||||
c := new(teaBlockCrypt)
|
||||
block, err := tea.NewCipherWithRounds(key, 16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.block = block
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *teaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
|
||||
func (c *teaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
|
||||
|
||||
type xteaBlockCrypt struct {
|
||||
encbuf [xtea.BlockSize]byte
|
||||
decbuf [2 * xtea.BlockSize]byte
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
// NewXTEABlockCrypt https://en.wikipedia.org/wiki/XTEA
|
||||
func NewXTEABlockCrypt(key []byte) (BlockCrypt, error) {
|
||||
c := new(xteaBlockCrypt)
|
||||
block, err := xtea.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.block = block
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *xteaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
|
||||
func (c *xteaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
|
||||
|
||||
type simpleXORBlockCrypt struct {
|
||||
xortbl []byte
|
||||
}
|
||||
|
||||
// NewSimpleXORBlockCrypt simple xor with key expanding
|
||||
func NewSimpleXORBlockCrypt(key []byte) (BlockCrypt, error) {
|
||||
c := new(simpleXORBlockCrypt)
|
||||
c.xortbl = pbkdf2.Key(key, []byte(saltxor), 32, mtuLimit, sha1.New)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *simpleXORBlockCrypt) Encrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) }
|
||||
func (c *simpleXORBlockCrypt) Decrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) }
|
||||
|
||||
type noneBlockCrypt struct{}
|
||||
|
||||
// NewNoneBlockCrypt does nothing but copying
|
||||
func NewNoneBlockCrypt(key []byte) (BlockCrypt, error) {
|
||||
return new(noneBlockCrypt), nil
|
||||
}
|
||||
|
||||
func (c *noneBlockCrypt) Encrypt(dst, src []byte) { copy(dst, src) }
|
||||
func (c *noneBlockCrypt) Decrypt(dst, src []byte) { copy(dst, src) }
|
||||
|
||||
// packet encryption with local CFB mode
|
||||
func encrypt(block cipher.Block, dst, src, buf []byte) {
|
||||
switch block.BlockSize() {
|
||||
case 8:
|
||||
encrypt8(block, dst, src, buf)
|
||||
case 16:
|
||||
encrypt16(block, dst, src, buf)
|
||||
default:
|
||||
encryptVariant(block, dst, src, buf)
|
||||
}
|
||||
}
|
||||
|
||||
// optimized encryption for the ciphers which works in 8-bytes
|
||||
func encrypt8(block cipher.Block, dst, src, buf []byte) {
|
||||
tbl := buf[:8]
|
||||
block.Encrypt(tbl, initialVector)
|
||||
n := len(src) / 8
|
||||
base := 0
|
||||
repeat := n / 8
|
||||
left := n % 8
|
||||
for i := 0; i < repeat; i++ {
|
||||
s := src[base:][0:64]
|
||||
d := dst[base:][0:64]
|
||||
// 1
|
||||
xor.BytesSrc1(d[0:8], s[0:8], tbl)
|
||||
block.Encrypt(tbl, d[0:8])
|
||||
// 2
|
||||
xor.BytesSrc1(d[8:16], s[8:16], tbl)
|
||||
block.Encrypt(tbl, d[8:16])
|
||||
// 3
|
||||
xor.BytesSrc1(d[16:24], s[16:24], tbl)
|
||||
block.Encrypt(tbl, d[16:24])
|
||||
// 4
|
||||
xor.BytesSrc1(d[24:32], s[24:32], tbl)
|
||||
block.Encrypt(tbl, d[24:32])
|
||||
// 5
|
||||
xor.BytesSrc1(d[32:40], s[32:40], tbl)
|
||||
block.Encrypt(tbl, d[32:40])
|
||||
// 6
|
||||
xor.BytesSrc1(d[40:48], s[40:48], tbl)
|
||||
block.Encrypt(tbl, d[40:48])
|
||||
// 7
|
||||
xor.BytesSrc1(d[48:56], s[48:56], tbl)
|
||||
block.Encrypt(tbl, d[48:56])
|
||||
// 8
|
||||
xor.BytesSrc1(d[56:64], s[56:64], tbl)
|
||||
block.Encrypt(tbl, d[56:64])
|
||||
base += 64
|
||||
}
|
||||
|
||||
switch left {
|
||||
case 7:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 8
|
||||
fallthrough
|
||||
case 6:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 8
|
||||
fallthrough
|
||||
case 5:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 8
|
||||
fallthrough
|
||||
case 4:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 8
|
||||
fallthrough
|
||||
case 3:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 8
|
||||
fallthrough
|
||||
case 2:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 8
|
||||
fallthrough
|
||||
case 1:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 8
|
||||
fallthrough
|
||||
case 0:
|
||||
xor.BytesSrc0(dst[base:], src[base:], tbl)
|
||||
}
|
||||
}
|
||||
|
||||
// optimized encryption for the ciphers which works in 16-bytes
|
||||
func encrypt16(block cipher.Block, dst, src, buf []byte) {
|
||||
tbl := buf[:16]
|
||||
block.Encrypt(tbl, initialVector)
|
||||
n := len(src) / 16
|
||||
base := 0
|
||||
repeat := n / 8
|
||||
left := n % 8
|
||||
for i := 0; i < repeat; i++ {
|
||||
s := src[base:][0:128]
|
||||
d := dst[base:][0:128]
|
||||
// 1
|
||||
xor.BytesSrc1(d[0:16], s[0:16], tbl)
|
||||
block.Encrypt(tbl, d[0:16])
|
||||
// 2
|
||||
xor.BytesSrc1(d[16:32], s[16:32], tbl)
|
||||
block.Encrypt(tbl, d[16:32])
|
||||
// 3
|
||||
xor.BytesSrc1(d[32:48], s[32:48], tbl)
|
||||
block.Encrypt(tbl, d[32:48])
|
||||
// 4
|
||||
xor.BytesSrc1(d[48:64], s[48:64], tbl)
|
||||
block.Encrypt(tbl, d[48:64])
|
||||
// 5
|
||||
xor.BytesSrc1(d[64:80], s[64:80], tbl)
|
||||
block.Encrypt(tbl, d[64:80])
|
||||
// 6
|
||||
xor.BytesSrc1(d[80:96], s[80:96], tbl)
|
||||
block.Encrypt(tbl, d[80:96])
|
||||
// 7
|
||||
xor.BytesSrc1(d[96:112], s[96:112], tbl)
|
||||
block.Encrypt(tbl, d[96:112])
|
||||
// 8
|
||||
xor.BytesSrc1(d[112:128], s[112:128], tbl)
|
||||
block.Encrypt(tbl, d[112:128])
|
||||
base += 128
|
||||
}
|
||||
|
||||
switch left {
|
||||
case 7:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 16
|
||||
fallthrough
|
||||
case 6:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 16
|
||||
fallthrough
|
||||
case 5:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 16
|
||||
fallthrough
|
||||
case 4:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 16
|
||||
fallthrough
|
||||
case 3:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 16
|
||||
fallthrough
|
||||
case 2:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 16
|
||||
fallthrough
|
||||
case 1:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += 16
|
||||
fallthrough
|
||||
case 0:
|
||||
xor.BytesSrc0(dst[base:], src[base:], tbl)
|
||||
}
|
||||
}
|
||||
|
||||
func encryptVariant(block cipher.Block, dst, src, buf []byte) {
|
||||
blocksize := block.BlockSize()
|
||||
tbl := buf[:blocksize]
|
||||
block.Encrypt(tbl, initialVector)
|
||||
n := len(src) / blocksize
|
||||
base := 0
|
||||
repeat := n / 8
|
||||
left := n % 8
|
||||
for i := 0; i < repeat; i++ {
|
||||
// 1
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
|
||||
// 2
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
|
||||
// 3
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
|
||||
// 4
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
|
||||
// 5
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
|
||||
// 6
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
|
||||
// 7
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
|
||||
// 8
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
}
|
||||
|
||||
switch left {
|
||||
case 7:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 6:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 5:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 4:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 3:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 2:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 1:
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
block.Encrypt(tbl, dst[base:])
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 0:
|
||||
xor.BytesSrc0(dst[base:], src[base:], tbl)
|
||||
}
|
||||
}
|
||||
|
||||
// decryption
|
||||
func decrypt(block cipher.Block, dst, src, buf []byte) {
|
||||
switch block.BlockSize() {
|
||||
case 8:
|
||||
decrypt8(block, dst, src, buf)
|
||||
case 16:
|
||||
decrypt16(block, dst, src, buf)
|
||||
default:
|
||||
decryptVariant(block, dst, src, buf)
|
||||
}
|
||||
}
|
||||
|
||||
func decrypt8(block cipher.Block, dst, src, buf []byte) {
|
||||
tbl := buf[0:8]
|
||||
next := buf[8:16]
|
||||
block.Encrypt(tbl, initialVector)
|
||||
n := len(src) / 8
|
||||
base := 0
|
||||
repeat := n / 8
|
||||
left := n % 8
|
||||
for i := 0; i < repeat; i++ {
|
||||
s := src[base:][0:64]
|
||||
d := dst[base:][0:64]
|
||||
// 1
|
||||
block.Encrypt(next, s[0:8])
|
||||
xor.BytesSrc1(d[0:8], s[0:8], tbl)
|
||||
// 2
|
||||
block.Encrypt(tbl, s[8:16])
|
||||
xor.BytesSrc1(d[8:16], s[8:16], next)
|
||||
// 3
|
||||
block.Encrypt(next, s[16:24])
|
||||
xor.BytesSrc1(d[16:24], s[16:24], tbl)
|
||||
// 4
|
||||
block.Encrypt(tbl, s[24:32])
|
||||
xor.BytesSrc1(d[24:32], s[24:32], next)
|
||||
// 5
|
||||
block.Encrypt(next, s[32:40])
|
||||
xor.BytesSrc1(d[32:40], s[32:40], tbl)
|
||||
// 6
|
||||
block.Encrypt(tbl, s[40:48])
|
||||
xor.BytesSrc1(d[40:48], s[40:48], next)
|
||||
// 7
|
||||
block.Encrypt(next, s[48:56])
|
||||
xor.BytesSrc1(d[48:56], s[48:56], tbl)
|
||||
// 8
|
||||
block.Encrypt(tbl, s[56:64])
|
||||
xor.BytesSrc1(d[56:64], s[56:64], next)
|
||||
base += 64
|
||||
}
|
||||
|
||||
switch left {
|
||||
case 7:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 8
|
||||
fallthrough
|
||||
case 6:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 8
|
||||
fallthrough
|
||||
case 5:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 8
|
||||
fallthrough
|
||||
case 4:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 8
|
||||
fallthrough
|
||||
case 3:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 8
|
||||
fallthrough
|
||||
case 2:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 8
|
||||
fallthrough
|
||||
case 1:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 8
|
||||
fallthrough
|
||||
case 0:
|
||||
xor.BytesSrc0(dst[base:], src[base:], tbl)
|
||||
}
|
||||
}
|
||||
|
||||
func decrypt16(block cipher.Block, dst, src, buf []byte) {
|
||||
tbl := buf[0:16]
|
||||
next := buf[16:32]
|
||||
block.Encrypt(tbl, initialVector)
|
||||
n := len(src) / 16
|
||||
base := 0
|
||||
repeat := n / 8
|
||||
left := n % 8
|
||||
for i := 0; i < repeat; i++ {
|
||||
s := src[base:][0:128]
|
||||
d := dst[base:][0:128]
|
||||
// 1
|
||||
block.Encrypt(next, s[0:16])
|
||||
xor.BytesSrc1(d[0:16], s[0:16], tbl)
|
||||
// 2
|
||||
block.Encrypt(tbl, s[16:32])
|
||||
xor.BytesSrc1(d[16:32], s[16:32], next)
|
||||
// 3
|
||||
block.Encrypt(next, s[32:48])
|
||||
xor.BytesSrc1(d[32:48], s[32:48], tbl)
|
||||
// 4
|
||||
block.Encrypt(tbl, s[48:64])
|
||||
xor.BytesSrc1(d[48:64], s[48:64], next)
|
||||
// 5
|
||||
block.Encrypt(next, s[64:80])
|
||||
xor.BytesSrc1(d[64:80], s[64:80], tbl)
|
||||
// 6
|
||||
block.Encrypt(tbl, s[80:96])
|
||||
xor.BytesSrc1(d[80:96], s[80:96], next)
|
||||
// 7
|
||||
block.Encrypt(next, s[96:112])
|
||||
xor.BytesSrc1(d[96:112], s[96:112], tbl)
|
||||
// 8
|
||||
block.Encrypt(tbl, s[112:128])
|
||||
xor.BytesSrc1(d[112:128], s[112:128], next)
|
||||
base += 128
|
||||
}
|
||||
|
||||
switch left {
|
||||
case 7:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 16
|
||||
fallthrough
|
||||
case 6:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 16
|
||||
fallthrough
|
||||
case 5:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 16
|
||||
fallthrough
|
||||
case 4:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 16
|
||||
fallthrough
|
||||
case 3:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 16
|
||||
fallthrough
|
||||
case 2:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 16
|
||||
fallthrough
|
||||
case 1:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += 16
|
||||
fallthrough
|
||||
case 0:
|
||||
xor.BytesSrc0(dst[base:], src[base:], tbl)
|
||||
}
|
||||
}
|
||||
|
||||
func decryptVariant(block cipher.Block, dst, src, buf []byte) {
|
||||
blocksize := block.BlockSize()
|
||||
tbl := buf[:blocksize]
|
||||
next := buf[blocksize:]
|
||||
block.Encrypt(tbl, initialVector)
|
||||
n := len(src) / blocksize
|
||||
base := 0
|
||||
repeat := n / 8
|
||||
left := n % 8
|
||||
for i := 0; i < repeat; i++ {
|
||||
// 1
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
base += blocksize
|
||||
|
||||
// 2
|
||||
block.Encrypt(tbl, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], next)
|
||||
base += blocksize
|
||||
|
||||
// 3
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
base += blocksize
|
||||
|
||||
// 4
|
||||
block.Encrypt(tbl, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], next)
|
||||
base += blocksize
|
||||
|
||||
// 5
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
base += blocksize
|
||||
|
||||
// 6
|
||||
block.Encrypt(tbl, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], next)
|
||||
base += blocksize
|
||||
|
||||
// 7
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
base += blocksize
|
||||
|
||||
// 8
|
||||
block.Encrypt(tbl, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], next)
|
||||
base += blocksize
|
||||
}
|
||||
|
||||
switch left {
|
||||
case 7:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 6:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 5:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 4:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 3:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 2:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 1:
|
||||
block.Encrypt(next, src[base:])
|
||||
xor.BytesSrc1(dst[base:], src[base:], tbl)
|
||||
tbl, next = next, tbl
|
||||
base += blocksize
|
||||
fallthrough
|
||||
case 0:
|
||||
xor.BytesSrc0(dst[base:], src[base:], tbl)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,289 @@
|
|||
package kcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSM4(t *testing.T) {
|
||||
bc, err := NewSM4BlockCrypt(pass[:16])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cryptTest(t, bc)
|
||||
}
|
||||
|
||||
func TestAES(t *testing.T) {
|
||||
bc, err := NewAESBlockCrypt(pass[:32])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cryptTest(t, bc)
|
||||
}
|
||||
|
||||
func TestTEA(t *testing.T) {
|
||||
bc, err := NewTEABlockCrypt(pass[:16])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cryptTest(t, bc)
|
||||
}
|
||||
|
||||
func TestXOR(t *testing.T) {
|
||||
bc, err := NewSimpleXORBlockCrypt(pass[:32])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cryptTest(t, bc)
|
||||
}
|
||||
|
||||
func TestBlowfish(t *testing.T) {
|
||||
bc, err := NewBlowfishBlockCrypt(pass[:32])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cryptTest(t, bc)
|
||||
}
|
||||
|
||||
func TestNone(t *testing.T) {
|
||||
bc, err := NewNoneBlockCrypt(pass[:32])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cryptTest(t, bc)
|
||||
}
|
||||
|
||||
func TestCast5(t *testing.T) {
|
||||
bc, err := NewCast5BlockCrypt(pass[:16])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cryptTest(t, bc)
|
||||
}
|
||||
|
||||
func Test3DES(t *testing.T) {
|
||||
bc, err := NewTripleDESBlockCrypt(pass[:24])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cryptTest(t, bc)
|
||||
}
|
||||
|
||||
func TestTwofish(t *testing.T) {
|
||||
bc, err := NewTwofishBlockCrypt(pass[:32])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cryptTest(t, bc)
|
||||
}
|
||||
|
||||
func TestXTEA(t *testing.T) {
|
||||
bc, err := NewXTEABlockCrypt(pass[:16])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cryptTest(t, bc)
|
||||
}
|
||||
|
||||
func TestSalsa20(t *testing.T) {
|
||||
bc, err := NewSalsa20BlockCrypt(pass[:32])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cryptTest(t, bc)
|
||||
}
|
||||
|
||||
func cryptTest(t *testing.T, bc BlockCrypt) {
|
||||
data := make([]byte, mtuLimit)
|
||||
io.ReadFull(rand.Reader, data)
|
||||
dec := make([]byte, mtuLimit)
|
||||
enc := make([]byte, mtuLimit)
|
||||
bc.Encrypt(enc, data)
|
||||
bc.Decrypt(dec, enc)
|
||||
if !bytes.Equal(data, dec) {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSM4(b *testing.B) {
|
||||
bc, err := NewSM4BlockCrypt(pass[:16])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func BenchmarkAES128(b *testing.B) {
|
||||
bc, err := NewAESBlockCrypt(pass[:16])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func BenchmarkAES192(b *testing.B) {
|
||||
bc, err := NewAESBlockCrypt(pass[:24])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func BenchmarkAES256(b *testing.B) {
|
||||
bc, err := NewAESBlockCrypt(pass[:32])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func BenchmarkTEA(b *testing.B) {
|
||||
bc, err := NewTEABlockCrypt(pass[:16])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func BenchmarkXOR(b *testing.B) {
|
||||
bc, err := NewSimpleXORBlockCrypt(pass[:32])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func BenchmarkBlowfish(b *testing.B) {
|
||||
bc, err := NewBlowfishBlockCrypt(pass[:32])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func BenchmarkNone(b *testing.B) {
|
||||
bc, err := NewNoneBlockCrypt(pass[:32])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func BenchmarkCast5(b *testing.B) {
|
||||
bc, err := NewCast5BlockCrypt(pass[:16])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func Benchmark3DES(b *testing.B) {
|
||||
bc, err := NewTripleDESBlockCrypt(pass[:24])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func BenchmarkTwofish(b *testing.B) {
|
||||
bc, err := NewTwofishBlockCrypt(pass[:32])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func BenchmarkXTEA(b *testing.B) {
|
||||
bc, err := NewXTEABlockCrypt(pass[:16])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func BenchmarkSalsa20(b *testing.B) {
|
||||
bc, err := NewSalsa20BlockCrypt(pass[:32])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
benchCrypt(b, bc)
|
||||
}
|
||||
|
||||
func benchCrypt(b *testing.B, bc BlockCrypt) {
|
||||
data := make([]byte, mtuLimit)
|
||||
io.ReadFull(rand.Reader, data)
|
||||
dec := make([]byte, mtuLimit)
|
||||
enc := make([]byte, mtuLimit)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(enc) * 2))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bc.Encrypt(enc, data)
|
||||
bc.Decrypt(dec, enc)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCRC32(b *testing.B) {
|
||||
content := make([]byte, 1024)
|
||||
b.SetBytes(int64(len(content)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
crc32.ChecksumIEEE(content)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCsprngSystem(b *testing.B) {
|
||||
data := make([]byte, md5.Size)
|
||||
b.SetBytes(int64(len(data)))
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
io.ReadFull(rand.Reader, data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCsprngMD5(b *testing.B) {
|
||||
var data [md5.Size]byte
|
||||
b.SetBytes(md5.Size)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
data = md5.Sum(data[:])
|
||||
}
|
||||
}
|
||||
func BenchmarkCsprngSHA1(b *testing.B) {
|
||||
var data [sha1.Size]byte
|
||||
b.SetBytes(sha1.Size)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
data = sha1.Sum(data[:])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCsprngNonceMD5(b *testing.B) {
|
||||
var ng nonceMD5
|
||||
ng.Init()
|
||||
b.SetBytes(md5.Size)
|
||||
data := make([]byte, md5.Size)
|
||||
for i := 0; i < b.N; i++ {
|
||||
ng.Fill(data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCsprngNonceAES128(b *testing.B) {
|
||||
var ng nonceAES128
|
||||
ng.Init()
|
||||
|
||||
b.SetBytes(aes.BlockSize)
|
||||
data := make([]byte, aes.BlockSize)
|
||||
for i := 0; i < b.N; i++ {
|
||||
ng.Fill(data)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
package kcp
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Entropy defines a entropy source
|
||||
type Entropy interface {
|
||||
Init()
|
||||
Fill(nonce []byte)
|
||||
}
|
||||
|
||||
// nonceMD5 nonce generator for packet header
|
||||
type nonceMD5 struct {
|
||||
seed [md5.Size]byte
|
||||
}
|
||||
|
||||
func (n *nonceMD5) Init() { /*nothing required*/ }
|
||||
|
||||
func (n *nonceMD5) Fill(nonce []byte) {
|
||||
if n.seed[0] == 0 { // entropy update
|
||||
io.ReadFull(rand.Reader, n.seed[:])
|
||||
}
|
||||
n.seed = md5.Sum(n.seed[:])
|
||||
copy(nonce, n.seed[:])
|
||||
}
|
||||
|
||||
// nonceAES128 nonce generator for packet headers
|
||||
type nonceAES128 struct {
|
||||
seed [aes.BlockSize]byte
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
func (n *nonceAES128) Init() {
|
||||
var key [16]byte //aes-128
|
||||
io.ReadFull(rand.Reader, key[:])
|
||||
io.ReadFull(rand.Reader, n.seed[:])
|
||||
block, _ := aes.NewCipher(key[:])
|
||||
n.block = block
|
||||
}
|
||||
|
||||
func (n *nonceAES128) Fill(nonce []byte) {
|
||||
if n.seed[0] == 0 { // entropy update
|
||||
io.ReadFull(rand.Reader, n.seed[:])
|
||||
}
|
||||
n.block.Encrypt(n.seed[:], n.seed[:])
|
||||
copy(nonce, n.seed[:])
|
||||
}
|
|
@ -0,0 +1,311 @@
|
|||
package kcp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/klauspost/reedsolomon"
|
||||
)
|
||||
|
||||
const (
|
||||
fecHeaderSize = 6
|
||||
fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size
|
||||
typeData = 0xf1
|
||||
typeFEC = 0xf2
|
||||
)
|
||||
|
||||
type (
|
||||
// fecPacket is a decoded FEC packet
|
||||
fecPacket struct {
|
||||
seqid uint32
|
||||
flag uint16
|
||||
data []byte
|
||||
}
|
||||
|
||||
// fecDecoder for decoding incoming packets
|
||||
fecDecoder struct {
|
||||
rxlimit int // queue size limit
|
||||
dataShards int
|
||||
parityShards int
|
||||
shardSize int
|
||||
rx []fecPacket // ordered receive queue
|
||||
|
||||
// caches
|
||||
decodeCache [][]byte
|
||||
flagCache []bool
|
||||
|
||||
// zeros
|
||||
zeros []byte
|
||||
|
||||
// RS decoder
|
||||
codec reedsolomon.Encoder
|
||||
}
|
||||
)
|
||||
|
||||
func newFECDecoder(rxlimit, dataShards, parityShards int) *fecDecoder {
|
||||
if dataShards <= 0 || parityShards <= 0 {
|
||||
return nil
|
||||
}
|
||||
if rxlimit < dataShards+parityShards {
|
||||
return nil
|
||||
}
|
||||
|
||||
dec := new(fecDecoder)
|
||||
dec.rxlimit = rxlimit
|
||||
dec.dataShards = dataShards
|
||||
dec.parityShards = parityShards
|
||||
dec.shardSize = dataShards + parityShards
|
||||
codec, err := reedsolomon.New(dataShards, parityShards)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
dec.codec = codec
|
||||
dec.decodeCache = make([][]byte, dec.shardSize)
|
||||
dec.flagCache = make([]bool, dec.shardSize)
|
||||
dec.zeros = make([]byte, mtuLimit)
|
||||
return dec
|
||||
}
|
||||
|
||||
// decodeBytes a fec packet
|
||||
func (dec *fecDecoder) decodeBytes(data []byte) fecPacket {
|
||||
var pkt fecPacket
|
||||
pkt.seqid = binary.LittleEndian.Uint32(data)
|
||||
pkt.flag = binary.LittleEndian.Uint16(data[4:])
|
||||
// allocate memory & copy
|
||||
buf := xmitBuf.Get().([]byte)[:len(data)-6]
|
||||
copy(buf, data[6:])
|
||||
pkt.data = buf
|
||||
return pkt
|
||||
}
|
||||
|
||||
// decode a fec packet
|
||||
func (dec *fecDecoder) decode(pkt fecPacket) (recovered [][]byte) {
|
||||
// insertion
|
||||
n := len(dec.rx) - 1
|
||||
insertIdx := 0
|
||||
for i := n; i >= 0; i-- {
|
||||
if pkt.seqid == dec.rx[i].seqid { // de-duplicate
|
||||
xmitBuf.Put(pkt.data)
|
||||
return nil
|
||||
} else if _itimediff(pkt.seqid, dec.rx[i].seqid) > 0 { // insertion
|
||||
insertIdx = i + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// insert into ordered rx queue
|
||||
if insertIdx == n+1 {
|
||||
dec.rx = append(dec.rx, pkt)
|
||||
} else {
|
||||
dec.rx = append(dec.rx, fecPacket{})
|
||||
copy(dec.rx[insertIdx+1:], dec.rx[insertIdx:]) // shift right
|
||||
dec.rx[insertIdx] = pkt
|
||||
}
|
||||
|
||||
// shard range for current packet
|
||||
shardBegin := pkt.seqid - pkt.seqid%uint32(dec.shardSize)
|
||||
shardEnd := shardBegin + uint32(dec.shardSize) - 1
|
||||
|
||||
// max search range in ordered queue for current shard
|
||||
searchBegin := insertIdx - int(pkt.seqid%uint32(dec.shardSize))
|
||||
if searchBegin < 0 {
|
||||
searchBegin = 0
|
||||
}
|
||||
searchEnd := searchBegin + dec.shardSize - 1
|
||||
if searchEnd >= len(dec.rx) {
|
||||
searchEnd = len(dec.rx) - 1
|
||||
}
|
||||
|
||||
// re-construct datashards
|
||||
if searchEnd-searchBegin+1 >= dec.dataShards {
|
||||
var numshard, numDataShard, first, maxlen int
|
||||
|
||||
// zero caches
|
||||
shards := dec.decodeCache
|
||||
shardsflag := dec.flagCache
|
||||
for k := range dec.decodeCache {
|
||||
shards[k] = nil
|
||||
shardsflag[k] = false
|
||||
}
|
||||
|
||||
// shard assembly
|
||||
for i := searchBegin; i <= searchEnd; i++ {
|
||||
seqid := dec.rx[i].seqid
|
||||
if _itimediff(seqid, shardEnd) > 0 {
|
||||
break
|
||||
} else if _itimediff(seqid, shardBegin) >= 0 {
|
||||
shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data
|
||||
shardsflag[seqid%uint32(dec.shardSize)] = true
|
||||
numshard++
|
||||
if dec.rx[i].flag == typeData {
|
||||
numDataShard++
|
||||
}
|
||||
if numshard == 1 {
|
||||
first = i
|
||||
}
|
||||
if len(dec.rx[i].data) > maxlen {
|
||||
maxlen = len(dec.rx[i].data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if numDataShard == dec.dataShards {
|
||||
// case 1: no loss on data shards
|
||||
dec.rx = dec.freeRange(first, numshard, dec.rx)
|
||||
} else if numshard >= dec.dataShards {
|
||||
// case 2: loss on data shards, but it's recoverable from parity shards
|
||||
for k := range shards {
|
||||
if shards[k] != nil {
|
||||
dlen := len(shards[k])
|
||||
shards[k] = shards[k][:maxlen]
|
||||
copy(shards[k][dlen:], dec.zeros)
|
||||
}
|
||||
}
|
||||
if err := dec.codec.ReconstructData(shards); err == nil {
|
||||
for k := range shards[:dec.dataShards] {
|
||||
if !shardsflag[k] {
|
||||
recovered = append(recovered, shards[k])
|
||||
}
|
||||
}
|
||||
}
|
||||
dec.rx = dec.freeRange(first, numshard, dec.rx)
|
||||
}
|
||||
}
|
||||
|
||||
// keep rxlimit
|
||||
if len(dec.rx) > dec.rxlimit {
|
||||
if dec.rx[0].flag == typeData { // track the unrecoverable data
|
||||
atomic.AddUint64(&DefaultSnmp.FECShortShards, 1)
|
||||
}
|
||||
dec.rx = dec.freeRange(0, 1, dec.rx)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// free a range of fecPacket, and zero for GC recycling
|
||||
func (dec *fecDecoder) freeRange(first, n int, q []fecPacket) []fecPacket {
|
||||
for i := first; i < first+n; i++ { // recycle buffer
|
||||
xmitBuf.Put(q[i].data)
|
||||
}
|
||||
copy(q[first:], q[first+n:])
|
||||
for i := 0; i < n; i++ { // dereference data
|
||||
q[len(q)-1-i].data = nil
|
||||
}
|
||||
return q[:len(q)-n]
|
||||
}
|
||||
|
||||
type (
|
||||
// fecEncoder for encoding outgoing packets
|
||||
fecEncoder struct {
|
||||
dataShards int
|
||||
parityShards int
|
||||
shardSize int
|
||||
paws uint32 // Protect Against Wrapped Sequence numbers
|
||||
next uint32 // next seqid
|
||||
|
||||
shardCount int // count the number of datashards collected
|
||||
maxSize int // track maximum data length in datashard
|
||||
|
||||
headerOffset int // FEC header offset
|
||||
payloadOffset int // FEC payload offset
|
||||
|
||||
// caches
|
||||
shardCache [][]byte
|
||||
encodeCache [][]byte
|
||||
|
||||
// zeros
|
||||
zeros []byte
|
||||
|
||||
// RS encoder
|
||||
codec reedsolomon.Encoder
|
||||
}
|
||||
)
|
||||
|
||||
func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder {
|
||||
if dataShards <= 0 || parityShards <= 0 {
|
||||
return nil
|
||||
}
|
||||
enc := new(fecEncoder)
|
||||
enc.dataShards = dataShards
|
||||
enc.parityShards = parityShards
|
||||
enc.shardSize = dataShards + parityShards
|
||||
enc.paws = (0xffffffff/uint32(enc.shardSize) - 1) * uint32(enc.shardSize)
|
||||
enc.headerOffset = offset
|
||||
enc.payloadOffset = enc.headerOffset + fecHeaderSize
|
||||
|
||||
codec, err := reedsolomon.New(dataShards, parityShards)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
enc.codec = codec
|
||||
|
||||
// caches
|
||||
enc.encodeCache = make([][]byte, enc.shardSize)
|
||||
enc.shardCache = make([][]byte, enc.shardSize)
|
||||
for k := range enc.shardCache {
|
||||
enc.shardCache[k] = make([]byte, mtuLimit)
|
||||
}
|
||||
enc.zeros = make([]byte, mtuLimit)
|
||||
return enc
|
||||
}
|
||||
|
||||
// encodes the packet, outputs parity shards if we have collected quorum datashards
|
||||
// notice: the contents of 'ps' will be re-written in successive calling
|
||||
func (enc *fecEncoder) encode(b []byte) (ps [][]byte) {
|
||||
enc.markData(b[enc.headerOffset:])
|
||||
binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:])))
|
||||
|
||||
// copy data to fec datashards
|
||||
sz := len(b)
|
||||
enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz]
|
||||
copy(enc.shardCache[enc.shardCount], b)
|
||||
enc.shardCount++
|
||||
|
||||
// track max datashard length
|
||||
if sz > enc.maxSize {
|
||||
enc.maxSize = sz
|
||||
}
|
||||
|
||||
// Generation of Reed-Solomon Erasure Code
|
||||
if enc.shardCount == enc.dataShards {
|
||||
// fill '0' into the tail of each datashard
|
||||
for i := 0; i < enc.dataShards; i++ {
|
||||
shard := enc.shardCache[i]
|
||||
slen := len(shard)
|
||||
copy(shard[slen:enc.maxSize], enc.zeros)
|
||||
}
|
||||
|
||||
// construct equal-sized slice with stripped header
|
||||
cache := enc.encodeCache
|
||||
for k := range cache {
|
||||
cache[k] = enc.shardCache[k][enc.payloadOffset:enc.maxSize]
|
||||
}
|
||||
|
||||
// encoding
|
||||
if err := enc.codec.Encode(cache); err == nil {
|
||||
ps = enc.shardCache[enc.dataShards:]
|
||||
for k := range ps {
|
||||
enc.markFEC(ps[k][enc.headerOffset:])
|
||||
ps[k] = ps[k][:enc.maxSize]
|
||||
}
|
||||
}
|
||||
|
||||
// counters resetting
|
||||
enc.shardCount = 0
|
||||
enc.maxSize = 0
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (enc *fecEncoder) markData(data []byte) {
|
||||
binary.LittleEndian.PutUint32(data, enc.next)
|
||||
binary.LittleEndian.PutUint16(data[4:], typeData)
|
||||
enc.next++
|
||||
}
|
||||
|
||||
func (enc *fecEncoder) markFEC(data []byte) {
|
||||
binary.LittleEndian.PutUint32(data, enc.next)
|
||||
binary.LittleEndian.PutUint16(data[4:], typeFEC)
|
||||
enc.next = (enc.next + 1) % enc.paws
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package kcp
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkFECDecode(b *testing.B) {
|
||||
const dataSize = 10
|
||||
const paritySize = 3
|
||||
const payLoad = 1500
|
||||
decoder := newFECDecoder(1024, dataSize, paritySize)
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(payLoad)
|
||||
for i := 0; i < b.N; i++ {
|
||||
if rand.Int()%(dataSize+paritySize) == 0 { // random loss
|
||||
continue
|
||||
}
|
||||
var pkt fecPacket
|
||||
pkt.seqid = uint32(i)
|
||||
if i%(dataSize+paritySize) >= dataSize {
|
||||
pkt.flag = typeFEC
|
||||
} else {
|
||||
pkt.flag = typeData
|
||||
}
|
||||
pkt.data = make([]byte, payLoad)
|
||||
decoder.decode(pkt)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFECEncode(b *testing.B) {
|
||||
const dataSize = 10
|
||||
const paritySize = 3
|
||||
const payLoad = 1500
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(payLoad)
|
||||
encoder := newFECEncoder(dataSize, paritySize, 0)
|
||||
for i := 0; i < b.N; i++ {
|
||||
data := make([]byte, payLoad)
|
||||
encoder.encode(data)
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,302 @@
|
|||
package kcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"container/list"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func iclock() int32 {
|
||||
return int32(currentMs())
|
||||
}
|
||||
|
||||
type DelayPacket struct {
|
||||
_ptr []byte
|
||||
_size int
|
||||
_ts int32
|
||||
}
|
||||
|
||||
func (p *DelayPacket) Init(size int, src []byte) {
|
||||
p._ptr = make([]byte, size)
|
||||
p._size = size
|
||||
copy(p._ptr, src[:size])
|
||||
}
|
||||
|
||||
func (p *DelayPacket) ptr() []byte { return p._ptr }
|
||||
func (p *DelayPacket) size() int { return p._size }
|
||||
func (p *DelayPacket) ts() int32 { return p._ts }
|
||||
func (p *DelayPacket) setts(ts int32) { p._ts = ts }
|
||||
|
||||
type DelayTunnel struct{ *list.List }
|
||||
type LatencySimulator struct {
|
||||
current int32
|
||||
lostrate, rttmin, rttmax, nmax int
|
||||
p12 DelayTunnel
|
||||
p21 DelayTunnel
|
||||
r12 *rand.Rand
|
||||
r21 *rand.Rand
|
||||
}
|
||||
|
||||
// lostrate: 往返一周丢包率的百分比,默认 10%
|
||||
// rttmin:rtt最小值,默认 60
|
||||
// rttmax:rtt最大值,默认 125
|
||||
//func (p *LatencySimulator)Init(int lostrate = 10, int rttmin = 60, int rttmax = 125, int nmax = 1000):
|
||||
func (p *LatencySimulator) Init(lostrate, rttmin, rttmax, nmax int) {
|
||||
p.r12 = rand.New(rand.NewSource(9))
|
||||
p.r21 = rand.New(rand.NewSource(99))
|
||||
p.p12 = DelayTunnel{list.New()}
|
||||
p.p21 = DelayTunnel{list.New()}
|
||||
p.current = iclock()
|
||||
p.lostrate = lostrate / 2 // 上面数据是往返丢包率,单程除以2
|
||||
p.rttmin = rttmin / 2
|
||||
p.rttmax = rttmax / 2
|
||||
p.nmax = nmax
|
||||
}
|
||||
|
||||
// 发送数据
|
||||
// peer - 端点0/1,从0发送,从1接收;从1发送从0接收
|
||||
func (p *LatencySimulator) send(peer int, data []byte, size int) int {
|
||||
rnd := 0
|
||||
if peer == 0 {
|
||||
rnd = p.r12.Intn(100)
|
||||
} else {
|
||||
rnd = p.r21.Intn(100)
|
||||
}
|
||||
//println("!!!!!!!!!!!!!!!!!!!!", rnd, p.lostrate, peer)
|
||||
if rnd < p.lostrate {
|
||||
return 0
|
||||
}
|
||||
pkt := &DelayPacket{}
|
||||
pkt.Init(size, data)
|
||||
p.current = iclock()
|
||||
delay := p.rttmin
|
||||
if p.rttmax > p.rttmin {
|
||||
delay += rand.Int() % (p.rttmax - p.rttmin)
|
||||
}
|
||||
pkt.setts(p.current + int32(delay))
|
||||
if peer == 0 {
|
||||
p.p12.PushBack(pkt)
|
||||
} else {
|
||||
p.p21.PushBack(pkt)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// 接收数据
|
||||
func (p *LatencySimulator) recv(peer int, data []byte, maxsize int) int32 {
|
||||
var it *list.Element
|
||||
if peer == 0 {
|
||||
it = p.p21.Front()
|
||||
if p.p21.Len() == 0 {
|
||||
return -1
|
||||
}
|
||||
} else {
|
||||
it = p.p12.Front()
|
||||
if p.p12.Len() == 0 {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
pkt := it.Value.(*DelayPacket)
|
||||
p.current = iclock()
|
||||
if p.current < pkt.ts() {
|
||||
return -2
|
||||
}
|
||||
if maxsize < pkt.size() {
|
||||
return -3
|
||||
}
|
||||
if peer == 0 {
|
||||
p.p21.Remove(it)
|
||||
} else {
|
||||
p.p12.Remove(it)
|
||||
}
|
||||
maxsize = pkt.size()
|
||||
copy(data, pkt.ptr()[:maxsize])
|
||||
return int32(maxsize)
|
||||
}
|
||||
|
||||
//=====================================================================
|
||||
//=====================================================================
|
||||
|
||||
// 模拟网络
|
||||
var vnet *LatencySimulator
|
||||
|
||||
// 测试用例
|
||||
func test(mode int) {
|
||||
// 创建模拟网络:丢包率10%,Rtt 60ms~125ms
|
||||
vnet = &LatencySimulator{}
|
||||
vnet.Init(10, 60, 125, 1000)
|
||||
|
||||
// 创建两个端点的 kcp对象,第一个参数 conv是会话编号,同一个会话需要相同
|
||||
// 最后一个是 user参数,用来传递标识
|
||||
output1 := func(buf []byte, size int) {
|
||||
if vnet.send(0, buf, size) != 1 {
|
||||
}
|
||||
}
|
||||
output2 := func(buf []byte, size int) {
|
||||
if vnet.send(1, buf, size) != 1 {
|
||||
}
|
||||
}
|
||||
kcp1 := NewKCP(0x11223344, output1)
|
||||
kcp2 := NewKCP(0x11223344, output2)
|
||||
|
||||
current := uint32(iclock())
|
||||
slap := current + 20
|
||||
index := 0
|
||||
next := 0
|
||||
var sumrtt uint32
|
||||
count := 0
|
||||
maxrtt := 0
|
||||
|
||||
// 配置窗口大小:平均延迟200ms,每20ms发送一个包,
|
||||
// 而考虑到丢包重发,设置最大收发窗口为128
|
||||
kcp1.WndSize(128, 128)
|
||||
kcp2.WndSize(128, 128)
|
||||
|
||||
// 判断测试用例的模式
|
||||
if mode == 0 {
|
||||
// 默认模式
|
||||
kcp1.NoDelay(0, 10, 0, 0)
|
||||
kcp2.NoDelay(0, 10, 0, 0)
|
||||
} else if mode == 1 {
|
||||
// 普通模式,关闭流控等
|
||||
kcp1.NoDelay(0, 10, 0, 1)
|
||||
kcp2.NoDelay(0, 10, 0, 1)
|
||||
} else {
|
||||
// 启动快速模式
|
||||
// 第二个参数 nodelay-启用以后若干常规加速将启动
|
||||
// 第三个参数 interval为内部处理时钟,默认设置为 10ms
|
||||
// 第四个参数 resend为快速重传指标,设置为2
|
||||
// 第五个参数 为是否禁用常规流控,这里禁止
|
||||
kcp1.NoDelay(1, 10, 2, 1)
|
||||
kcp2.NoDelay(1, 10, 2, 1)
|
||||
}
|
||||
|
||||
buffer := make([]byte, 2000)
|
||||
var hr int32
|
||||
|
||||
ts1 := iclock()
|
||||
|
||||
for {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
current = uint32(iclock())
|
||||
kcp1.Update()
|
||||
kcp2.Update()
|
||||
|
||||
// 每隔 20ms,kcp1发送数据
|
||||
for ; current >= slap; slap += 20 {
|
||||
buf := new(bytes.Buffer)
|
||||
binary.Write(buf, binary.LittleEndian, uint32(index))
|
||||
index++
|
||||
binary.Write(buf, binary.LittleEndian, uint32(current))
|
||||
// 发送上层协议包
|
||||
kcp1.Send(buf.Bytes())
|
||||
//println("now", iclock())
|
||||
}
|
||||
|
||||
// 处理虚拟网络:检测是否有udp包从p1->p2
|
||||
for {
|
||||
hr = vnet.recv(1, buffer, 2000)
|
||||
if hr < 0 {
|
||||
break
|
||||
}
|
||||
// 如果 p2收到udp,则作为下层协议输入到kcp2
|
||||
kcp2.Input(buffer[:hr], true, false)
|
||||
}
|
||||
|
||||
// 处理虚拟网络:检测是否有udp包从p2->p1
|
||||
for {
|
||||
hr = vnet.recv(0, buffer, 2000)
|
||||
if hr < 0 {
|
||||
break
|
||||
}
|
||||
// 如果 p1收到udp,则作为下层协议输入到kcp1
|
||||
kcp1.Input(buffer[:hr], true, false)
|
||||
//println("@@@@", hr, r)
|
||||
}
|
||||
|
||||
// kcp2接收到任何包都返回回去
|
||||
for {
|
||||
hr = int32(kcp2.Recv(buffer[:10]))
|
||||
// 没有收到包就退出
|
||||
if hr < 0 {
|
||||
break
|
||||
}
|
||||
// 如果收到包就回射
|
||||
buf := bytes.NewReader(buffer)
|
||||
var sn uint32
|
||||
binary.Read(buf, binary.LittleEndian, &sn)
|
||||
kcp2.Send(buffer[:hr])
|
||||
}
|
||||
|
||||
// kcp1收到kcp2的回射数据
|
||||
for {
|
||||
hr = int32(kcp1.Recv(buffer[:10]))
|
||||
buf := bytes.NewReader(buffer)
|
||||
// 没有收到包就退出
|
||||
if hr < 0 {
|
||||
break
|
||||
}
|
||||
var sn uint32
|
||||
var ts, rtt uint32
|
||||
binary.Read(buf, binary.LittleEndian, &sn)
|
||||
binary.Read(buf, binary.LittleEndian, &ts)
|
||||
rtt = uint32(current) - ts
|
||||
|
||||
if sn != uint32(next) {
|
||||
// 如果收到的包不连续
|
||||
//for i:=0;i<8 ;i++ {
|
||||
//println("---", i, buffer[i])
|
||||
//}
|
||||
println("ERROR sn ", count, "<->", next, sn)
|
||||
return
|
||||
}
|
||||
|
||||
next++
|
||||
sumrtt += rtt
|
||||
count++
|
||||
if rtt > uint32(maxrtt) {
|
||||
maxrtt = int(rtt)
|
||||
}
|
||||
|
||||
//println("[RECV] mode=", mode, " sn=", sn, " rtt=", rtt)
|
||||
}
|
||||
|
||||
if next > 100 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ts1 = iclock() - ts1
|
||||
|
||||
names := []string{"default", "normal", "fast"}
|
||||
fmt.Printf("%s mode result (%dms):\n", names[mode], ts1)
|
||||
fmt.Printf("avgrtt=%d maxrtt=%d\n", int(sumrtt/uint32(count)), maxrtt)
|
||||
}
|
||||
|
||||
func TestNetwork(t *testing.T) {
|
||||
test(0) // 默认模式,类似 TCP:正常模式,无快速重传,常规流控
|
||||
test(1) // 普通模式,关闭流控等
|
||||
test(2) // 快速模式,所有开关都打开,且关闭流控
|
||||
}
|
||||
|
||||
func BenchmarkFlush(b *testing.B) {
|
||||
kcp := NewKCP(1, func(buf []byte, size int) {})
|
||||
kcp.snd_buf = make([]segment, 1024)
|
||||
for k := range kcp.snd_buf {
|
||||
kcp.snd_buf[k].xmit = 1
|
||||
kcp.snd_buf[k].resendts = currentMs() + 10000
|
||||
}
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
var mu sync.Mutex
|
||||
for i := 0; i < b.N; i++ {
|
||||
mu.Lock()
|
||||
kcp.flush(false)
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,963 @@
|
|||
package kcp
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"hash/crc32"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
type errTimeout struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (errTimeout) Timeout() bool { return true }
|
||||
func (errTimeout) Temporary() bool { return true }
|
||||
func (errTimeout) Error() string { return "i/o timeout" }
|
||||
|
||||
const (
|
||||
// 16-bytes nonce for each packet
|
||||
nonceSize = 16
|
||||
|
||||
// 4-bytes packet checksum
|
||||
crcSize = 4
|
||||
|
||||
// overall crypto header size
|
||||
cryptHeaderSize = nonceSize + crcSize
|
||||
|
||||
// maximum packet size
|
||||
mtuLimit = 1500
|
||||
|
||||
// FEC keeps rxFECMulti* (dataShard+parityShard) ordered packets in memory
|
||||
rxFECMulti = 3
|
||||
|
||||
// accept backlog
|
||||
acceptBacklog = 128
|
||||
)
|
||||
|
||||
const (
|
||||
errBrokenPipe = "broken pipe"
|
||||
errInvalidOperation = "invalid operation"
|
||||
)
|
||||
|
||||
var (
|
||||
// a system-wide packet buffer shared among sending, receiving and FEC
|
||||
// to mitigate high-frequency memory allocation for packets
|
||||
xmitBuf sync.Pool
|
||||
)
|
||||
|
||||
func init() {
|
||||
xmitBuf.New = func() interface{} {
|
||||
return make([]byte, mtuLimit)
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
// UDPSession defines a KCP session implemented by UDP
|
||||
UDPSession struct {
|
||||
updaterIdx int // record slice index in updater
|
||||
conn net.PacketConn // the underlying packet connection
|
||||
kcp *KCP // KCP ARQ protocol
|
||||
l *Listener // pointing to the Listener object if it's been accepted by a Listener
|
||||
block BlockCrypt // block encryption object
|
||||
|
||||
// kcp receiving is based on packets
|
||||
// recvbuf turns packets into stream
|
||||
recvbuf []byte
|
||||
bufptr []byte
|
||||
// header extended output buffer, if has header
|
||||
ext []byte
|
||||
|
||||
// FEC codec
|
||||
fecDecoder *fecDecoder
|
||||
fecEncoder *fecEncoder
|
||||
|
||||
// settings
|
||||
remote net.Addr // remote peer address
|
||||
rd time.Time // read deadline
|
||||
wd time.Time // write deadline
|
||||
headerSize int // the header size additional to a KCP frame
|
||||
ackNoDelay bool // send ack immediately for each incoming packet(testing purpose)
|
||||
writeDelay bool // delay kcp.flush() for Write() for bulk transfer
|
||||
dup int // duplicate udp packets(testing purpose)
|
||||
|
||||
// notifications
|
||||
die chan struct{} // notify current session has Closed
|
||||
chReadEvent chan struct{} // notify Read() can be called without blocking
|
||||
chWriteEvent chan struct{} // notify Write() can be called without blocking
|
||||
chReadError chan error // notify PacketConn.Read() have an error
|
||||
chWriteError chan error // notify PacketConn.Write() have an error
|
||||
|
||||
// nonce generator
|
||||
nonce Entropy
|
||||
|
||||
isClosed bool // flag the session has Closed
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
setReadBuffer interface {
|
||||
SetReadBuffer(bytes int) error
|
||||
}
|
||||
|
||||
setWriteBuffer interface {
|
||||
SetWriteBuffer(bytes int) error
|
||||
}
|
||||
)
|
||||
|
||||
// newUDPSession create a new udp session for client or server
|
||||
func newUDPSession(conv uint32, dataShards, parityShards int, l *Listener, conn net.PacketConn, remote net.Addr, block BlockCrypt) *UDPSession {
|
||||
sess := new(UDPSession)
|
||||
sess.die = make(chan struct{})
|
||||
sess.nonce = new(nonceAES128)
|
||||
sess.nonce.Init()
|
||||
sess.chReadEvent = make(chan struct{}, 1)
|
||||
sess.chWriteEvent = make(chan struct{}, 1)
|
||||
sess.chReadError = make(chan error, 1)
|
||||
sess.chWriteError = make(chan error, 1)
|
||||
sess.remote = remote
|
||||
sess.conn = conn
|
||||
sess.l = l
|
||||
sess.block = block
|
||||
sess.recvbuf = make([]byte, mtuLimit)
|
||||
|
||||
// FEC codec initialization
|
||||
sess.fecDecoder = newFECDecoder(rxFECMulti*(dataShards+parityShards), dataShards, parityShards)
|
||||
if sess.block != nil {
|
||||
sess.fecEncoder = newFECEncoder(dataShards, parityShards, cryptHeaderSize)
|
||||
} else {
|
||||
sess.fecEncoder = newFECEncoder(dataShards, parityShards, 0)
|
||||
}
|
||||
|
||||
// calculate additional header size introduced by FEC and encryption
|
||||
if sess.block != nil {
|
||||
sess.headerSize += cryptHeaderSize
|
||||
}
|
||||
if sess.fecEncoder != nil {
|
||||
sess.headerSize += fecHeaderSizePlus2
|
||||
}
|
||||
|
||||
// we only need to allocate extended packet buffer if we have the additional header
|
||||
if sess.headerSize > 0 {
|
||||
sess.ext = make([]byte, mtuLimit)
|
||||
}
|
||||
|
||||
sess.kcp = NewKCP(conv, func(buf []byte, size int) {
|
||||
if size >= IKCP_OVERHEAD {
|
||||
sess.output(buf[:size])
|
||||
}
|
||||
})
|
||||
sess.kcp.SetMtu(IKCP_MTU_DEF - sess.headerSize)
|
||||
|
||||
// register current session to the global updater,
|
||||
// which call sess.update() periodically.
|
||||
updater.addSession(sess)
|
||||
|
||||
if sess.l == nil { // it's a client connection
|
||||
go sess.readLoop()
|
||||
atomic.AddUint64(&DefaultSnmp.ActiveOpens, 1)
|
||||
} else {
|
||||
atomic.AddUint64(&DefaultSnmp.PassiveOpens, 1)
|
||||
}
|
||||
currestab := atomic.AddUint64(&DefaultSnmp.CurrEstab, 1)
|
||||
maxconn := atomic.LoadUint64(&DefaultSnmp.MaxConn)
|
||||
if currestab > maxconn {
|
||||
atomic.CompareAndSwapUint64(&DefaultSnmp.MaxConn, maxconn, currestab)
|
||||
}
|
||||
|
||||
return sess
|
||||
}
|
||||
|
||||
// Read implements net.Conn
|
||||
func (s *UDPSession) Read(b []byte) (n int, err error) {
|
||||
for {
|
||||
s.mu.Lock()
|
||||
if len(s.bufptr) > 0 { // copy from buffer into b
|
||||
n = copy(b, s.bufptr)
|
||||
s.bufptr = s.bufptr[n:]
|
||||
s.mu.Unlock()
|
||||
atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(n))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if s.isClosed {
|
||||
s.mu.Unlock()
|
||||
return 0, errors.New(errBrokenPipe)
|
||||
}
|
||||
|
||||
if size := s.kcp.PeekSize(); size > 0 { // peek data size from kcp
|
||||
if len(b) >= size { // receive data into 'b' directly
|
||||
s.kcp.Recv(b)
|
||||
s.mu.Unlock()
|
||||
atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(size))
|
||||
return size, nil
|
||||
}
|
||||
|
||||
// if necessary resize the stream buffer to guarantee a sufficent buffer space
|
||||
if cap(s.recvbuf) < size {
|
||||
s.recvbuf = make([]byte, size)
|
||||
}
|
||||
|
||||
// resize the length of recvbuf to correspond to data size
|
||||
s.recvbuf = s.recvbuf[:size]
|
||||
s.kcp.Recv(s.recvbuf)
|
||||
n = copy(b, s.recvbuf) // copy to 'b'
|
||||
s.bufptr = s.recvbuf[n:] // pointer update
|
||||
s.mu.Unlock()
|
||||
atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(n))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// deadline for current reading operation
|
||||
var timeout *time.Timer
|
||||
var c <-chan time.Time
|
||||
if !s.rd.IsZero() {
|
||||
if time.Now().After(s.rd) {
|
||||
s.mu.Unlock()
|
||||
return 0, errTimeout{}
|
||||
}
|
||||
|
||||
delay := s.rd.Sub(time.Now())
|
||||
timeout = time.NewTimer(delay)
|
||||
c = timeout.C
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// wait for read event or timeout
|
||||
select {
|
||||
case <-s.chReadEvent:
|
||||
case <-c:
|
||||
case <-s.die:
|
||||
case err = <-s.chReadError:
|
||||
if timeout != nil {
|
||||
timeout.Stop()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
if timeout != nil {
|
||||
timeout.Stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements net.Conn
|
||||
func (s *UDPSession) Write(b []byte) (n int, err error) {
|
||||
for {
|
||||
s.mu.Lock()
|
||||
if s.isClosed {
|
||||
s.mu.Unlock()
|
||||
return 0, errors.New(errBrokenPipe)
|
||||
}
|
||||
|
||||
// controls how much data will be sent to kcp core
|
||||
// to prevent the memory from exhuasting
|
||||
if s.kcp.WaitSnd() < int(s.kcp.snd_wnd) {
|
||||
n = len(b)
|
||||
for {
|
||||
if len(b) <= int(s.kcp.mss) {
|
||||
s.kcp.Send(b)
|
||||
break
|
||||
} else {
|
||||
s.kcp.Send(b[:s.kcp.mss])
|
||||
b = b[s.kcp.mss:]
|
||||
}
|
||||
}
|
||||
|
||||
// flush immediately if the queue is full
|
||||
if s.kcp.WaitSnd() >= int(s.kcp.snd_wnd) || !s.writeDelay {
|
||||
s.kcp.flush(false)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
atomic.AddUint64(&DefaultSnmp.BytesSent, uint64(n))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// deadline for current writing operation
|
||||
var timeout *time.Timer
|
||||
var c <-chan time.Time
|
||||
if !s.wd.IsZero() {
|
||||
if time.Now().After(s.wd) {
|
||||
s.mu.Unlock()
|
||||
return 0, errTimeout{}
|
||||
}
|
||||
delay := s.wd.Sub(time.Now())
|
||||
timeout = time.NewTimer(delay)
|
||||
c = timeout.C
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// wait for write event or timeout
|
||||
select {
|
||||
case <-s.chWriteEvent:
|
||||
case <-c:
|
||||
case <-s.die:
|
||||
case err = <-s.chWriteError:
|
||||
if timeout != nil {
|
||||
timeout.Stop()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
if timeout != nil {
|
||||
timeout.Stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (s *UDPSession) Close() error {
|
||||
// remove current session from updater & listener(if necessary)
|
||||
updater.removeSession(s)
|
||||
if s.l != nil { // notify listener
|
||||
s.l.closeSession(s.remote)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.isClosed {
|
||||
return errors.New(errBrokenPipe)
|
||||
}
|
||||
close(s.die)
|
||||
s.isClosed = true
|
||||
atomic.AddUint64(&DefaultSnmp.CurrEstab, ^uint64(0))
|
||||
if s.l == nil { // client socket close
|
||||
return s.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
|
||||
func (s *UDPSession) LocalAddr() net.Addr { return s.conn.LocalAddr() }
|
||||
|
||||
// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
|
||||
func (s *UDPSession) RemoteAddr() net.Addr { return s.remote }
|
||||
|
||||
// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
|
||||
func (s *UDPSession) SetDeadline(t time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.rd = t
|
||||
s.wd = t
|
||||
s.notifyReadEvent()
|
||||
s.notifyWriteEvent()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements the Conn SetReadDeadline method.
|
||||
func (s *UDPSession) SetReadDeadline(t time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.rd = t
|
||||
s.notifyReadEvent()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements the Conn SetWriteDeadline method.
|
||||
func (s *UDPSession) SetWriteDeadline(t time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.wd = t
|
||||
s.notifyWriteEvent()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDelay delays write for bulk transfer until the next update interval
|
||||
func (s *UDPSession) SetWriteDelay(delay bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.writeDelay = delay
|
||||
}
|
||||
|
||||
// SetWindowSize set maximum window size
|
||||
func (s *UDPSession) SetWindowSize(sndwnd, rcvwnd int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.kcp.WndSize(sndwnd, rcvwnd)
|
||||
}
|
||||
|
||||
// SetMtu sets the maximum transmission unit(not including UDP header)
|
||||
func (s *UDPSession) SetMtu(mtu int) bool {
|
||||
if mtu > mtuLimit {
|
||||
return false
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.kcp.SetMtu(mtu - s.headerSize)
|
||||
return true
|
||||
}
|
||||
|
||||
// SetStreamMode toggles the stream mode on/off
|
||||
func (s *UDPSession) SetStreamMode(enable bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if enable {
|
||||
s.kcp.stream = 1
|
||||
} else {
|
||||
s.kcp.stream = 0
|
||||
}
|
||||
}
|
||||
|
||||
// SetACKNoDelay changes ack flush option, set true to flush ack immediately,
|
||||
func (s *UDPSession) SetACKNoDelay(nodelay bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.ackNoDelay = nodelay
|
||||
}
|
||||
|
||||
// SetDUP duplicates udp packets for kcp output, for testing purpose only
|
||||
func (s *UDPSession) SetDUP(dup int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.dup = dup
|
||||
}
|
||||
|
||||
// SetNoDelay calls nodelay() of kcp
|
||||
// https://github.com/skywind3000/kcp/blob/master/README.en.md#protocol-configuration
|
||||
func (s *UDPSession) SetNoDelay(nodelay, interval, resend, nc int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.kcp.NoDelay(nodelay, interval, resend, nc)
|
||||
}
|
||||
|
||||
// SetDSCP sets the 6bit DSCP field of IP header, no effect if it's accepted from Listener
|
||||
func (s *UDPSession) SetDSCP(dscp int) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.l == nil {
|
||||
if nc, ok := s.conn.(net.Conn); ok {
|
||||
if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err != nil {
|
||||
return ipv6.NewConn(nc).SetTrafficClass(dscp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New(errInvalidOperation)
|
||||
}
|
||||
|
||||
// SetReadBuffer sets the socket read buffer, no effect if it's accepted from Listener
|
||||
func (s *UDPSession) SetReadBuffer(bytes int) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.l == nil {
|
||||
if nc, ok := s.conn.(setReadBuffer); ok {
|
||||
return nc.SetReadBuffer(bytes)
|
||||
}
|
||||
}
|
||||
return errors.New(errInvalidOperation)
|
||||
}
|
||||
|
||||
// SetWriteBuffer sets the socket write buffer, no effect if it's accepted from Listener
|
||||
func (s *UDPSession) SetWriteBuffer(bytes int) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.l == nil {
|
||||
if nc, ok := s.conn.(setWriteBuffer); ok {
|
||||
return nc.SetWriteBuffer(bytes)
|
||||
}
|
||||
}
|
||||
return errors.New(errInvalidOperation)
|
||||
}
|
||||
|
||||
// post-processing for sending a packet from kcp core
|
||||
// steps:
|
||||
// 0. Header extending
|
||||
// 1. FEC packet generation
|
||||
// 2. CRC32 integrity
|
||||
// 3. Encryption
|
||||
// 4. WriteTo kernel
|
||||
func (s *UDPSession) output(buf []byte) {
|
||||
var ecc [][]byte
|
||||
|
||||
// 0. extend buf's header space(if necessary)
|
||||
ext := buf
|
||||
if s.headerSize > 0 {
|
||||
ext = s.ext[:s.headerSize+len(buf)]
|
||||
copy(ext[s.headerSize:], buf)
|
||||
}
|
||||
|
||||
// 1. FEC encoding
|
||||
if s.fecEncoder != nil {
|
||||
ecc = s.fecEncoder.encode(ext)
|
||||
}
|
||||
|
||||
// 2&3. crc32 & encryption
|
||||
if s.block != nil {
|
||||
s.nonce.Fill(ext[:nonceSize])
|
||||
checksum := crc32.ChecksumIEEE(ext[cryptHeaderSize:])
|
||||
binary.LittleEndian.PutUint32(ext[nonceSize:], checksum)
|
||||
s.block.Encrypt(ext, ext)
|
||||
|
||||
for k := range ecc {
|
||||
s.nonce.Fill(ecc[k][:nonceSize])
|
||||
checksum := crc32.ChecksumIEEE(ecc[k][cryptHeaderSize:])
|
||||
binary.LittleEndian.PutUint32(ecc[k][nonceSize:], checksum)
|
||||
s.block.Encrypt(ecc[k], ecc[k])
|
||||
}
|
||||
}
|
||||
|
||||
// 4. WriteTo kernel
|
||||
nbytes := 0
|
||||
npkts := 0
|
||||
for i := 0; i < s.dup+1; i++ {
|
||||
if n, err := s.conn.WriteTo(ext, s.remote); err == nil {
|
||||
nbytes += n
|
||||
npkts++
|
||||
} else {
|
||||
s.notifyWriteError(err)
|
||||
}
|
||||
}
|
||||
|
||||
for k := range ecc {
|
||||
if n, err := s.conn.WriteTo(ecc[k], s.remote); err == nil {
|
||||
nbytes += n
|
||||
npkts++
|
||||
} else {
|
||||
s.notifyWriteError(err)
|
||||
}
|
||||
}
|
||||
atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts))
|
||||
atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes))
|
||||
}
|
||||
|
||||
// kcp update, returns interval for next calling
|
||||
func (s *UDPSession) update() (interval time.Duration) {
|
||||
s.mu.Lock()
|
||||
waitsnd := s.kcp.WaitSnd()
|
||||
interval = time.Duration(s.kcp.flush(false)) * time.Millisecond
|
||||
if s.kcp.WaitSnd() < waitsnd {
|
||||
s.notifyWriteEvent()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// GetConv gets conversation id of a session
|
||||
func (s *UDPSession) GetConv() uint32 { return s.kcp.conv }
|
||||
|
||||
func (s *UDPSession) notifyReadEvent() {
|
||||
select {
|
||||
case s.chReadEvent <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPSession) notifyWriteEvent() {
|
||||
select {
|
||||
case s.chWriteEvent <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPSession) notifyWriteError(err error) {
|
||||
select {
|
||||
case s.chWriteError <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPSession) kcpInput(data []byte) {
|
||||
var kcpInErrors, fecErrs, fecRecovered, fecParityShards uint64
|
||||
|
||||
if s.fecDecoder != nil {
|
||||
if len(data) > fecHeaderSize { // must be larger than fec header size
|
||||
f := s.fecDecoder.decodeBytes(data)
|
||||
if f.flag == typeData || f.flag == typeFEC { // header check
|
||||
if f.flag == typeFEC {
|
||||
fecParityShards++
|
||||
}
|
||||
recovers := s.fecDecoder.decode(f)
|
||||
|
||||
s.mu.Lock()
|
||||
waitsnd := s.kcp.WaitSnd()
|
||||
if f.flag == typeData {
|
||||
if ret := s.kcp.Input(data[fecHeaderSizePlus2:], true, s.ackNoDelay); ret != 0 {
|
||||
kcpInErrors++
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range recovers {
|
||||
if len(r) >= 2 { // must be larger than 2bytes
|
||||
sz := binary.LittleEndian.Uint16(r)
|
||||
if int(sz) <= len(r) && sz >= 2 {
|
||||
if ret := s.kcp.Input(r[2:sz], false, s.ackNoDelay); ret == 0 {
|
||||
fecRecovered++
|
||||
} else {
|
||||
kcpInErrors++
|
||||
}
|
||||
} else {
|
||||
fecErrs++
|
||||
}
|
||||
} else {
|
||||
fecErrs++
|
||||
}
|
||||
}
|
||||
|
||||
// to notify the readers to receive the data
|
||||
if n := s.kcp.PeekSize(); n > 0 {
|
||||
s.notifyReadEvent()
|
||||
}
|
||||
// to notify the writers when queue is shorter(e.g. ACKed)
|
||||
if s.kcp.WaitSnd() < waitsnd {
|
||||
s.notifyWriteEvent()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
} else {
|
||||
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
|
||||
}
|
||||
} else {
|
||||
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
|
||||
}
|
||||
} else {
|
||||
s.mu.Lock()
|
||||
waitsnd := s.kcp.WaitSnd()
|
||||
if ret := s.kcp.Input(data, true, s.ackNoDelay); ret != 0 {
|
||||
kcpInErrors++
|
||||
}
|
||||
if n := s.kcp.PeekSize(); n > 0 {
|
||||
s.notifyReadEvent()
|
||||
}
|
||||
if s.kcp.WaitSnd() < waitsnd {
|
||||
s.notifyWriteEvent()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
atomic.AddUint64(&DefaultSnmp.InPkts, 1)
|
||||
atomic.AddUint64(&DefaultSnmp.InBytes, uint64(len(data)))
|
||||
if fecParityShards > 0 {
|
||||
atomic.AddUint64(&DefaultSnmp.FECParityShards, fecParityShards)
|
||||
}
|
||||
if kcpInErrors > 0 {
|
||||
atomic.AddUint64(&DefaultSnmp.KCPInErrors, kcpInErrors)
|
||||
}
|
||||
if fecErrs > 0 {
|
||||
atomic.AddUint64(&DefaultSnmp.FECErrs, fecErrs)
|
||||
}
|
||||
if fecRecovered > 0 {
|
||||
atomic.AddUint64(&DefaultSnmp.FECRecovered, fecRecovered)
|
||||
}
|
||||
}
|
||||
|
||||
// the read loop for a client session
|
||||
func (s *UDPSession) readLoop() {
|
||||
buf := make([]byte, mtuLimit)
|
||||
var src string
|
||||
for {
|
||||
if n, addr, err := s.conn.ReadFrom(buf); err == nil {
|
||||
// make sure the packet is from the same source
|
||||
if src == "" { // set source address
|
||||
src = addr.String()
|
||||
} else if addr.String() != src {
|
||||
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
|
||||
continue
|
||||
}
|
||||
|
||||
if n >= s.headerSize+IKCP_OVERHEAD {
|
||||
data := buf[:n]
|
||||
dataValid := false
|
||||
if s.block != nil {
|
||||
s.block.Decrypt(data, data)
|
||||
data = data[nonceSize:]
|
||||
checksum := crc32.ChecksumIEEE(data[crcSize:])
|
||||
if checksum == binary.LittleEndian.Uint32(data) {
|
||||
data = data[crcSize:]
|
||||
dataValid = true
|
||||
} else {
|
||||
atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1)
|
||||
}
|
||||
} else if s.block == nil {
|
||||
dataValid = true
|
||||
}
|
||||
|
||||
if dataValid {
|
||||
s.kcpInput(data)
|
||||
}
|
||||
} else {
|
||||
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
|
||||
}
|
||||
} else {
|
||||
s.chReadError <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
// Listener defines a server which will be waiting to accept incoming connections
|
||||
Listener struct {
|
||||
block BlockCrypt // block encryption
|
||||
dataShards int // FEC data shard
|
||||
parityShards int // FEC parity shard
|
||||
fecDecoder *fecDecoder // FEC mock initialization
|
||||
conn net.PacketConn // the underlying packet connection
|
||||
|
||||
sessions map[string]*UDPSession // all sessions accepted by this Listener
|
||||
sessionLock sync.Mutex
|
||||
chAccepts chan *UDPSession // Listen() backlog
|
||||
chSessionClosed chan net.Addr // session close queue
|
||||
headerSize int // the additional header to a KCP frame
|
||||
die chan struct{} // notify the listener has closed
|
||||
rd atomic.Value // read deadline for Accept()
|
||||
wd atomic.Value
|
||||
}
|
||||
)
|
||||
|
||||
// monitor incoming data for all connections of server
|
||||
func (l *Listener) monitor() {
|
||||
// a cache for session object last used
|
||||
var lastAddr string
|
||||
var lastSession *UDPSession
|
||||
buf := make([]byte, mtuLimit)
|
||||
for {
|
||||
if n, from, err := l.conn.ReadFrom(buf); err == nil {
|
||||
if n >= l.headerSize+IKCP_OVERHEAD {
|
||||
data := buf[:n]
|
||||
dataValid := false
|
||||
if l.block != nil {
|
||||
l.block.Decrypt(data, data)
|
||||
data = data[nonceSize:]
|
||||
checksum := crc32.ChecksumIEEE(data[crcSize:])
|
||||
if checksum == binary.LittleEndian.Uint32(data) {
|
||||
data = data[crcSize:]
|
||||
dataValid = true
|
||||
} else {
|
||||
atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1)
|
||||
}
|
||||
} else if l.block == nil {
|
||||
dataValid = true
|
||||
}
|
||||
|
||||
if dataValid {
|
||||
addr := from.String()
|
||||
var s *UDPSession
|
||||
var ok bool
|
||||
|
||||
// the packets received from an address always come in batch,
|
||||
// cache the session for next packet, without querying map.
|
||||
if addr == lastAddr {
|
||||
s, ok = lastSession, true
|
||||
} else {
|
||||
l.sessionLock.Lock()
|
||||
if s, ok = l.sessions[addr]; ok {
|
||||
lastSession = s
|
||||
lastAddr = addr
|
||||
}
|
||||
l.sessionLock.Unlock()
|
||||
}
|
||||
|
||||
if !ok { // new session
|
||||
if len(l.chAccepts) < cap(l.chAccepts) { // do not let the new sessions overwhelm accept queue
|
||||
var conv uint32
|
||||
convValid := false
|
||||
if l.fecDecoder != nil {
|
||||
isfec := binary.LittleEndian.Uint16(data[4:])
|
||||
if isfec == typeData {
|
||||
conv = binary.LittleEndian.Uint32(data[fecHeaderSizePlus2:])
|
||||
convValid = true
|
||||
}
|
||||
} else {
|
||||
conv = binary.LittleEndian.Uint32(data)
|
||||
convValid = true
|
||||
}
|
||||
|
||||
if convValid { // creates a new session only if the 'conv' field in kcp is accessible
|
||||
s := newUDPSession(conv, l.dataShards, l.parityShards, l, l.conn, from, l.block)
|
||||
s.kcpInput(data)
|
||||
l.sessionLock.Lock()
|
||||
l.sessions[addr] = s
|
||||
l.sessionLock.Unlock()
|
||||
l.chAccepts <- s
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.kcpInput(data)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
|
||||
}
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetReadBuffer sets the socket read buffer for the Listener
|
||||
func (l *Listener) SetReadBuffer(bytes int) error {
|
||||
if nc, ok := l.conn.(setReadBuffer); ok {
|
||||
return nc.SetReadBuffer(bytes)
|
||||
}
|
||||
return errors.New(errInvalidOperation)
|
||||
}
|
||||
|
||||
// SetWriteBuffer sets the socket write buffer for the Listener
|
||||
func (l *Listener) SetWriteBuffer(bytes int) error {
|
||||
if nc, ok := l.conn.(setWriteBuffer); ok {
|
||||
return nc.SetWriteBuffer(bytes)
|
||||
}
|
||||
return errors.New(errInvalidOperation)
|
||||
}
|
||||
|
||||
// SetDSCP sets the 6bit DSCP field of IP header
|
||||
func (l *Listener) SetDSCP(dscp int) error {
|
||||
if nc, ok := l.conn.(net.Conn); ok {
|
||||
if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err != nil {
|
||||
return ipv6.NewConn(nc).SetTrafficClass(dscp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return errors.New(errInvalidOperation)
|
||||
}
|
||||
|
||||
// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn.
|
||||
func (l *Listener) Accept() (net.Conn, error) {
|
||||
return l.AcceptKCP()
|
||||
}
|
||||
|
||||
// AcceptKCP accepts a KCP connection
|
||||
func (l *Listener) AcceptKCP() (*UDPSession, error) {
|
||||
var timeout <-chan time.Time
|
||||
if tdeadline, ok := l.rd.Load().(time.Time); ok && !tdeadline.IsZero() {
|
||||
timeout = time.After(tdeadline.Sub(time.Now()))
|
||||
}
|
||||
|
||||
select {
|
||||
case <-timeout:
|
||||
return nil, &errTimeout{}
|
||||
case c := <-l.chAccepts:
|
||||
return c, nil
|
||||
case <-l.die:
|
||||
return nil, errors.New(errBrokenPipe)
|
||||
}
|
||||
}
|
||||
|
||||
// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
|
||||
func (l *Listener) SetDeadline(t time.Time) error {
|
||||
l.SetReadDeadline(t)
|
||||
l.SetWriteDeadline(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements the Conn SetReadDeadline method.
|
||||
func (l *Listener) SetReadDeadline(t time.Time) error {
|
||||
l.rd.Store(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements the Conn SetWriteDeadline method.
|
||||
func (l *Listener) SetWriteDeadline(t time.Time) error {
|
||||
l.wd.Store(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close stops listening on the UDP address. Already Accepted connections are not closed.
|
||||
func (l *Listener) Close() error {
|
||||
close(l.die)
|
||||
return l.conn.Close()
|
||||
}
|
||||
|
||||
// closeSession notify the listener that a session has closed
|
||||
func (l *Listener) closeSession(remote net.Addr) (ret bool) {
|
||||
l.sessionLock.Lock()
|
||||
defer l.sessionLock.Unlock()
|
||||
if _, ok := l.sessions[remote.String()]; ok {
|
||||
delete(l.sessions, remote.String())
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address, The Addr returned is shared by all invocations of Addr, so do not modify it.
|
||||
func (l *Listener) Addr() net.Addr { return l.conn.LocalAddr() }
|
||||
|
||||
// Listen listens for incoming KCP packets addressed to the local address laddr on the network "udp",
|
||||
func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr, nil, 0, 0) }
|
||||
|
||||
// ListenWithOptions listens for incoming KCP packets addressed to the local address laddr on the network "udp" with packet encryption,
|
||||
// rdataShards, parityShards defines Reed-Solomon Erasure Coding parametes
|
||||
func ListenWithOptions(laddr string, block BlockCrypt, dataShards, parityShards int) (*Listener, error) {
|
||||
udpaddr, err := net.ResolveUDPAddr("udp", laddr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "net.ResolveUDPAddr")
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", udpaddr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "net.ListenUDP")
|
||||
}
|
||||
|
||||
return ServeConn(block, dataShards, parityShards, conn)
|
||||
}
|
||||
|
||||
// ServeConn serves KCP protocol for a single packet connection.
|
||||
func ServeConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*Listener, error) {
|
||||
l := new(Listener)
|
||||
l.conn = conn
|
||||
l.sessions = make(map[string]*UDPSession)
|
||||
l.chAccepts = make(chan *UDPSession, acceptBacklog)
|
||||
l.chSessionClosed = make(chan net.Addr)
|
||||
l.die = make(chan struct{})
|
||||
l.dataShards = dataShards
|
||||
l.parityShards = parityShards
|
||||
l.block = block
|
||||
l.fecDecoder = newFECDecoder(rxFECMulti*(dataShards+parityShards), dataShards, parityShards)
|
||||
|
||||
// calculate header size
|
||||
if l.block != nil {
|
||||
l.headerSize += cryptHeaderSize
|
||||
}
|
||||
if l.fecDecoder != nil {
|
||||
l.headerSize += fecHeaderSizePlus2
|
||||
}
|
||||
|
||||
go l.monitor()
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// Dial connects to the remote address "raddr" on the network "udp"
|
||||
func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr, nil, 0, 0) }
|
||||
|
||||
// DialWithOptions connects to the remote address "raddr" on the network "udp" with packet encryption
|
||||
func DialWithOptions(raddr string, block BlockCrypt, dataShards, parityShards int) (*UDPSession, error) {
|
||||
// network type detection
|
||||
udpaddr, err := net.ResolveUDPAddr("udp", raddr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "net.ResolveUDPAddr")
|
||||
}
|
||||
network := "udp4"
|
||||
if udpaddr.IP.To4() == nil {
|
||||
network = "udp"
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP(network, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "net.DialUDP")
|
||||
}
|
||||
|
||||
return NewConn(raddr, block, dataShards, parityShards, conn)
|
||||
}
|
||||
|
||||
// NewConn establishes a session and talks KCP protocol over a packet connection.
|
||||
func NewConn(raddr string, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) {
|
||||
udpaddr, err := net.ResolveUDPAddr("udp", raddr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "net.ResolveUDPAddr")
|
||||
}
|
||||
|
||||
var convid uint32
|
||||
binary.Read(rand.Reader, binary.LittleEndian, &convid)
|
||||
return newUDPSession(convid, dataShards, parityShards, nil, conn, udpaddr, block), nil
|
||||
}
|
||||
|
||||
// monotonic reference time point
|
||||
var refTime time.Time = time.Now()
|
||||
|
||||
// currentMs returns current elasped monotonic milliseconds since program startup
|
||||
func currentMs() uint32 { return uint32(time.Now().Sub(refTime) / time.Millisecond) }
|
|
@ -0,0 +1,475 @@
|
|||
package kcp
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
const portEcho = "127.0.0.1:9999"
|
||||
const portSink = "127.0.0.1:19999"
|
||||
const portTinyBufferEcho = "127.0.0.1:29999"
|
||||
const portListerner = "127.0.0.1:9998"
|
||||
|
||||
var key = []byte("testkey")
|
||||
var pass = pbkdf2.Key(key, []byte(portSink), 4096, 32, sha1.New)
|
||||
|
||||
func init() {
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
|
||||
go echoServer()
|
||||
go sinkServer()
|
||||
go tinyBufferEchoServer()
|
||||
println("beginning tests, encryption:salsa20, fec:10/3")
|
||||
}
|
||||
|
||||
func dialEcho() (*UDPSession, error) {
|
||||
//block, _ := NewNoneBlockCrypt(pass)
|
||||
//block, _ := NewSimpleXORBlockCrypt(pass)
|
||||
//block, _ := NewTEABlockCrypt(pass[:16])
|
||||
//block, _ := NewAESBlockCrypt(pass)
|
||||
block, _ := NewSalsa20BlockCrypt(pass)
|
||||
sess, err := DialWithOptions(portEcho, block, 10, 3)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sess.SetStreamMode(true)
|
||||
sess.SetStreamMode(false)
|
||||
sess.SetStreamMode(true)
|
||||
sess.SetWindowSize(1024, 1024)
|
||||
sess.SetReadBuffer(16 * 1024 * 1024)
|
||||
sess.SetWriteBuffer(16 * 1024 * 1024)
|
||||
sess.SetStreamMode(true)
|
||||
sess.SetNoDelay(1, 10, 2, 1)
|
||||
sess.SetMtu(1400)
|
||||
sess.SetMtu(1600)
|
||||
sess.SetMtu(1400)
|
||||
sess.SetACKNoDelay(true)
|
||||
sess.SetACKNoDelay(false)
|
||||
sess.SetDeadline(time.Now().Add(time.Minute))
|
||||
return sess, err
|
||||
}
|
||||
|
||||
func dialSink() (*UDPSession, error) {
|
||||
sess, err := DialWithOptions(portSink, nil, 0, 0)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sess.SetStreamMode(true)
|
||||
sess.SetWindowSize(1024, 1024)
|
||||
sess.SetReadBuffer(16 * 1024 * 1024)
|
||||
sess.SetWriteBuffer(16 * 1024 * 1024)
|
||||
sess.SetStreamMode(true)
|
||||
sess.SetNoDelay(1, 10, 2, 1)
|
||||
sess.SetMtu(1400)
|
||||
sess.SetACKNoDelay(false)
|
||||
sess.SetDeadline(time.Now().Add(time.Minute))
|
||||
return sess, err
|
||||
}
|
||||
|
||||
func dialTinyBufferEcho() (*UDPSession, error) {
|
||||
//block, _ := NewNoneBlockCrypt(pass)
|
||||
//block, _ := NewSimpleXORBlockCrypt(pass)
|
||||
//block, _ := NewTEABlockCrypt(pass[:16])
|
||||
//block, _ := NewAESBlockCrypt(pass)
|
||||
block, _ := NewSalsa20BlockCrypt(pass)
|
||||
sess, err := DialWithOptions(portTinyBufferEcho, block, 10, 3)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return sess, err
|
||||
}
|
||||
|
||||
//////////////////////////
|
||||
func listenEcho() (net.Listener, error) {
|
||||
//block, _ := NewNoneBlockCrypt(pass)
|
||||
//block, _ := NewSimpleXORBlockCrypt(pass)
|
||||
//block, _ := NewTEABlockCrypt(pass[:16])
|
||||
//block, _ := NewAESBlockCrypt(pass)
|
||||
block, _ := NewSalsa20BlockCrypt(pass)
|
||||
return ListenWithOptions(portEcho, block, 10, 3)
|
||||
}
|
||||
func listenTinyBufferEcho() (net.Listener, error) {
|
||||
//block, _ := NewNoneBlockCrypt(pass)
|
||||
//block, _ := NewSimpleXORBlockCrypt(pass)
|
||||
//block, _ := NewTEABlockCrypt(pass[:16])
|
||||
//block, _ := NewAESBlockCrypt(pass)
|
||||
block, _ := NewSalsa20BlockCrypt(pass)
|
||||
return ListenWithOptions(portTinyBufferEcho, block, 10, 3)
|
||||
}
|
||||
|
||||
func listenSink() (net.Listener, error) {
|
||||
return ListenWithOptions(portSink, nil, 0, 0)
|
||||
}
|
||||
|
||||
func echoServer() {
|
||||
l, err := listenEcho()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
kcplistener := l.(*Listener)
|
||||
kcplistener.SetReadBuffer(4 * 1024 * 1024)
|
||||
kcplistener.SetWriteBuffer(4 * 1024 * 1024)
|
||||
kcplistener.SetDSCP(46)
|
||||
for {
|
||||
s, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// coverage test
|
||||
s.(*UDPSession).SetReadBuffer(4 * 1024 * 1024)
|
||||
s.(*UDPSession).SetWriteBuffer(4 * 1024 * 1024)
|
||||
go handleEcho(s.(*UDPSession))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func sinkServer() {
|
||||
l, err := listenSink()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
kcplistener := l.(*Listener)
|
||||
kcplistener.SetReadBuffer(4 * 1024 * 1024)
|
||||
kcplistener.SetWriteBuffer(4 * 1024 * 1024)
|
||||
kcplistener.SetDSCP(46)
|
||||
for {
|
||||
s, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go handleSink(s.(*UDPSession))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func tinyBufferEchoServer() {
|
||||
l, err := listenTinyBufferEcho()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
s, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go handleTinyBufferEcho(s.(*UDPSession))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
///////////////////////////
|
||||
|
||||
func handleEcho(conn *UDPSession) {
|
||||
conn.SetStreamMode(true)
|
||||
conn.SetWindowSize(4096, 4096)
|
||||
conn.SetNoDelay(1, 10, 2, 1)
|
||||
conn.SetDSCP(46)
|
||||
conn.SetMtu(1400)
|
||||
conn.SetACKNoDelay(false)
|
||||
conn.SetReadDeadline(time.Now().Add(time.Hour))
|
||||
conn.SetWriteDeadline(time.Now().Add(time.Hour))
|
||||
buf := make([]byte, 65536)
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
conn.Write(buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
func handleSink(conn *UDPSession) {
|
||||
conn.SetStreamMode(true)
|
||||
conn.SetWindowSize(4096, 4096)
|
||||
conn.SetNoDelay(1, 10, 2, 1)
|
||||
conn.SetDSCP(46)
|
||||
conn.SetMtu(1400)
|
||||
conn.SetACKNoDelay(false)
|
||||
conn.SetReadDeadline(time.Now().Add(time.Hour))
|
||||
conn.SetWriteDeadline(time.Now().Add(time.Hour))
|
||||
buf := make([]byte, 65536)
|
||||
for {
|
||||
_, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleTinyBufferEcho(conn *UDPSession) {
|
||||
conn.SetStreamMode(true)
|
||||
buf := make([]byte, 2)
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
conn.Write(buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////
|
||||
|
||||
func TestTimeout(t *testing.T) {
|
||||
cli, err := dialEcho()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
buf := make([]byte, 10)
|
||||
|
||||
//timeout
|
||||
cli.SetDeadline(time.Now().Add(time.Second))
|
||||
<-time.After(2 * time.Second)
|
||||
n, err := cli.Read(buf)
|
||||
if n != 0 || err == nil {
|
||||
t.Fail()
|
||||
}
|
||||
cli.Close()
|
||||
}
|
||||
|
||||
func TestSendRecv(t *testing.T) {
|
||||
cli, err := dialEcho()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
cli.SetWriteDelay(true)
|
||||
cli.SetDUP(1)
|
||||
const N = 100
|
||||
buf := make([]byte, 10)
|
||||
for i := 0; i < N; i++ {
|
||||
msg := fmt.Sprintf("hello%v", i)
|
||||
cli.Write([]byte(msg))
|
||||
if n, err := cli.Read(buf); err == nil {
|
||||
if string(buf[:n]) != msg {
|
||||
t.Fail()
|
||||
}
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
cli.Close()
|
||||
}
|
||||
|
||||
func TestTinyBufferReceiver(t *testing.T) {
|
||||
cli, err := dialTinyBufferEcho()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
const N = 100
|
||||
snd := byte(0)
|
||||
fillBuffer := func(buf []byte) {
|
||||
for i := 0; i < len(buf); i++ {
|
||||
buf[i] = snd
|
||||
snd++
|
||||
}
|
||||
}
|
||||
|
||||
rcv := byte(0)
|
||||
check := func(buf []byte) bool {
|
||||
for i := 0; i < len(buf); i++ {
|
||||
if buf[i] != rcv {
|
||||
return false
|
||||
}
|
||||
rcv++
|
||||
}
|
||||
return true
|
||||
}
|
||||
sndbuf := make([]byte, 7)
|
||||
rcvbuf := make([]byte, 7)
|
||||
for i := 0; i < N; i++ {
|
||||
fillBuffer(sndbuf)
|
||||
cli.Write(sndbuf)
|
||||
if n, err := io.ReadFull(cli, rcvbuf); err == nil {
|
||||
if !check(rcvbuf[:n]) {
|
||||
t.Fail()
|
||||
}
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
cli.Close()
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
cli, err := dialEcho()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
buf := make([]byte, 10)
|
||||
|
||||
cli.Close()
|
||||
if cli.Close() == nil {
|
||||
t.Fail()
|
||||
}
|
||||
n, err := cli.Write(buf)
|
||||
if n != 0 || err == nil {
|
||||
t.Fail()
|
||||
}
|
||||
n, err = cli.Read(buf)
|
||||
if n != 0 || err == nil {
|
||||
t.Fail()
|
||||
}
|
||||
cli.Close()
|
||||
}
|
||||
|
||||
func TestParallel1024CLIENT_64BMSG_64CNT(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1024)
|
||||
for i := 0; i < 1024; i++ {
|
||||
go parallel_client(&wg)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func parallel_client(wg *sync.WaitGroup) (err error) {
|
||||
cli, err := dialEcho()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = echo_tester(cli, 64, 64)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
|
||||
func BenchmarkEchoSpeed4K(b *testing.B) {
|
||||
speedclient(b, 4096)
|
||||
}
|
||||
|
||||
func BenchmarkEchoSpeed64K(b *testing.B) {
|
||||
speedclient(b, 65536)
|
||||
}
|
||||
|
||||
func BenchmarkEchoSpeed512K(b *testing.B) {
|
||||
speedclient(b, 524288)
|
||||
}
|
||||
|
||||
func BenchmarkEchoSpeed1M(b *testing.B) {
|
||||
speedclient(b, 1048576)
|
||||
}
|
||||
|
||||
func speedclient(b *testing.B, nbytes int) {
|
||||
b.ReportAllocs()
|
||||
cli, err := dialEcho()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := echo_tester(cli, nbytes, b.N); err != nil {
|
||||
b.Fail()
|
||||
}
|
||||
b.SetBytes(int64(nbytes))
|
||||
}
|
||||
|
||||
func BenchmarkSinkSpeed4K(b *testing.B) {
|
||||
sinkclient(b, 4096)
|
||||
}
|
||||
|
||||
func BenchmarkSinkSpeed64K(b *testing.B) {
|
||||
sinkclient(b, 65536)
|
||||
}
|
||||
|
||||
func BenchmarkSinkSpeed256K(b *testing.B) {
|
||||
sinkclient(b, 524288)
|
||||
}
|
||||
|
||||
func BenchmarkSinkSpeed1M(b *testing.B) {
|
||||
sinkclient(b, 1048576)
|
||||
}
|
||||
|
||||
func sinkclient(b *testing.B, nbytes int) {
|
||||
b.ReportAllocs()
|
||||
cli, err := dialSink()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sink_tester(cli, nbytes, b.N)
|
||||
b.SetBytes(int64(nbytes))
|
||||
}
|
||||
|
||||
func echo_tester(cli net.Conn, msglen, msgcount int) error {
|
||||
buf := make([]byte, msglen)
|
||||
for i := 0; i < msgcount; i++ {
|
||||
// send packet
|
||||
if _, err := cli.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// receive packet
|
||||
nrecv := 0
|
||||
for {
|
||||
n, err := cli.Read(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
} else {
|
||||
nrecv += n
|
||||
if nrecv == msglen {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sink_tester(cli *UDPSession, msglen, msgcount int) error {
|
||||
// sender
|
||||
buf := make([]byte, msglen)
|
||||
for i := 0; i < msgcount; i++ {
|
||||
if _, err := cli.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSNMP(t *testing.T) {
|
||||
t.Log(DefaultSnmp.Copy())
|
||||
t.Log(DefaultSnmp.Header())
|
||||
t.Log(DefaultSnmp.ToSlice())
|
||||
DefaultSnmp.Reset()
|
||||
t.Log(DefaultSnmp.ToSlice())
|
||||
}
|
||||
|
||||
func TestListenerClose(t *testing.T) {
|
||||
l, err := ListenWithOptions(portListerner, nil, 10, 3)
|
||||
if err != nil {
|
||||
t.Fail()
|
||||
}
|
||||
l.SetReadDeadline(time.Now().Add(time.Second))
|
||||
l.SetWriteDeadline(time.Now().Add(time.Second))
|
||||
l.SetDeadline(time.Now().Add(time.Second))
|
||||
time.Sleep(2 * time.Second)
|
||||
if _, err := l.Accept(); err == nil {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
l.Close()
|
||||
fakeaddr, _ := net.ResolveUDPAddr("udp6", "127.0.0.1:1111")
|
||||
if l.closeSession(fakeaddr) {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,164 @@
|
|||
package kcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Snmp defines network statistics indicator
|
||||
type Snmp struct {
|
||||
BytesSent uint64 // bytes sent from upper level
|
||||
BytesReceived uint64 // bytes received to upper level
|
||||
MaxConn uint64 // max number of connections ever reached
|
||||
ActiveOpens uint64 // accumulated active open connections
|
||||
PassiveOpens uint64 // accumulated passive open connections
|
||||
CurrEstab uint64 // current number of established connections
|
||||
InErrs uint64 // UDP read errors reported from net.PacketConn
|
||||
InCsumErrors uint64 // checksum errors from CRC32
|
||||
KCPInErrors uint64 // packet iput errors reported from KCP
|
||||
InPkts uint64 // incoming packets count
|
||||
OutPkts uint64 // outgoing packets count
|
||||
InSegs uint64 // incoming KCP segments
|
||||
OutSegs uint64 // outgoing KCP segments
|
||||
InBytes uint64 // UDP bytes received
|
||||
OutBytes uint64 // UDP bytes sent
|
||||
RetransSegs uint64 // accmulated retransmited segments
|
||||
FastRetransSegs uint64 // accmulated fast retransmitted segments
|
||||
EarlyRetransSegs uint64 // accmulated early retransmitted segments
|
||||
LostSegs uint64 // number of segs infered as lost
|
||||
RepeatSegs uint64 // number of segs duplicated
|
||||
FECRecovered uint64 // correct packets recovered from FEC
|
||||
FECErrs uint64 // incorrect packets recovered from FEC
|
||||
FECParityShards uint64 // FEC segments received
|
||||
FECShortShards uint64 // number of data shards that's not enough for recovery
|
||||
}
|
||||
|
||||
func newSnmp() *Snmp {
|
||||
return new(Snmp)
|
||||
}
|
||||
|
||||
// Header returns all field names
|
||||
func (s *Snmp) Header() []string {
|
||||
return []string{
|
||||
"BytesSent",
|
||||
"BytesReceived",
|
||||
"MaxConn",
|
||||
"ActiveOpens",
|
||||
"PassiveOpens",
|
||||
"CurrEstab",
|
||||
"InErrs",
|
||||
"InCsumErrors",
|
||||
"KCPInErrors",
|
||||
"InPkts",
|
||||
"OutPkts",
|
||||
"InSegs",
|
||||
"OutSegs",
|
||||
"InBytes",
|
||||
"OutBytes",
|
||||
"RetransSegs",
|
||||
"FastRetransSegs",
|
||||
"EarlyRetransSegs",
|
||||
"LostSegs",
|
||||
"RepeatSegs",
|
||||
"FECParityShards",
|
||||
"FECErrs",
|
||||
"FECRecovered",
|
||||
"FECShortShards",
|
||||
}
|
||||
}
|
||||
|
||||
// ToSlice returns current snmp info as slice
|
||||
func (s *Snmp) ToSlice() []string {
|
||||
snmp := s.Copy()
|
||||
return []string{
|
||||
fmt.Sprint(snmp.BytesSent),
|
||||
fmt.Sprint(snmp.BytesReceived),
|
||||
fmt.Sprint(snmp.MaxConn),
|
||||
fmt.Sprint(snmp.ActiveOpens),
|
||||
fmt.Sprint(snmp.PassiveOpens),
|
||||
fmt.Sprint(snmp.CurrEstab),
|
||||
fmt.Sprint(snmp.InErrs),
|
||||
fmt.Sprint(snmp.InCsumErrors),
|
||||
fmt.Sprint(snmp.KCPInErrors),
|
||||
fmt.Sprint(snmp.InPkts),
|
||||
fmt.Sprint(snmp.OutPkts),
|
||||
fmt.Sprint(snmp.InSegs),
|
||||
fmt.Sprint(snmp.OutSegs),
|
||||
fmt.Sprint(snmp.InBytes),
|
||||
fmt.Sprint(snmp.OutBytes),
|
||||
fmt.Sprint(snmp.RetransSegs),
|
||||
fmt.Sprint(snmp.FastRetransSegs),
|
||||
fmt.Sprint(snmp.EarlyRetransSegs),
|
||||
fmt.Sprint(snmp.LostSegs),
|
||||
fmt.Sprint(snmp.RepeatSegs),
|
||||
fmt.Sprint(snmp.FECParityShards),
|
||||
fmt.Sprint(snmp.FECErrs),
|
||||
fmt.Sprint(snmp.FECRecovered),
|
||||
fmt.Sprint(snmp.FECShortShards),
|
||||
}
|
||||
}
|
||||
|
||||
// Copy make a copy of current snmp snapshot
|
||||
func (s *Snmp) Copy() *Snmp {
|
||||
d := newSnmp()
|
||||
d.BytesSent = atomic.LoadUint64(&s.BytesSent)
|
||||
d.BytesReceived = atomic.LoadUint64(&s.BytesReceived)
|
||||
d.MaxConn = atomic.LoadUint64(&s.MaxConn)
|
||||
d.ActiveOpens = atomic.LoadUint64(&s.ActiveOpens)
|
||||
d.PassiveOpens = atomic.LoadUint64(&s.PassiveOpens)
|
||||
d.CurrEstab = atomic.LoadUint64(&s.CurrEstab)
|
||||
d.InErrs = atomic.LoadUint64(&s.InErrs)
|
||||
d.InCsumErrors = atomic.LoadUint64(&s.InCsumErrors)
|
||||
d.KCPInErrors = atomic.LoadUint64(&s.KCPInErrors)
|
||||
d.InPkts = atomic.LoadUint64(&s.InPkts)
|
||||
d.OutPkts = atomic.LoadUint64(&s.OutPkts)
|
||||
d.InSegs = atomic.LoadUint64(&s.InSegs)
|
||||
d.OutSegs = atomic.LoadUint64(&s.OutSegs)
|
||||
d.InBytes = atomic.LoadUint64(&s.InBytes)
|
||||
d.OutBytes = atomic.LoadUint64(&s.OutBytes)
|
||||
d.RetransSegs = atomic.LoadUint64(&s.RetransSegs)
|
||||
d.FastRetransSegs = atomic.LoadUint64(&s.FastRetransSegs)
|
||||
d.EarlyRetransSegs = atomic.LoadUint64(&s.EarlyRetransSegs)
|
||||
d.LostSegs = atomic.LoadUint64(&s.LostSegs)
|
||||
d.RepeatSegs = atomic.LoadUint64(&s.RepeatSegs)
|
||||
d.FECParityShards = atomic.LoadUint64(&s.FECParityShards)
|
||||
d.FECErrs = atomic.LoadUint64(&s.FECErrs)
|
||||
d.FECRecovered = atomic.LoadUint64(&s.FECRecovered)
|
||||
d.FECShortShards = atomic.LoadUint64(&s.FECShortShards)
|
||||
return d
|
||||
}
|
||||
|
||||
// Reset values to zero
|
||||
func (s *Snmp) Reset() {
|
||||
atomic.StoreUint64(&s.BytesSent, 0)
|
||||
atomic.StoreUint64(&s.BytesReceived, 0)
|
||||
atomic.StoreUint64(&s.MaxConn, 0)
|
||||
atomic.StoreUint64(&s.ActiveOpens, 0)
|
||||
atomic.StoreUint64(&s.PassiveOpens, 0)
|
||||
atomic.StoreUint64(&s.CurrEstab, 0)
|
||||
atomic.StoreUint64(&s.InErrs, 0)
|
||||
atomic.StoreUint64(&s.InCsumErrors, 0)
|
||||
atomic.StoreUint64(&s.KCPInErrors, 0)
|
||||
atomic.StoreUint64(&s.InPkts, 0)
|
||||
atomic.StoreUint64(&s.OutPkts, 0)
|
||||
atomic.StoreUint64(&s.InSegs, 0)
|
||||
atomic.StoreUint64(&s.OutSegs, 0)
|
||||
atomic.StoreUint64(&s.InBytes, 0)
|
||||
atomic.StoreUint64(&s.OutBytes, 0)
|
||||
atomic.StoreUint64(&s.RetransSegs, 0)
|
||||
atomic.StoreUint64(&s.FastRetransSegs, 0)
|
||||
atomic.StoreUint64(&s.EarlyRetransSegs, 0)
|
||||
atomic.StoreUint64(&s.LostSegs, 0)
|
||||
atomic.StoreUint64(&s.RepeatSegs, 0)
|
||||
atomic.StoreUint64(&s.FECParityShards, 0)
|
||||
atomic.StoreUint64(&s.FECErrs, 0)
|
||||
atomic.StoreUint64(&s.FECRecovered, 0)
|
||||
atomic.StoreUint64(&s.FECShortShards, 0)
|
||||
}
|
||||
|
||||
// DefaultSnmp is the global KCP connection statistics collector
|
||||
var DefaultSnmp *Snmp
|
||||
|
||||
func init() {
|
||||
DefaultSnmp = newSnmp()
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
package kcp
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var updater updateHeap
|
||||
|
||||
func init() {
|
||||
updater.init()
|
||||
go updater.updateTask()
|
||||
}
|
||||
|
||||
// entry contains a session update info
|
||||
type entry struct {
|
||||
ts time.Time
|
||||
s *UDPSession
|
||||
}
|
||||
|
||||
// a global heap managed kcp.flush() caller
|
||||
type updateHeap struct {
|
||||
entries []entry
|
||||
mu sync.Mutex
|
||||
chWakeUp chan struct{}
|
||||
}
|
||||
|
||||
func (h *updateHeap) Len() int { return len(h.entries) }
|
||||
func (h *updateHeap) Less(i, j int) bool { return h.entries[i].ts.Before(h.entries[j].ts) }
|
||||
func (h *updateHeap) Swap(i, j int) {
|
||||
h.entries[i], h.entries[j] = h.entries[j], h.entries[i]
|
||||
h.entries[i].s.updaterIdx = i
|
||||
h.entries[j].s.updaterIdx = j
|
||||
}
|
||||
|
||||
func (h *updateHeap) Push(x interface{}) {
|
||||
h.entries = append(h.entries, x.(entry))
|
||||
n := len(h.entries)
|
||||
h.entries[n-1].s.updaterIdx = n - 1
|
||||
}
|
||||
|
||||
func (h *updateHeap) Pop() interface{} {
|
||||
n := len(h.entries)
|
||||
x := h.entries[n-1]
|
||||
h.entries[n-1].s.updaterIdx = -1
|
||||
h.entries[n-1] = entry{} // manual set nil for GC
|
||||
h.entries = h.entries[0 : n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
func (h *updateHeap) init() {
|
||||
h.chWakeUp = make(chan struct{}, 1)
|
||||
}
|
||||
|
||||
func (h *updateHeap) addSession(s *UDPSession) {
|
||||
h.mu.Lock()
|
||||
heap.Push(h, entry{time.Now(), s})
|
||||
h.mu.Unlock()
|
||||
h.wakeup()
|
||||
}
|
||||
|
||||
func (h *updateHeap) removeSession(s *UDPSession) {
|
||||
h.mu.Lock()
|
||||
if s.updaterIdx != -1 {
|
||||
heap.Remove(h, s.updaterIdx)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
func (h *updateHeap) wakeup() {
|
||||
select {
|
||||
case h.chWakeUp <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (h *updateHeap) updateTask() {
|
||||
var timer <-chan time.Time
|
||||
for {
|
||||
select {
|
||||
case <-timer:
|
||||
case <-h.chWakeUp:
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
hlen := h.Len()
|
||||
for i := 0; i < hlen; i++ {
|
||||
entry := &h.entries[0]
|
||||
if time.Now().After(entry.ts) {
|
||||
interval := entry.s.update()
|
||||
entry.ts = time.Now().Add(interval)
|
||||
heap.Fix(h, 0)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hlen > 0 {
|
||||
timer = time.After(h.entries[0].ts.Sub(time.Now()))
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package lib
|
||||
package lg
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
@ -9,10 +9,10 @@ import (
|
|||
|
||||
var Log *log.Logger
|
||||
|
||||
func InitLogFile(f string, isStdout bool) {
|
||||
func InitLogFile(f string, isStdout bool, logPath string) {
|
||||
var prefix string
|
||||
if !isStdout {
|
||||
logFile, err := os.OpenFile(filepath.Join(GetLogPath(), f+"_log.txt"), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0766)
|
||||
logFile, err := os.OpenFile(filepath.Join(logPath, f+"_log.txt"), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0766)
|
||||
if err != nil {
|
||||
log.Fatalln("open file error !", err)
|
||||
}
|
43
lib/pool.go
43
lib/pool.go
|
@ -1,43 +0,0 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
const poolSize = 64 * 1024
|
||||
const poolSizeSmall = 100
|
||||
const poolSizeUdp = 1472
|
||||
const poolSizeCopy = 32 * 1024
|
||||
|
||||
var BufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, poolSize)
|
||||
},
|
||||
}
|
||||
|
||||
var BufPoolUdp = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, poolSizeUdp)
|
||||
},
|
||||
}
|
||||
var BufPoolMax = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, poolSize)
|
||||
},
|
||||
}
|
||||
var BufPoolSmall = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, poolSizeSmall)
|
||||
},
|
||||
}
|
||||
var BufPoolCopy = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, poolSizeCopy)
|
||||
},
|
||||
}
|
||||
|
||||
func PutBufPoolCopy(buf []byte) {
|
||||
if cap(buf) == poolSizeCopy {
|
||||
BufPoolCopy.Put(buf[:poolSizeCopy])
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
const PoolSize = 64 * 1024
|
||||
const PoolSizeSmall = 100
|
||||
const PoolSizeUdp = 1472
|
||||
const PoolSizeCopy = 32 * 1024
|
||||
|
||||
var BufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, PoolSize)
|
||||
},
|
||||
}
|
||||
|
||||
var BufPoolUdp = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, PoolSizeUdp)
|
||||
},
|
||||
}
|
||||
var BufPoolMax = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, PoolSize)
|
||||
},
|
||||
}
|
||||
var BufPoolSmall = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, PoolSizeSmall)
|
||||
},
|
||||
}
|
||||
var BufPoolCopy = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, PoolSizeCopy)
|
||||
},
|
||||
}
|
||||
|
||||
func PutBufPoolCopy(buf []byte) {
|
||||
if cap(buf) == PoolSizeCopy {
|
||||
BufPoolCopy.Put(buf[:PoolSizeCopy])
|
||||
}
|
||||
}
|
||||
|
||||
func PutBufPoolUdp(buf []byte) {
|
||||
if cap(buf) == PoolSizeUdp {
|
||||
BufPoolUdp.Put(buf[:PoolSizeUdp])
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package lib
|
||||
package rate
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
|
@ -0,0 +1,237 @@
|
|||
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package snappy
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrCorrupt reports that the input is invalid.
|
||||
ErrCorrupt = errors.New("snappy: corrupt input")
|
||||
// ErrTooLarge reports that the uncompressed length is too large.
|
||||
ErrTooLarge = errors.New("snappy: decoded block is too large")
|
||||
// ErrUnsupported reports that the input isn't supported.
|
||||
ErrUnsupported = errors.New("snappy: unsupported input")
|
||||
|
||||
errUnsupportedLiteralLength = errors.New("snappy: unsupported literal length")
|
||||
)
|
||||
|
||||
// DecodedLen returns the length of the decoded block.
|
||||
func DecodedLen(src []byte) (int, error) {
|
||||
v, _, err := decodedLen(src)
|
||||
return v, err
|
||||
}
|
||||
|
||||
// decodedLen returns the length of the decoded block and the number of bytes
|
||||
// that the length header occupied.
|
||||
func decodedLen(src []byte) (blockLen, headerLen int, err error) {
|
||||
v, n := binary.Uvarint(src)
|
||||
if n <= 0 || v > 0xffffffff {
|
||||
return 0, 0, ErrCorrupt
|
||||
}
|
||||
|
||||
const wordSize = 32 << (^uint(0) >> 32 & 1)
|
||||
if wordSize == 32 && v > 0x7fffffff {
|
||||
return 0, 0, ErrTooLarge
|
||||
}
|
||||
return int(v), n, nil
|
||||
}
|
||||
|
||||
const (
|
||||
decodeErrCodeCorrupt = 1
|
||||
decodeErrCodeUnsupportedLiteralLength = 2
|
||||
)
|
||||
|
||||
// Decode returns the decoded form of src. The returned slice may be a sub-
|
||||
// slice of dst if dst was large enough to hold the entire decoded block.
|
||||
// Otherwise, a newly allocated slice will be returned.
|
||||
//
|
||||
// The dst and src must not overlap. It is valid to pass a nil dst.
|
||||
func Decode(dst, src []byte) ([]byte, error) {
|
||||
dLen, s, err := decodedLen(src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dLen <= len(dst) {
|
||||
dst = dst[:dLen]
|
||||
} else {
|
||||
dst = make([]byte, dLen)
|
||||
}
|
||||
switch decode(dst, src[s:]) {
|
||||
case 0:
|
||||
return dst, nil
|
||||
case decodeErrCodeUnsupportedLiteralLength:
|
||||
return nil, errUnsupportedLiteralLength
|
||||
}
|
||||
return nil, ErrCorrupt
|
||||
}
|
||||
|
||||
// NewReader returns a new Reader that decompresses from r, using the framing
|
||||
// format described at
|
||||
// https://github.com/google/snappy/blob/master/framing_format.txt
|
||||
func NewReader(r io.Reader) *Reader {
|
||||
return &Reader{
|
||||
r: r,
|
||||
decoded: make([]byte, maxBlockSize),
|
||||
buf: make([]byte, maxEncodedLenOfMaxBlockSize+checksumSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Reader is an io.Reader that can read Snappy-compressed bytes.
|
||||
type Reader struct {
|
||||
r io.Reader
|
||||
err error
|
||||
decoded []byte
|
||||
buf []byte
|
||||
// decoded[i:j] contains decoded bytes that have not yet been passed on.
|
||||
i, j int
|
||||
readHeader bool
|
||||
}
|
||||
|
||||
// Reset discards any buffered data, resets all state, and switches the Snappy
|
||||
// reader to read from r. This permits reusing a Reader rather than allocating
|
||||
// a new one.
|
||||
func (r *Reader) Reset(reader io.Reader) {
|
||||
r.r = reader
|
||||
r.err = nil
|
||||
r.i = 0
|
||||
r.j = 0
|
||||
r.readHeader = false
|
||||
}
|
||||
|
||||
func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) {
|
||||
if _, r.err = io.ReadFull(r.r, p); r.err != nil {
|
||||
if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
|
||||
r.err = ErrCorrupt
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Read satisfies the io.Reader interface.
|
||||
func (r *Reader) Read(p []byte) (int, error) {
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
for {
|
||||
if r.i < r.j {
|
||||
n := copy(p, r.decoded[r.i:r.j])
|
||||
r.i += n
|
||||
return n, nil
|
||||
}
|
||||
if !r.readFull(r.buf[:4], true) {
|
||||
return 0, r.err
|
||||
}
|
||||
chunkType := r.buf[0]
|
||||
if !r.readHeader {
|
||||
if chunkType != chunkTypeStreamIdentifier {
|
||||
r.err = ErrCorrupt
|
||||
return 0, r.err
|
||||
}
|
||||
r.readHeader = true
|
||||
}
|
||||
chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
|
||||
if chunkLen > len(r.buf) {
|
||||
r.err = ErrUnsupported
|
||||
return 0, r.err
|
||||
}
|
||||
|
||||
// The chunk types are specified at
|
||||
// https://github.com/google/snappy/blob/master/framing_format.txt
|
||||
switch chunkType {
|
||||
case chunkTypeCompressedData:
|
||||
// Section 4.2. Compressed data (chunk type 0x00).
|
||||
if chunkLen < checksumSize {
|
||||
r.err = ErrCorrupt
|
||||
return 0, r.err
|
||||
}
|
||||
buf := r.buf[:chunkLen]
|
||||
if !r.readFull(buf, false) {
|
||||
return 0, r.err
|
||||
}
|
||||
checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
|
||||
buf = buf[checksumSize:]
|
||||
|
||||
n, err := DecodedLen(buf)
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return 0, r.err
|
||||
}
|
||||
if n > len(r.decoded) {
|
||||
r.err = ErrCorrupt
|
||||
return 0, r.err
|
||||
}
|
||||
if _, err := Decode(r.decoded, buf); err != nil {
|
||||
r.err = err
|
||||
return 0, r.err
|
||||
}
|
||||
if crc(r.decoded[:n]) != checksum {
|
||||
r.err = ErrCorrupt
|
||||
return 0, r.err
|
||||
}
|
||||
r.i, r.j = 0, n
|
||||
continue
|
||||
|
||||
case chunkTypeUncompressedData:
|
||||
// Section 4.3. Uncompressed data (chunk type 0x01).
|
||||
if chunkLen < checksumSize {
|
||||
r.err = ErrCorrupt
|
||||
return 0, r.err
|
||||
}
|
||||
buf := r.buf[:checksumSize]
|
||||
if !r.readFull(buf, false) {
|
||||
return 0, r.err
|
||||
}
|
||||
checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
|
||||
// Read directly into r.decoded instead of via r.buf.
|
||||
n := chunkLen - checksumSize
|
||||
if n > len(r.decoded) {
|
||||
r.err = ErrCorrupt
|
||||
return 0, r.err
|
||||
}
|
||||
if !r.readFull(r.decoded[:n], false) {
|
||||
return 0, r.err
|
||||
}
|
||||
if crc(r.decoded[:n]) != checksum {
|
||||
r.err = ErrCorrupt
|
||||
return 0, r.err
|
||||
}
|
||||
r.i, r.j = 0, n
|
||||
continue
|
||||
|
||||
case chunkTypeStreamIdentifier:
|
||||
// Section 4.1. Stream identifier (chunk type 0xff).
|
||||
if chunkLen != len(magicBody) {
|
||||
r.err = ErrCorrupt
|
||||
return 0, r.err
|
||||
}
|
||||
if !r.readFull(r.buf[:len(magicBody)], false) {
|
||||
return 0, r.err
|
||||
}
|
||||
for i := 0; i < len(magicBody); i++ {
|
||||
if r.buf[i] != magicBody[i] {
|
||||
r.err = ErrCorrupt
|
||||
return 0, r.err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if chunkType <= 0x7f {
|
||||
// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
|
||||
r.err = ErrUnsupported
|
||||
return 0, r.err
|
||||
}
|
||||
// Section 4.4 Padding (chunk type 0xfe).
|
||||
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
|
||||
if !r.readFull(r.buf[:chunkLen], false) {
|
||||
return 0, r.err
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !appengine
|
||||
// +build gc
|
||||
// +build !noasm
|
||||
|
||||
package snappy
|
||||
|
||||
// decode has the same semantics as in decode_other.go.
|
||||
//
|
||||
//go:noescape
|
||||
func decode(dst, src []byte) int
|
|
@ -0,0 +1,490 @@
|
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !appengine
|
||||
// +build gc
|
||||
// +build !noasm
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// The asm code generally follows the pure Go code in decode_other.go, except
|
||||
// where marked with a "!!!".
|
||||
|
||||
// func decode(dst, src []byte) int
|
||||
//
|
||||
// All local variables fit into registers. The non-zero stack size is only to
|
||||
// spill registers and push args when issuing a CALL. The register allocation:
|
||||
// - AX scratch
|
||||
// - BX scratch
|
||||
// - CX length or x
|
||||
// - DX offset
|
||||
// - SI &src[s]
|
||||
// - DI &dst[d]
|
||||
// + R8 dst_base
|
||||
// + R9 dst_len
|
||||
// + R10 dst_base + dst_len
|
||||
// + R11 src_base
|
||||
// + R12 src_len
|
||||
// + R13 src_base + src_len
|
||||
// - R14 used by doCopy
|
||||
// - R15 used by doCopy
|
||||
//
|
||||
// The registers R8-R13 (marked with a "+") are set at the start of the
|
||||
// function, and after a CALL returns, and are not otherwise modified.
|
||||
//
|
||||
// The d variable is implicitly DI - R8, and len(dst)-d is R10 - DI.
|
||||
// The s variable is implicitly SI - R11, and len(src)-s is R13 - SI.
|
||||
TEXT ·decode(SB), NOSPLIT, $48-56
|
||||
// Initialize SI, DI and R8-R13.
|
||||
MOVQ dst_base+0(FP), R8
|
||||
MOVQ dst_len+8(FP), R9
|
||||
MOVQ R8, DI
|
||||
MOVQ R8, R10
|
||||
ADDQ R9, R10
|
||||
MOVQ src_base+24(FP), R11
|
||||
MOVQ src_len+32(FP), R12
|
||||
MOVQ R11, SI
|
||||
MOVQ R11, R13
|
||||
ADDQ R12, R13
|
||||
|
||||
loop:
|
||||
// for s < len(src)
|
||||
CMPQ SI, R13
|
||||
JEQ end
|
||||
|
||||
// CX = uint32(src[s])
|
||||
//
|
||||
// switch src[s] & 0x03
|
||||
MOVBLZX (SI), CX
|
||||
MOVL CX, BX
|
||||
ANDL $3, BX
|
||||
CMPL BX, $1
|
||||
JAE tagCopy
|
||||
|
||||
// ----------------------------------------
|
||||
// The code below handles literal tags.
|
||||
|
||||
// case tagLiteral:
|
||||
// x := uint32(src[s] >> 2)
|
||||
// switch
|
||||
SHRL $2, CX
|
||||
CMPL CX, $60
|
||||
JAE tagLit60Plus
|
||||
|
||||
// case x < 60:
|
||||
// s++
|
||||
INCQ SI
|
||||
|
||||
doLit:
|
||||
// This is the end of the inner "switch", when we have a literal tag.
|
||||
//
|
||||
// We assume that CX == x and x fits in a uint32, where x is the variable
|
||||
// used in the pure Go decode_other.go code.
|
||||
|
||||
// length = int(x) + 1
|
||||
//
|
||||
// Unlike the pure Go code, we don't need to check if length <= 0 because
|
||||
// CX can hold 64 bits, so the increment cannot overflow.
|
||||
INCQ CX
|
||||
|
||||
// Prepare to check if copying length bytes will run past the end of dst or
|
||||
// src.
|
||||
//
|
||||
// AX = len(dst) - d
|
||||
// BX = len(src) - s
|
||||
MOVQ R10, AX
|
||||
SUBQ DI, AX
|
||||
MOVQ R13, BX
|
||||
SUBQ SI, BX
|
||||
|
||||
// !!! Try a faster technique for short (16 or fewer bytes) copies.
|
||||
//
|
||||
// if length > 16 || len(dst)-d < 16 || len(src)-s < 16 {
|
||||
// goto callMemmove // Fall back on calling runtime·memmove.
|
||||
// }
|
||||
//
|
||||
// The C++ snappy code calls this TryFastAppend. It also checks len(src)-s
|
||||
// against 21 instead of 16, because it cannot assume that all of its input
|
||||
// is contiguous in memory and so it needs to leave enough source bytes to
|
||||
// read the next tag without refilling buffers, but Go's Decode assumes
|
||||
// contiguousness (the src argument is a []byte).
|
||||
CMPQ CX, $16
|
||||
JGT callMemmove
|
||||
CMPQ AX, $16
|
||||
JLT callMemmove
|
||||
CMPQ BX, $16
|
||||
JLT callMemmove
|
||||
|
||||
// !!! Implement the copy from src to dst as a 16-byte load and store.
|
||||
// (Decode's documentation says that dst and src must not overlap.)
|
||||
//
|
||||
// This always copies 16 bytes, instead of only length bytes, but that's
|
||||
// OK. If the input is a valid Snappy encoding then subsequent iterations
|
||||
// will fix up the overrun. Otherwise, Decode returns a nil []byte (and a
|
||||
// non-nil error), so the overrun will be ignored.
|
||||
//
|
||||
// Note that on amd64, it is legal and cheap to issue unaligned 8-byte or
|
||||
// 16-byte loads and stores. This technique probably wouldn't be as
|
||||
// effective on architectures that are fussier about alignment.
|
||||
MOVOU 0(SI), X0
|
||||
MOVOU X0, 0(DI)
|
||||
|
||||
// d += length
|
||||
// s += length
|
||||
ADDQ CX, DI
|
||||
ADDQ CX, SI
|
||||
JMP loop
|
||||
|
||||
callMemmove:
|
||||
// if length > len(dst)-d || length > len(src)-s { etc }
|
||||
CMPQ CX, AX
|
||||
JGT errCorrupt
|
||||
CMPQ CX, BX
|
||||
JGT errCorrupt
|
||||
|
||||
// copy(dst[d:], src[s:s+length])
|
||||
//
|
||||
// This means calling runtime·memmove(&dst[d], &src[s], length), so we push
|
||||
// DI, SI and CX as arguments. Coincidentally, we also need to spill those
|
||||
// three registers to the stack, to save local variables across the CALL.
|
||||
MOVQ DI, 0(SP)
|
||||
MOVQ SI, 8(SP)
|
||||
MOVQ CX, 16(SP)
|
||||
MOVQ DI, 24(SP)
|
||||
MOVQ SI, 32(SP)
|
||||
MOVQ CX, 40(SP)
|
||||
CALL runtime·memmove(SB)
|
||||
|
||||
// Restore local variables: unspill registers from the stack and
|
||||
// re-calculate R8-R13.
|
||||
MOVQ 24(SP), DI
|
||||
MOVQ 32(SP), SI
|
||||
MOVQ 40(SP), CX
|
||||
MOVQ dst_base+0(FP), R8
|
||||
MOVQ dst_len+8(FP), R9
|
||||
MOVQ R8, R10
|
||||
ADDQ R9, R10
|
||||
MOVQ src_base+24(FP), R11
|
||||
MOVQ src_len+32(FP), R12
|
||||
MOVQ R11, R13
|
||||
ADDQ R12, R13
|
||||
|
||||
// d += length
|
||||
// s += length
|
||||
ADDQ CX, DI
|
||||
ADDQ CX, SI
|
||||
JMP loop
|
||||
|
||||
tagLit60Plus:
|
||||
// !!! This fragment does the
|
||||
//
|
||||
// s += x - 58; if uint(s) > uint(len(src)) { etc }
|
||||
//
|
||||
// checks. In the asm version, we code it once instead of once per switch case.
|
||||
ADDQ CX, SI
|
||||
SUBQ $58, SI
|
||||
MOVQ SI, BX
|
||||
SUBQ R11, BX
|
||||
CMPQ BX, R12
|
||||
JA errCorrupt
|
||||
|
||||
// case x == 60:
|
||||
CMPL CX, $61
|
||||
JEQ tagLit61
|
||||
JA tagLit62Plus
|
||||
|
||||
// x = uint32(src[s-1])
|
||||
MOVBLZX -1(SI), CX
|
||||
JMP doLit
|
||||
|
||||
tagLit61:
|
||||
// case x == 61:
|
||||
// x = uint32(src[s-2]) | uint32(src[s-1])<<8
|
||||
MOVWLZX -2(SI), CX
|
||||
JMP doLit
|
||||
|
||||
tagLit62Plus:
|
||||
CMPL CX, $62
|
||||
JA tagLit63
|
||||
|
||||
// case x == 62:
|
||||
// x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
|
||||
MOVWLZX -3(SI), CX
|
||||
MOVBLZX -1(SI), BX
|
||||
SHLL $16, BX
|
||||
ORL BX, CX
|
||||
JMP doLit
|
||||
|
||||
tagLit63:
|
||||
// case x == 63:
|
||||
// x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
|
||||
MOVL -4(SI), CX
|
||||
JMP doLit
|
||||
|
||||
// The code above handles literal tags.
|
||||
// ----------------------------------------
|
||||
// The code below handles copy tags.
|
||||
|
||||
tagCopy4:
|
||||
// case tagCopy4:
|
||||
// s += 5
|
||||
ADDQ $5, SI
|
||||
|
||||
// if uint(s) > uint(len(src)) { etc }
|
||||
MOVQ SI, BX
|
||||
SUBQ R11, BX
|
||||
CMPQ BX, R12
|
||||
JA errCorrupt
|
||||
|
||||
// length = 1 + int(src[s-5])>>2
|
||||
SHRQ $2, CX
|
||||
INCQ CX
|
||||
|
||||
// offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
|
||||
MOVLQZX -4(SI), DX
|
||||
JMP doCopy
|
||||
|
||||
tagCopy2:
|
||||
// case tagCopy2:
|
||||
// s += 3
|
||||
ADDQ $3, SI
|
||||
|
||||
// if uint(s) > uint(len(src)) { etc }
|
||||
MOVQ SI, BX
|
||||
SUBQ R11, BX
|
||||
CMPQ BX, R12
|
||||
JA errCorrupt
|
||||
|
||||
// length = 1 + int(src[s-3])>>2
|
||||
SHRQ $2, CX
|
||||
INCQ CX
|
||||
|
||||
// offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
|
||||
MOVWQZX -2(SI), DX
|
||||
JMP doCopy
|
||||
|
||||
tagCopy:
|
||||
// We have a copy tag. We assume that:
|
||||
// - BX == src[s] & 0x03
|
||||
// - CX == src[s]
|
||||
CMPQ BX, $2
|
||||
JEQ tagCopy2
|
||||
JA tagCopy4
|
||||
|
||||
// case tagCopy1:
|
||||
// s += 2
|
||||
ADDQ $2, SI
|
||||
|
||||
// if uint(s) > uint(len(src)) { etc }
|
||||
MOVQ SI, BX
|
||||
SUBQ R11, BX
|
||||
CMPQ BX, R12
|
||||
JA errCorrupt
|
||||
|
||||
// offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
|
||||
MOVQ CX, DX
|
||||
ANDQ $0xe0, DX
|
||||
SHLQ $3, DX
|
||||
MOVBQZX -1(SI), BX
|
||||
ORQ BX, DX
|
||||
|
||||
// length = 4 + int(src[s-2])>>2&0x7
|
||||
SHRQ $2, CX
|
||||
ANDQ $7, CX
|
||||
ADDQ $4, CX
|
||||
|
||||
doCopy:
|
||||
// This is the end of the outer "switch", when we have a copy tag.
|
||||
//
|
||||
// We assume that:
|
||||
// - CX == length && CX > 0
|
||||
// - DX == offset
|
||||
|
||||
// if offset <= 0 { etc }
|
||||
CMPQ DX, $0
|
||||
JLE errCorrupt
|
||||
|
||||
// if d < offset { etc }
|
||||
MOVQ DI, BX
|
||||
SUBQ R8, BX
|
||||
CMPQ BX, DX
|
||||
JLT errCorrupt
|
||||
|
||||
// if length > len(dst)-d { etc }
|
||||
MOVQ R10, BX
|
||||
SUBQ DI, BX
|
||||
CMPQ CX, BX
|
||||
JGT errCorrupt
|
||||
|
||||
// forwardCopy(dst[d:d+length], dst[d-offset:]); d += length
|
||||
//
|
||||
// Set:
|
||||
// - R14 = len(dst)-d
|
||||
// - R15 = &dst[d-offset]
|
||||
MOVQ R10, R14
|
||||
SUBQ DI, R14
|
||||
MOVQ DI, R15
|
||||
SUBQ DX, R15
|
||||
|
||||
// !!! Try a faster technique for short (16 or fewer bytes) forward copies.
|
||||
//
|
||||
// First, try using two 8-byte load/stores, similar to the doLit technique
|
||||
// above. Even if dst[d:d+length] and dst[d-offset:] can overlap, this is
|
||||
// still OK if offset >= 8. Note that this has to be two 8-byte load/stores
|
||||
// and not one 16-byte load/store, and the first store has to be before the
|
||||
// second load, due to the overlap if offset is in the range [8, 16).
|
||||
//
|
||||
// if length > 16 || offset < 8 || len(dst)-d < 16 {
|
||||
// goto slowForwardCopy
|
||||
// }
|
||||
// copy 16 bytes
|
||||
// d += length
|
||||
CMPQ CX, $16
|
||||
JGT slowForwardCopy
|
||||
CMPQ DX, $8
|
||||
JLT slowForwardCopy
|
||||
CMPQ R14, $16
|
||||
JLT slowForwardCopy
|
||||
MOVQ 0(R15), AX
|
||||
MOVQ AX, 0(DI)
|
||||
MOVQ 8(R15), BX
|
||||
MOVQ BX, 8(DI)
|
||||
ADDQ CX, DI
|
||||
JMP loop
|
||||
|
||||
slowForwardCopy:
|
||||
// !!! If the forward copy is longer than 16 bytes, or if offset < 8, we
|
||||
// can still try 8-byte load stores, provided we can overrun up to 10 extra
|
||||
// bytes. As above, the overrun will be fixed up by subsequent iterations
|
||||
// of the outermost loop.
|
||||
//
|
||||
// The C++ snappy code calls this technique IncrementalCopyFastPath. Its
|
||||
// commentary says:
|
||||
//
|
||||
// ----
|
||||
//
|
||||
// The main part of this loop is a simple copy of eight bytes at a time
|
||||
// until we've copied (at least) the requested amount of bytes. However,
|
||||
// if d and d-offset are less than eight bytes apart (indicating a
|
||||
// repeating pattern of length < 8), we first need to expand the pattern in
|
||||
// order to get the correct results. For instance, if the buffer looks like
|
||||
// this, with the eight-byte <d-offset> and <d> patterns marked as
|
||||
// intervals:
|
||||
//
|
||||
// abxxxxxxxxxxxx
|
||||
// [------] d-offset
|
||||
// [------] d
|
||||
//
|
||||
// a single eight-byte copy from <d-offset> to <d> will repeat the pattern
|
||||
// once, after which we can move <d> two bytes without moving <d-offset>:
|
||||
//
|
||||
// ababxxxxxxxxxx
|
||||
// [------] d-offset
|
||||
// [------] d
|
||||
//
|
||||
// and repeat the exercise until the two no longer overlap.
|
||||
//
|
||||
// This allows us to do very well in the special case of one single byte
|
||||
// repeated many times, without taking a big hit for more general cases.
|
||||
//
|
||||
// The worst case of extra writing past the end of the match occurs when
|
||||
// offset == 1 and length == 1; the last copy will read from byte positions
|
||||
// [0..7] and write to [4..11], whereas it was only supposed to write to
|
||||
// position 1. Thus, ten excess bytes.
|
||||
//
|
||||
// ----
|
||||
//
|
||||
// That "10 byte overrun" worst case is confirmed by Go's
|
||||
// TestSlowForwardCopyOverrun, which also tests the fixUpSlowForwardCopy
|
||||
// and finishSlowForwardCopy algorithm.
|
||||
//
|
||||
// if length > len(dst)-d-10 {
|
||||
// goto verySlowForwardCopy
|
||||
// }
|
||||
SUBQ $10, R14
|
||||
CMPQ CX, R14
|
||||
JGT verySlowForwardCopy
|
||||
|
||||
makeOffsetAtLeast8:
|
||||
// !!! As above, expand the pattern so that offset >= 8 and we can use
|
||||
// 8-byte load/stores.
|
||||
//
|
||||
// for offset < 8 {
|
||||
// copy 8 bytes from dst[d-offset:] to dst[d:]
|
||||
// length -= offset
|
||||
// d += offset
|
||||
// offset += offset
|
||||
// // The two previous lines together means that d-offset, and therefore
|
||||
// // R15, is unchanged.
|
||||
// }
|
||||
CMPQ DX, $8
|
||||
JGE fixUpSlowForwardCopy
|
||||
MOVQ (R15), BX
|
||||
MOVQ BX, (DI)
|
||||
SUBQ DX, CX
|
||||
ADDQ DX, DI
|
||||
ADDQ DX, DX
|
||||
JMP makeOffsetAtLeast8
|
||||
|
||||
fixUpSlowForwardCopy:
|
||||
// !!! Add length (which might be negative now) to d (implied by DI being
|
||||
// &dst[d]) so that d ends up at the right place when we jump back to the
|
||||
// top of the loop. Before we do that, though, we save DI to AX so that, if
|
||||
// length is positive, copying the remaining length bytes will write to the
|
||||
// right place.
|
||||
MOVQ DI, AX
|
||||
ADDQ CX, DI
|
||||
|
||||
finishSlowForwardCopy:
|
||||
// !!! Repeat 8-byte load/stores until length <= 0. Ending with a negative
|
||||
// length means that we overrun, but as above, that will be fixed up by
|
||||
// subsequent iterations of the outermost loop.
|
||||
CMPQ CX, $0
|
||||
JLE loop
|
||||
MOVQ (R15), BX
|
||||
MOVQ BX, (AX)
|
||||
ADDQ $8, R15
|
||||
ADDQ $8, AX
|
||||
SUBQ $8, CX
|
||||
JMP finishSlowForwardCopy
|
||||
|
||||
verySlowForwardCopy:
|
||||
// verySlowForwardCopy is a simple implementation of forward copy. In C
|
||||
// parlance, this is a do/while loop instead of a while loop, since we know
|
||||
// that length > 0. In Go syntax:
|
||||
//
|
||||
// for {
|
||||
// dst[d] = dst[d - offset]
|
||||
// d++
|
||||
// length--
|
||||
// if length == 0 {
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
MOVB (R15), BX
|
||||
MOVB BX, (DI)
|
||||
INCQ R15
|
||||
INCQ DI
|
||||
DECQ CX
|
||||
JNZ verySlowForwardCopy
|
||||
JMP loop
|
||||
|
||||
// The code above handles copy tags.
|
||||
// ----------------------------------------
|
||||
|
||||
end:
|
||||
// This is the end of the "for s < len(src)".
|
||||
//
|
||||
// if d != len(dst) { etc }
|
||||
CMPQ DI, R10
|
||||
JNE errCorrupt
|
||||
|
||||
// return 0
|
||||
MOVQ $0, ret+48(FP)
|
||||
RET
|
||||
|
||||
errCorrupt:
|
||||
// return decodeErrCodeCorrupt
|
||||
MOVQ $1, ret+48(FP)
|
||||
RET
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !amd64 appengine !gc noasm
|
||||
|
||||
package snappy
|
||||
|
||||
// decode writes the decoding of src to dst. It assumes that the varint-encoded
|
||||
// length of the decompressed bytes has already been read, and that len(dst)
|
||||
// equals that length.
|
||||
//
|
||||
// It returns 0 on success or a decodeErrCodeXxx error code on failure.
|
||||
func decode(dst, src []byte) int {
|
||||
var d, s, offset, length int
|
||||
for s < len(src) {
|
||||
switch src[s] & 0x03 {
|
||||
case tagLiteral:
|
||||
x := uint32(src[s] >> 2)
|
||||
switch {
|
||||
case x < 60:
|
||||
s++
|
||||
case x == 60:
|
||||
s += 2
|
||||
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
|
||||
return decodeErrCodeCorrupt
|
||||
}
|
||||
x = uint32(src[s-1])
|
||||
case x == 61:
|
||||
s += 3
|
||||
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
|
||||
return decodeErrCodeCorrupt
|
||||
}
|
||||
x = uint32(src[s-2]) | uint32(src[s-1])<<8
|
||||
case x == 62:
|
||||
s += 4
|
||||
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
|
||||
return decodeErrCodeCorrupt
|
||||
}
|
||||
x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
|
||||
case x == 63:
|
||||
s += 5
|
||||
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
|
||||
return decodeErrCodeCorrupt
|
||||
}
|
||||
x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
|
||||
}
|
||||
length = int(x) + 1
|
||||
if length <= 0 {
|
||||
return decodeErrCodeUnsupportedLiteralLength
|
||||
}
|
||||
if length > len(dst)-d || length > len(src)-s {
|
||||
return decodeErrCodeCorrupt
|
||||
}
|
||||
copy(dst[d:], src[s:s+length])
|
||||
d += length
|
||||
s += length
|
||||
continue
|
||||
|
||||
case tagCopy1:
|
||||
s += 2
|
||||
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
|
||||
return decodeErrCodeCorrupt
|
||||
}
|
||||
length = 4 + int(src[s-2])>>2&0x7
|
||||
offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
|
||||
|
||||
case tagCopy2:
|
||||
s += 3
|
||||
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
|
||||
return decodeErrCodeCorrupt
|
||||
}
|
||||
length = 1 + int(src[s-3])>>2
|
||||
offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
|
||||
|
||||
case tagCopy4:
|
||||
s += 5
|
||||
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
|
||||
return decodeErrCodeCorrupt
|
||||
}
|
||||
length = 1 + int(src[s-5])>>2
|
||||
offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
|
||||
}
|
||||
|
||||
if offset <= 0 || d < offset || length > len(dst)-d {
|
||||
return decodeErrCodeCorrupt
|
||||
}
|
||||
// Copy from an earlier sub-slice of dst to a later sub-slice. Unlike
|
||||
// the built-in copy function, this byte-by-byte copy always runs
|
||||
// forwards, even if the slices overlap. Conceptually, this is:
|
||||
//
|
||||
// d += forwardCopy(dst[d:d+length], dst[d-offset:])
|
||||
for end := d + length; d != end; d++ {
|
||||
dst[d] = dst[d-offset]
|
||||
}
|
||||
}
|
||||
if d != len(dst) {
|
||||
return decodeErrCodeCorrupt
|
||||
}
|
||||
return 0
|
||||
}
|
|
@ -0,0 +1,285 @@
|
|||
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package snappy
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Encode returns the encoded form of src. The returned slice may be a sub-
|
||||
// slice of dst if dst was large enough to hold the entire encoded block.
|
||||
// Otherwise, a newly allocated slice will be returned.
|
||||
//
|
||||
// The dst and src must not overlap. It is valid to pass a nil dst.
|
||||
func Encode(dst, src []byte) []byte {
|
||||
if n := MaxEncodedLen(len(src)); n < 0 {
|
||||
panic(ErrTooLarge)
|
||||
} else if len(dst) < n {
|
||||
dst = make([]byte, n)
|
||||
}
|
||||
|
||||
// The block starts with the varint-encoded length of the decompressed bytes.
|
||||
d := binary.PutUvarint(dst, uint64(len(src)))
|
||||
|
||||
for len(src) > 0 {
|
||||
p := src
|
||||
src = nil
|
||||
if len(p) > maxBlockSize {
|
||||
p, src = p[:maxBlockSize], p[maxBlockSize:]
|
||||
}
|
||||
if len(p) < minNonLiteralBlockSize {
|
||||
d += emitLiteral(dst[d:], p)
|
||||
} else {
|
||||
d += encodeBlock(dst[d:], p)
|
||||
}
|
||||
}
|
||||
return dst[:d]
|
||||
}
|
||||
|
||||
// inputMargin is the minimum number of extra input bytes to keep, inside
|
||||
// encodeBlock's inner loop. On some architectures, this margin lets us
|
||||
// implement a fast path for emitLiteral, where the copy of short (<= 16 byte)
|
||||
// literals can be implemented as a single load to and store from a 16-byte
|
||||
// register. That literal's actual length can be as short as 1 byte, so this
|
||||
// can copy up to 15 bytes too much, but that's OK as subsequent iterations of
|
||||
// the encoding loop will fix up the copy overrun, and this inputMargin ensures
|
||||
// that we don't overrun the dst and src buffers.
|
||||
const inputMargin = 16 - 1
|
||||
|
||||
// minNonLiteralBlockSize is the minimum size of the input to encodeBlock that
|
||||
// could be encoded with a copy tag. This is the minimum with respect to the
|
||||
// algorithm used by encodeBlock, not a minimum enforced by the file format.
|
||||
//
|
||||
// The encoded output must start with at least a 1 byte literal, as there are
|
||||
// no previous bytes to copy. A minimal (1 byte) copy after that, generated
|
||||
// from an emitCopy call in encodeBlock's main loop, would require at least
|
||||
// another inputMargin bytes, for the reason above: we want any emitLiteral
|
||||
// calls inside encodeBlock's main loop to use the fast path if possible, which
|
||||
// requires being able to overrun by inputMargin bytes. Thus,
|
||||
// minNonLiteralBlockSize equals 1 + 1 + inputMargin.
|
||||
//
|
||||
// The C++ code doesn't use this exact threshold, but it could, as discussed at
|
||||
// https://groups.google.com/d/topic/snappy-compression/oGbhsdIJSJ8/discussion
|
||||
// The difference between Go (2+inputMargin) and C++ (inputMargin) is purely an
|
||||
// optimization. It should not affect the encoded form. This is tested by
|
||||
// TestSameEncodingAsCppShortCopies.
|
||||
const minNonLiteralBlockSize = 1 + 1 + inputMargin
|
||||
|
||||
// MaxEncodedLen returns the maximum length of a snappy block, given its
|
||||
// uncompressed length.
|
||||
//
|
||||
// It will return a negative value if srcLen is too large to encode.
|
||||
func MaxEncodedLen(srcLen int) int {
|
||||
n := uint64(srcLen)
|
||||
if n > 0xffffffff {
|
||||
return -1
|
||||
}
|
||||
// Compressed data can be defined as:
|
||||
// compressed := item* literal*
|
||||
// item := literal* copy
|
||||
//
|
||||
// The trailing literal sequence has a space blowup of at most 62/60
|
||||
// since a literal of length 60 needs one tag byte + one extra byte
|
||||
// for length information.
|
||||
//
|
||||
// Item blowup is trickier to measure. Suppose the "copy" op copies
|
||||
// 4 bytes of data. Because of a special check in the encoding code,
|
||||
// we produce a 4-byte copy only if the offset is < 65536. Therefore
|
||||
// the copy op takes 3 bytes to encode, and this type of item leads
|
||||
// to at most the 62/60 blowup for representing literals.
|
||||
//
|
||||
// Suppose the "copy" op copies 5 bytes of data. If the offset is big
|
||||
// enough, it will take 5 bytes to encode the copy op. Therefore the
|
||||
// worst case here is a one-byte literal followed by a five-byte copy.
|
||||
// That is, 6 bytes of input turn into 7 bytes of "compressed" data.
|
||||
//
|
||||
// This last factor dominates the blowup, so the final estimate is:
|
||||
n = 32 + n + n/6
|
||||
if n > 0xffffffff {
|
||||
return -1
|
||||
}
|
||||
return int(n)
|
||||
}
|
||||
|
||||
var errClosed = errors.New("snappy: Writer is closed")
|
||||
|
||||
// NewWriter returns a new Writer that compresses to w.
|
||||
//
|
||||
// The Writer returned does not buffer writes. There is no need to Flush or
|
||||
// Close such a Writer.
|
||||
//
|
||||
// Deprecated: the Writer returned is not suitable for many small writes, only
|
||||
// for few large writes. Use NewBufferedWriter instead, which is efficient
|
||||
// regardless of the frequency and shape of the writes, and remember to Close
|
||||
// that Writer when done.
|
||||
func NewWriter(w io.Writer) *Writer {
|
||||
return &Writer{
|
||||
w: w,
|
||||
obuf: make([]byte, obufLen),
|
||||
}
|
||||
}
|
||||
|
||||
// NewBufferedWriter returns a new Writer that compresses to w, using the
|
||||
// framing format described at
|
||||
// https://github.com/google/snappy/blob/master/framing_format.txt
|
||||
//
|
||||
// The Writer returned buffers writes. Users must call Close to guarantee all
|
||||
// data has been forwarded to the underlying io.Writer. They may also call
|
||||
// Flush zero or more times before calling Close.
|
||||
func NewBufferedWriter(w io.Writer) *Writer {
|
||||
return &Writer{
|
||||
w: w,
|
||||
ibuf: make([]byte, 0, maxBlockSize),
|
||||
obuf: make([]byte, obufLen),
|
||||
}
|
||||
}
|
||||
|
||||
// Writer is an io.Writer that can write Snappy-compressed bytes.
|
||||
type Writer struct {
|
||||
w io.Writer
|
||||
err error
|
||||
|
||||
// ibuf is a buffer for the incoming (uncompressed) bytes.
|
||||
//
|
||||
// Its use is optional. For backwards compatibility, Writers created by the
|
||||
// NewWriter function have ibuf == nil, do not buffer incoming bytes, and
|
||||
// therefore do not need to be Flush'ed or Close'd.
|
||||
ibuf []byte
|
||||
|
||||
// obuf is a buffer for the outgoing (compressed) bytes.
|
||||
obuf []byte
|
||||
|
||||
// wroteStreamHeader is whether we have written the stream header.
|
||||
wroteStreamHeader bool
|
||||
}
|
||||
|
||||
// Reset discards the writer's state and switches the Snappy writer to write to
|
||||
// w. This permits reusing a Writer rather than allocating a new one.
|
||||
func (w *Writer) Reset(writer io.Writer) {
|
||||
w.w = writer
|
||||
w.err = nil
|
||||
if w.ibuf != nil {
|
||||
w.ibuf = w.ibuf[:0]
|
||||
}
|
||||
w.wroteStreamHeader = false
|
||||
}
|
||||
|
||||
// Write satisfies the io.Writer interface.
|
||||
func (w *Writer) Write(p []byte) (nRet int, errRet error) {
|
||||
if w.ibuf == nil {
|
||||
// Do not buffer incoming bytes. This does not perform or compress well
|
||||
// if the caller of Writer.Write writes many small slices. This
|
||||
// behavior is therefore deprecated, but still supported for backwards
|
||||
// compatibility with code that doesn't explicitly Flush or Close.
|
||||
return w.write(p)
|
||||
}
|
||||
|
||||
// The remainder of this method is based on bufio.Writer.Write from the
|
||||
// standard library.
|
||||
|
||||
for len(p) > (cap(w.ibuf)-len(w.ibuf)) && w.err == nil {
|
||||
var n int
|
||||
if len(w.ibuf) == 0 {
|
||||
// Large write, empty buffer.
|
||||
// Write directly from p to avoid copy.
|
||||
n, _ = w.write(p)
|
||||
} else {
|
||||
n = copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
|
||||
w.ibuf = w.ibuf[:len(w.ibuf)+n]
|
||||
w.Flush()
|
||||
}
|
||||
nRet += n
|
||||
p = p[n:]
|
||||
}
|
||||
if w.err != nil {
|
||||
return nRet, w.err
|
||||
}
|
||||
n := copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
|
||||
w.ibuf = w.ibuf[:len(w.ibuf)+n]
|
||||
nRet += n
|
||||
return nRet, nil
|
||||
}
|
||||
|
||||
func (w *Writer) write(p []byte) (nRet int, errRet error) {
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
for len(p) > 0 {
|
||||
obufStart := len(magicChunk)
|
||||
if !w.wroteStreamHeader {
|
||||
w.wroteStreamHeader = true
|
||||
copy(w.obuf, magicChunk)
|
||||
obufStart = 0
|
||||
}
|
||||
|
||||
var uncompressed []byte
|
||||
if len(p) > maxBlockSize {
|
||||
uncompressed, p = p[:maxBlockSize], p[maxBlockSize:]
|
||||
} else {
|
||||
uncompressed, p = p, nil
|
||||
}
|
||||
checksum := crc(uncompressed)
|
||||
|
||||
// Compress the buffer, discarding the result if the improvement
|
||||
// isn't at least 12.5%.
|
||||
compressed := Encode(w.obuf[obufHeaderLen:], uncompressed)
|
||||
chunkType := uint8(chunkTypeCompressedData)
|
||||
chunkLen := 4 + len(compressed)
|
||||
obufEnd := obufHeaderLen + len(compressed)
|
||||
if len(compressed) >= len(uncompressed)-len(uncompressed)/8 {
|
||||
chunkType = chunkTypeUncompressedData
|
||||
chunkLen = 4 + len(uncompressed)
|
||||
obufEnd = obufHeaderLen
|
||||
}
|
||||
|
||||
// Fill in the per-chunk header that comes before the body.
|
||||
w.obuf[len(magicChunk)+0] = chunkType
|
||||
w.obuf[len(magicChunk)+1] = uint8(chunkLen >> 0)
|
||||
w.obuf[len(magicChunk)+2] = uint8(chunkLen >> 8)
|
||||
w.obuf[len(magicChunk)+3] = uint8(chunkLen >> 16)
|
||||
w.obuf[len(magicChunk)+4] = uint8(checksum >> 0)
|
||||
w.obuf[len(magicChunk)+5] = uint8(checksum >> 8)
|
||||
w.obuf[len(magicChunk)+6] = uint8(checksum >> 16)
|
||||
w.obuf[len(magicChunk)+7] = uint8(checksum >> 24)
|
||||
|
||||
if _, err := w.w.Write(w.obuf[obufStart:obufEnd]); err != nil {
|
||||
w.err = err
|
||||
return nRet, err
|
||||
}
|
||||
if chunkType == chunkTypeUncompressedData {
|
||||
if _, err := w.w.Write(uncompressed); err != nil {
|
||||
w.err = err
|
||||
return nRet, err
|
||||
}
|
||||
}
|
||||
nRet += len(uncompressed)
|
||||
}
|
||||
return nRet, nil
|
||||
}
|
||||
|
||||
// Flush flushes the Writer to its underlying io.Writer.
|
||||
func (w *Writer) Flush() error {
|
||||
if w.err != nil {
|
||||
return w.err
|
||||
}
|
||||
if len(w.ibuf) == 0 {
|
||||
return nil
|
||||
}
|
||||
w.write(w.ibuf)
|
||||
w.ibuf = w.ibuf[:0]
|
||||
return w.err
|
||||
}
|
||||
|
||||
// Close calls Flush and then closes the Writer.
|
||||
func (w *Writer) Close() error {
|
||||
w.Flush()
|
||||
ret := w.err
|
||||
if w.err == nil {
|
||||
w.err = errClosed
|
||||
}
|
||||
return ret
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !appengine
|
||||
// +build gc
|
||||
// +build !noasm
|
||||
|
||||
package snappy
|
||||
|
||||
// emitLiteral has the same semantics as in encode_other.go.
|
||||
//
|
||||
//go:noescape
|
||||
func emitLiteral(dst, lit []byte) int
|
||||
|
||||
// emitCopy has the same semantics as in encode_other.go.
|
||||
//
|
||||
//go:noescape
|
||||
func emitCopy(dst []byte, offset, length int) int
|
||||
|
||||
// extendMatch has the same semantics as in encode_other.go.
|
||||
//
|
||||
//go:noescape
|
||||
func extendMatch(src []byte, i, j int) int
|
||||
|
||||
// encodeBlock has the same semantics as in encode_other.go.
|
||||
//
|
||||
//go:noescape
|
||||
func encodeBlock(dst, src []byte) (d int)
|
|
@ -0,0 +1,730 @@
|
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !appengine
|
||||
// +build gc
|
||||
// +build !noasm
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// The XXX lines assemble on Go 1.4, 1.5 and 1.7, but not 1.6, due to a
|
||||
// Go toolchain regression. See https://github.com/golang/go/issues/15426 and
|
||||
// https://github.com/golang/snappy/issues/29
|
||||
//
|
||||
// As a workaround, the package was built with a known good assembler, and
|
||||
// those instructions were disassembled by "objdump -d" to yield the
|
||||
// 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15
|
||||
// style comments, in AT&T asm syntax. Note that rsp here is a physical
|
||||
// register, not Go/asm's SP pseudo-register (see https://golang.org/doc/asm).
|
||||
// The instructions were then encoded as "BYTE $0x.." sequences, which assemble
|
||||
// fine on Go 1.6.
|
||||
|
||||
// The asm code generally follows the pure Go code in encode_other.go, except
|
||||
// where marked with a "!!!".
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// func emitLiteral(dst, lit []byte) int
|
||||
//
|
||||
// All local variables fit into registers. The register allocation:
|
||||
// - AX len(lit)
|
||||
// - BX n
|
||||
// - DX return value
|
||||
// - DI &dst[i]
|
||||
// - R10 &lit[0]
|
||||
//
|
||||
// The 24 bytes of stack space is to call runtime·memmove.
|
||||
//
|
||||
// The unusual register allocation of local variables, such as R10 for the
|
||||
// source pointer, matches the allocation used at the call site in encodeBlock,
|
||||
// which makes it easier to manually inline this function.
|
||||
TEXT ·emitLiteral(SB), NOSPLIT, $24-56
|
||||
MOVQ dst_base+0(FP), DI
|
||||
MOVQ lit_base+24(FP), R10
|
||||
MOVQ lit_len+32(FP), AX
|
||||
MOVQ AX, DX
|
||||
MOVL AX, BX
|
||||
SUBL $1, BX
|
||||
|
||||
CMPL BX, $60
|
||||
JLT oneByte
|
||||
CMPL BX, $256
|
||||
JLT twoBytes
|
||||
|
||||
threeBytes:
|
||||
MOVB $0xf4, 0(DI)
|
||||
MOVW BX, 1(DI)
|
||||
ADDQ $3, DI
|
||||
ADDQ $3, DX
|
||||
JMP memmove
|
||||
|
||||
twoBytes:
|
||||
MOVB $0xf0, 0(DI)
|
||||
MOVB BX, 1(DI)
|
||||
ADDQ $2, DI
|
||||
ADDQ $2, DX
|
||||
JMP memmove
|
||||
|
||||
oneByte:
|
||||
SHLB $2, BX
|
||||
MOVB BX, 0(DI)
|
||||
ADDQ $1, DI
|
||||
ADDQ $1, DX
|
||||
|
||||
memmove:
|
||||
MOVQ DX, ret+48(FP)
|
||||
|
||||
// copy(dst[i:], lit)
|
||||
//
|
||||
// This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push
|
||||
// DI, R10 and AX as arguments.
|
||||
MOVQ DI, 0(SP)
|
||||
MOVQ R10, 8(SP)
|
||||
MOVQ AX, 16(SP)
|
||||
CALL runtime·memmove(SB)
|
||||
RET
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// func emitCopy(dst []byte, offset, length int) int
|
||||
//
|
||||
// All local variables fit into registers. The register allocation:
|
||||
// - AX length
|
||||
// - SI &dst[0]
|
||||
// - DI &dst[i]
|
||||
// - R11 offset
|
||||
//
|
||||
// The unusual register allocation of local variables, such as R11 for the
|
||||
// offset, matches the allocation used at the call site in encodeBlock, which
|
||||
// makes it easier to manually inline this function.
|
||||
TEXT ·emitCopy(SB), NOSPLIT, $0-48
|
||||
MOVQ dst_base+0(FP), DI
|
||||
MOVQ DI, SI
|
||||
MOVQ offset+24(FP), R11
|
||||
MOVQ length+32(FP), AX
|
||||
|
||||
loop0:
|
||||
// for length >= 68 { etc }
|
||||
CMPL AX, $68
|
||||
JLT step1
|
||||
|
||||
// Emit a length 64 copy, encoded as 3 bytes.
|
||||
MOVB $0xfe, 0(DI)
|
||||
MOVW R11, 1(DI)
|
||||
ADDQ $3, DI
|
||||
SUBL $64, AX
|
||||
JMP loop0
|
||||
|
||||
step1:
|
||||
// if length > 64 { etc }
|
||||
CMPL AX, $64
|
||||
JLE step2
|
||||
|
||||
// Emit a length 60 copy, encoded as 3 bytes.
|
||||
MOVB $0xee, 0(DI)
|
||||
MOVW R11, 1(DI)
|
||||
ADDQ $3, DI
|
||||
SUBL $60, AX
|
||||
|
||||
step2:
|
||||
// if length >= 12 || offset >= 2048 { goto step3 }
|
||||
CMPL AX, $12
|
||||
JGE step3
|
||||
CMPL R11, $2048
|
||||
JGE step3
|
||||
|
||||
// Emit the remaining copy, encoded as 2 bytes.
|
||||
MOVB R11, 1(DI)
|
||||
SHRL $8, R11
|
||||
SHLB $5, R11
|
||||
SUBB $4, AX
|
||||
SHLB $2, AX
|
||||
ORB AX, R11
|
||||
ORB $1, R11
|
||||
MOVB R11, 0(DI)
|
||||
ADDQ $2, DI
|
||||
|
||||
// Return the number of bytes written.
|
||||
SUBQ SI, DI
|
||||
MOVQ DI, ret+40(FP)
|
||||
RET
|
||||
|
||||
step3:
|
||||
// Emit the remaining copy, encoded as 3 bytes.
|
||||
SUBL $1, AX
|
||||
SHLB $2, AX
|
||||
ORB $2, AX
|
||||
MOVB AX, 0(DI)
|
||||
MOVW R11, 1(DI)
|
||||
ADDQ $3, DI
|
||||
|
||||
// Return the number of bytes written.
|
||||
SUBQ SI, DI
|
||||
MOVQ DI, ret+40(FP)
|
||||
RET
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// func extendMatch(src []byte, i, j int) int
|
||||
//
|
||||
// All local variables fit into registers. The register allocation:
|
||||
// - DX &src[0]
|
||||
// - SI &src[j]
|
||||
// - R13 &src[len(src) - 8]
|
||||
// - R14 &src[len(src)]
|
||||
// - R15 &src[i]
|
||||
//
|
||||
// The unusual register allocation of local variables, such as R15 for a source
|
||||
// pointer, matches the allocation used at the call site in encodeBlock, which
|
||||
// makes it easier to manually inline this function.
|
||||
TEXT ·extendMatch(SB), NOSPLIT, $0-48
|
||||
MOVQ src_base+0(FP), DX
|
||||
MOVQ src_len+8(FP), R14
|
||||
MOVQ i+24(FP), R15
|
||||
MOVQ j+32(FP), SI
|
||||
ADDQ DX, R14
|
||||
ADDQ DX, R15
|
||||
ADDQ DX, SI
|
||||
MOVQ R14, R13
|
||||
SUBQ $8, R13
|
||||
|
||||
cmp8:
|
||||
// As long as we are 8 or more bytes before the end of src, we can load and
|
||||
// compare 8 bytes at a time. If those 8 bytes are equal, repeat.
|
||||
CMPQ SI, R13
|
||||
JA cmp1
|
||||
MOVQ (R15), AX
|
||||
MOVQ (SI), BX
|
||||
CMPQ AX, BX
|
||||
JNE bsf
|
||||
ADDQ $8, R15
|
||||
ADDQ $8, SI
|
||||
JMP cmp8
|
||||
|
||||
bsf:
|
||||
// If those 8 bytes were not equal, XOR the two 8 byte values, and return
|
||||
// the index of the first byte that differs. The BSF instruction finds the
|
||||
// least significant 1 bit, the amd64 architecture is little-endian, and
|
||||
// the shift by 3 converts a bit index to a byte index.
|
||||
XORQ AX, BX
|
||||
BSFQ BX, BX
|
||||
SHRQ $3, BX
|
||||
ADDQ BX, SI
|
||||
|
||||
// Convert from &src[ret] to ret.
|
||||
SUBQ DX, SI
|
||||
MOVQ SI, ret+40(FP)
|
||||
RET
|
||||
|
||||
cmp1:
|
||||
// In src's tail, compare 1 byte at a time.
|
||||
CMPQ SI, R14
|
||||
JAE extendMatchEnd
|
||||
MOVB (R15), AX
|
||||
MOVB (SI), BX
|
||||
CMPB AX, BX
|
||||
JNE extendMatchEnd
|
||||
ADDQ $1, R15
|
||||
ADDQ $1, SI
|
||||
JMP cmp1
|
||||
|
||||
extendMatchEnd:
|
||||
// Convert from &src[ret] to ret.
|
||||
SUBQ DX, SI
|
||||
MOVQ SI, ret+40(FP)
|
||||
RET
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// func encodeBlock(dst, src []byte) (d int)
|
||||
//
|
||||
// All local variables fit into registers, other than "var table". The register
|
||||
// allocation:
|
||||
// - AX . .
|
||||
// - BX . .
|
||||
// - CX 56 shift (note that amd64 shifts by non-immediates must use CX).
|
||||
// - DX 64 &src[0], tableSize
|
||||
// - SI 72 &src[s]
|
||||
// - DI 80 &dst[d]
|
||||
// - R9 88 sLimit
|
||||
// - R10 . &src[nextEmit]
|
||||
// - R11 96 prevHash, currHash, nextHash, offset
|
||||
// - R12 104 &src[base], skip
|
||||
// - R13 . &src[nextS], &src[len(src) - 8]
|
||||
// - R14 . len(src), bytesBetweenHashLookups, &src[len(src)], x
|
||||
// - R15 112 candidate
|
||||
//
|
||||
// The second column (56, 64, etc) is the stack offset to spill the registers
|
||||
// when calling other functions. We could pack this slightly tighter, but it's
|
||||
// simpler to have a dedicated spill map independent of the function called.
|
||||
//
|
||||
// "var table [maxTableSize]uint16" takes up 32768 bytes of stack space. An
|
||||
// extra 56 bytes, to call other functions, and an extra 64 bytes, to spill
|
||||
// local variables (registers) during calls gives 32768 + 56 + 64 = 32888.
|
||||
TEXT ·encodeBlock(SB), 0, $32888-56
|
||||
MOVQ dst_base+0(FP), DI
|
||||
MOVQ src_base+24(FP), SI
|
||||
MOVQ src_len+32(FP), R14
|
||||
|
||||
// shift, tableSize := uint32(32-8), 1<<8
|
||||
MOVQ $24, CX
|
||||
MOVQ $256, DX
|
||||
|
||||
calcShift:
|
||||
// for ; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 {
|
||||
// shift--
|
||||
// }
|
||||
CMPQ DX, $16384
|
||||
JGE varTable
|
||||
CMPQ DX, R14
|
||||
JGE varTable
|
||||
SUBQ $1, CX
|
||||
SHLQ $1, DX
|
||||
JMP calcShift
|
||||
|
||||
varTable:
|
||||
// var table [maxTableSize]uint16
|
||||
//
|
||||
// In the asm code, unlike the Go code, we can zero-initialize only the
|
||||
// first tableSize elements. Each uint16 element is 2 bytes and each MOVOU
|
||||
// writes 16 bytes, so we can do only tableSize/8 writes instead of the
|
||||
// 2048 writes that would zero-initialize all of table's 32768 bytes.
|
||||
SHRQ $3, DX
|
||||
LEAQ table-32768(SP), BX
|
||||
PXOR X0, X0
|
||||
|
||||
memclr:
|
||||
MOVOU X0, 0(BX)
|
||||
ADDQ $16, BX
|
||||
SUBQ $1, DX
|
||||
JNZ memclr
|
||||
|
||||
// !!! DX = &src[0]
|
||||
MOVQ SI, DX
|
||||
|
||||
// sLimit := len(src) - inputMargin
|
||||
MOVQ R14, R9
|
||||
SUBQ $15, R9
|
||||
|
||||
// !!! Pre-emptively spill CX, DX and R9 to the stack. Their values don't
|
||||
// change for the rest of the function.
|
||||
MOVQ CX, 56(SP)
|
||||
MOVQ DX, 64(SP)
|
||||
MOVQ R9, 88(SP)
|
||||
|
||||
// nextEmit := 0
|
||||
MOVQ DX, R10
|
||||
|
||||
// s := 1
|
||||
ADDQ $1, SI
|
||||
|
||||
// nextHash := hash(load32(src, s), shift)
|
||||
MOVL 0(SI), R11
|
||||
IMULL $0x1e35a7bd, R11
|
||||
SHRL CX, R11
|
||||
|
||||
outer:
|
||||
// for { etc }
|
||||
|
||||
// skip := 32
|
||||
MOVQ $32, R12
|
||||
|
||||
// nextS := s
|
||||
MOVQ SI, R13
|
||||
|
||||
// candidate := 0
|
||||
MOVQ $0, R15
|
||||
|
||||
inner0:
|
||||
// for { etc }
|
||||
|
||||
// s := nextS
|
||||
MOVQ R13, SI
|
||||
|
||||
// bytesBetweenHashLookups := skip >> 5
|
||||
MOVQ R12, R14
|
||||
SHRQ $5, R14
|
||||
|
||||
// nextS = s + bytesBetweenHashLookups
|
||||
ADDQ R14, R13
|
||||
|
||||
// skip += bytesBetweenHashLookups
|
||||
ADDQ R14, R12
|
||||
|
||||
// if nextS > sLimit { goto emitRemainder }
|
||||
MOVQ R13, AX
|
||||
SUBQ DX, AX
|
||||
CMPQ AX, R9
|
||||
JA emitRemainder
|
||||
|
||||
// candidate = int(table[nextHash])
|
||||
// XXX: MOVWQZX table-32768(SP)(R11*2), R15
|
||||
// XXX: 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15
|
||||
BYTE $0x4e
|
||||
BYTE $0x0f
|
||||
BYTE $0xb7
|
||||
BYTE $0x7c
|
||||
BYTE $0x5c
|
||||
BYTE $0x78
|
||||
|
||||
// table[nextHash] = uint16(s)
|
||||
MOVQ SI, AX
|
||||
SUBQ DX, AX
|
||||
|
||||
// XXX: MOVW AX, table-32768(SP)(R11*2)
|
||||
// XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2)
|
||||
BYTE $0x66
|
||||
BYTE $0x42
|
||||
BYTE $0x89
|
||||
BYTE $0x44
|
||||
BYTE $0x5c
|
||||
BYTE $0x78
|
||||
|
||||
// nextHash = hash(load32(src, nextS), shift)
|
||||
MOVL 0(R13), R11
|
||||
IMULL $0x1e35a7bd, R11
|
||||
SHRL CX, R11
|
||||
|
||||
// if load32(src, s) != load32(src, candidate) { continue } break
|
||||
MOVL 0(SI), AX
|
||||
MOVL (DX)(R15*1), BX
|
||||
CMPL AX, BX
|
||||
JNE inner0
|
||||
|
||||
fourByteMatch:
|
||||
// As per the encode_other.go code:
|
||||
//
|
||||
// A 4-byte match has been found. We'll later see etc.
|
||||
|
||||
// !!! Jump to a fast path for short (<= 16 byte) literals. See the comment
|
||||
// on inputMargin in encode.go.
|
||||
MOVQ SI, AX
|
||||
SUBQ R10, AX
|
||||
CMPQ AX, $16
|
||||
JLE emitLiteralFastPath
|
||||
|
||||
// ----------------------------------------
|
||||
// Begin inline of the emitLiteral call.
|
||||
//
|
||||
// d += emitLiteral(dst[d:], src[nextEmit:s])
|
||||
|
||||
MOVL AX, BX
|
||||
SUBL $1, BX
|
||||
|
||||
CMPL BX, $60
|
||||
JLT inlineEmitLiteralOneByte
|
||||
CMPL BX, $256
|
||||
JLT inlineEmitLiteralTwoBytes
|
||||
|
||||
inlineEmitLiteralThreeBytes:
|
||||
MOVB $0xf4, 0(DI)
|
||||
MOVW BX, 1(DI)
|
||||
ADDQ $3, DI
|
||||
JMP inlineEmitLiteralMemmove
|
||||
|
||||
inlineEmitLiteralTwoBytes:
|
||||
MOVB $0xf0, 0(DI)
|
||||
MOVB BX, 1(DI)
|
||||
ADDQ $2, DI
|
||||
JMP inlineEmitLiteralMemmove
|
||||
|
||||
inlineEmitLiteralOneByte:
|
||||
SHLB $2, BX
|
||||
MOVB BX, 0(DI)
|
||||
ADDQ $1, DI
|
||||
|
||||
inlineEmitLiteralMemmove:
|
||||
// Spill local variables (registers) onto the stack; call; unspill.
|
||||
//
|
||||
// copy(dst[i:], lit)
|
||||
//
|
||||
// This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push
|
||||
// DI, R10 and AX as arguments.
|
||||
MOVQ DI, 0(SP)
|
||||
MOVQ R10, 8(SP)
|
||||
MOVQ AX, 16(SP)
|
||||
ADDQ AX, DI // Finish the "d +=" part of "d += emitLiteral(etc)".
|
||||
MOVQ SI, 72(SP)
|
||||
MOVQ DI, 80(SP)
|
||||
MOVQ R15, 112(SP)
|
||||
CALL runtime·memmove(SB)
|
||||
MOVQ 56(SP), CX
|
||||
MOVQ 64(SP), DX
|
||||
MOVQ 72(SP), SI
|
||||
MOVQ 80(SP), DI
|
||||
MOVQ 88(SP), R9
|
||||
MOVQ 112(SP), R15
|
||||
JMP inner1
|
||||
|
||||
inlineEmitLiteralEnd:
|
||||
// End inline of the emitLiteral call.
|
||||
// ----------------------------------------
|
||||
|
||||
emitLiteralFastPath:
|
||||
// !!! Emit the 1-byte encoding "uint8(len(lit)-1)<<2".
|
||||
MOVB AX, BX
|
||||
SUBB $1, BX
|
||||
SHLB $2, BX
|
||||
MOVB BX, (DI)
|
||||
ADDQ $1, DI
|
||||
|
||||
// !!! Implement the copy from lit to dst as a 16-byte load and store.
|
||||
// (Encode's documentation says that dst and src must not overlap.)
|
||||
//
|
||||
// This always copies 16 bytes, instead of only len(lit) bytes, but that's
|
||||
// OK. Subsequent iterations will fix up the overrun.
|
||||
//
|
||||
// Note that on amd64, it is legal and cheap to issue unaligned 8-byte or
|
||||
// 16-byte loads and stores. This technique probably wouldn't be as
|
||||
// effective on architectures that are fussier about alignment.
|
||||
MOVOU 0(R10), X0
|
||||
MOVOU X0, 0(DI)
|
||||
ADDQ AX, DI
|
||||
|
||||
inner1:
|
||||
// for { etc }
|
||||
|
||||
// base := s
|
||||
MOVQ SI, R12
|
||||
|
||||
// !!! offset := base - candidate
|
||||
MOVQ R12, R11
|
||||
SUBQ R15, R11
|
||||
SUBQ DX, R11
|
||||
|
||||
// ----------------------------------------
|
||||
// Begin inline of the extendMatch call.
|
||||
//
|
||||
// s = extendMatch(src, candidate+4, s+4)
|
||||
|
||||
// !!! R14 = &src[len(src)]
|
||||
MOVQ src_len+32(FP), R14
|
||||
ADDQ DX, R14
|
||||
|
||||
// !!! R13 = &src[len(src) - 8]
|
||||
MOVQ R14, R13
|
||||
SUBQ $8, R13
|
||||
|
||||
// !!! R15 = &src[candidate + 4]
|
||||
ADDQ $4, R15
|
||||
ADDQ DX, R15
|
||||
|
||||
// !!! s += 4
|
||||
ADDQ $4, SI
|
||||
|
||||
inlineExtendMatchCmp8:
|
||||
// As long as we are 8 or more bytes before the end of src, we can load and
|
||||
// compare 8 bytes at a time. If those 8 bytes are equal, repeat.
|
||||
CMPQ SI, R13
|
||||
JA inlineExtendMatchCmp1
|
||||
MOVQ (R15), AX
|
||||
MOVQ (SI), BX
|
||||
CMPQ AX, BX
|
||||
JNE inlineExtendMatchBSF
|
||||
ADDQ $8, R15
|
||||
ADDQ $8, SI
|
||||
JMP inlineExtendMatchCmp8
|
||||
|
||||
inlineExtendMatchBSF:
|
||||
// If those 8 bytes were not equal, XOR the two 8 byte values, and return
|
||||
// the index of the first byte that differs. The BSF instruction finds the
|
||||
// least significant 1 bit, the amd64 architecture is little-endian, and
|
||||
// the shift by 3 converts a bit index to a byte index.
|
||||
XORQ AX, BX
|
||||
BSFQ BX, BX
|
||||
SHRQ $3, BX
|
||||
ADDQ BX, SI
|
||||
JMP inlineExtendMatchEnd
|
||||
|
||||
inlineExtendMatchCmp1:
|
||||
// In src's tail, compare 1 byte at a time.
|
||||
CMPQ SI, R14
|
||||
JAE inlineExtendMatchEnd
|
||||
MOVB (R15), AX
|
||||
MOVB (SI), BX
|
||||
CMPB AX, BX
|
||||
JNE inlineExtendMatchEnd
|
||||
ADDQ $1, R15
|
||||
ADDQ $1, SI
|
||||
JMP inlineExtendMatchCmp1
|
||||
|
||||
inlineExtendMatchEnd:
|
||||
// End inline of the extendMatch call.
|
||||
// ----------------------------------------
|
||||
|
||||
// ----------------------------------------
|
||||
// Begin inline of the emitCopy call.
|
||||
//
|
||||
// d += emitCopy(dst[d:], base-candidate, s-base)
|
||||
|
||||
// !!! length := s - base
|
||||
MOVQ SI, AX
|
||||
SUBQ R12, AX
|
||||
|
||||
inlineEmitCopyLoop0:
|
||||
// for length >= 68 { etc }
|
||||
CMPL AX, $68
|
||||
JLT inlineEmitCopyStep1
|
||||
|
||||
// Emit a length 64 copy, encoded as 3 bytes.
|
||||
MOVB $0xfe, 0(DI)
|
||||
MOVW R11, 1(DI)
|
||||
ADDQ $3, DI
|
||||
SUBL $64, AX
|
||||
JMP inlineEmitCopyLoop0
|
||||
|
||||
inlineEmitCopyStep1:
|
||||
// if length > 64 { etc }
|
||||
CMPL AX, $64
|
||||
JLE inlineEmitCopyStep2
|
||||
|
||||
// Emit a length 60 copy, encoded as 3 bytes.
|
||||
MOVB $0xee, 0(DI)
|
||||
MOVW R11, 1(DI)
|
||||
ADDQ $3, DI
|
||||
SUBL $60, AX
|
||||
|
||||
inlineEmitCopyStep2:
|
||||
// if length >= 12 || offset >= 2048 { goto inlineEmitCopyStep3 }
|
||||
CMPL AX, $12
|
||||
JGE inlineEmitCopyStep3
|
||||
CMPL R11, $2048
|
||||
JGE inlineEmitCopyStep3
|
||||
|
||||
// Emit the remaining copy, encoded as 2 bytes.
|
||||
MOVB R11, 1(DI)
|
||||
SHRL $8, R11
|
||||
SHLB $5, R11
|
||||
SUBB $4, AX
|
||||
SHLB $2, AX
|
||||
ORB AX, R11
|
||||
ORB $1, R11
|
||||
MOVB R11, 0(DI)
|
||||
ADDQ $2, DI
|
||||
JMP inlineEmitCopyEnd
|
||||
|
||||
inlineEmitCopyStep3:
|
||||
// Emit the remaining copy, encoded as 3 bytes.
|
||||
SUBL $1, AX
|
||||
SHLB $2, AX
|
||||
ORB $2, AX
|
||||
MOVB AX, 0(DI)
|
||||
MOVW R11, 1(DI)
|
||||
ADDQ $3, DI
|
||||
|
||||
inlineEmitCopyEnd:
|
||||
// End inline of the emitCopy call.
|
||||
// ----------------------------------------
|
||||
|
||||
// nextEmit = s
|
||||
MOVQ SI, R10
|
||||
|
||||
// if s >= sLimit { goto emitRemainder }
|
||||
MOVQ SI, AX
|
||||
SUBQ DX, AX
|
||||
CMPQ AX, R9
|
||||
JAE emitRemainder
|
||||
|
||||
// As per the encode_other.go code:
|
||||
//
|
||||
// We could immediately etc.
|
||||
|
||||
// x := load64(src, s-1)
|
||||
MOVQ -1(SI), R14
|
||||
|
||||
// prevHash := hash(uint32(x>>0), shift)
|
||||
MOVL R14, R11
|
||||
IMULL $0x1e35a7bd, R11
|
||||
SHRL CX, R11
|
||||
|
||||
// table[prevHash] = uint16(s-1)
|
||||
MOVQ SI, AX
|
||||
SUBQ DX, AX
|
||||
SUBQ $1, AX
|
||||
|
||||
// XXX: MOVW AX, table-32768(SP)(R11*2)
|
||||
// XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2)
|
||||
BYTE $0x66
|
||||
BYTE $0x42
|
||||
BYTE $0x89
|
||||
BYTE $0x44
|
||||
BYTE $0x5c
|
||||
BYTE $0x78
|
||||
|
||||
// currHash := hash(uint32(x>>8), shift)
|
||||
SHRQ $8, R14
|
||||
MOVL R14, R11
|
||||
IMULL $0x1e35a7bd, R11
|
||||
SHRL CX, R11
|
||||
|
||||
// candidate = int(table[currHash])
|
||||
// XXX: MOVWQZX table-32768(SP)(R11*2), R15
|
||||
// XXX: 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15
|
||||
BYTE $0x4e
|
||||
BYTE $0x0f
|
||||
BYTE $0xb7
|
||||
BYTE $0x7c
|
||||
BYTE $0x5c
|
||||
BYTE $0x78
|
||||
|
||||
// table[currHash] = uint16(s)
|
||||
ADDQ $1, AX
|
||||
|
||||
// XXX: MOVW AX, table-32768(SP)(R11*2)
|
||||
// XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2)
|
||||
BYTE $0x66
|
||||
BYTE $0x42
|
||||
BYTE $0x89
|
||||
BYTE $0x44
|
||||
BYTE $0x5c
|
||||
BYTE $0x78
|
||||
|
||||
// if uint32(x>>8) == load32(src, candidate) { continue }
|
||||
MOVL (DX)(R15*1), BX
|
||||
CMPL R14, BX
|
||||
JEQ inner1
|
||||
|
||||
// nextHash = hash(uint32(x>>16), shift)
|
||||
SHRQ $8, R14
|
||||
MOVL R14, R11
|
||||
IMULL $0x1e35a7bd, R11
|
||||
SHRL CX, R11
|
||||
|
||||
// s++
|
||||
ADDQ $1, SI
|
||||
|
||||
// break out of the inner1 for loop, i.e. continue the outer loop.
|
||||
JMP outer
|
||||
|
||||
emitRemainder:
|
||||
// if nextEmit < len(src) { etc }
|
||||
MOVQ src_len+32(FP), AX
|
||||
ADDQ DX, AX
|
||||
CMPQ R10, AX
|
||||
JEQ encodeBlockEnd
|
||||
|
||||
// d += emitLiteral(dst[d:], src[nextEmit:])
|
||||
//
|
||||
// Push args.
|
||||
MOVQ DI, 0(SP)
|
||||
MOVQ $0, 8(SP) // Unnecessary, as the callee ignores it, but conservative.
|
||||
MOVQ $0, 16(SP) // Unnecessary, as the callee ignores it, but conservative.
|
||||
MOVQ R10, 24(SP)
|
||||
SUBQ R10, AX
|
||||
MOVQ AX, 32(SP)
|
||||
MOVQ AX, 40(SP) // Unnecessary, as the callee ignores it, but conservative.
|
||||
|
||||
// Spill local variables (registers) onto the stack; call; unspill.
|
||||
MOVQ DI, 80(SP)
|
||||
CALL ·emitLiteral(SB)
|
||||
MOVQ 80(SP), DI
|
||||
|
||||
// Finish the "d +=" part of "d += emitLiteral(etc)".
|
||||
ADDQ 48(SP), DI
|
||||
|
||||
encodeBlockEnd:
|
||||
MOVQ dst_base+0(FP), AX
|
||||
SUBQ AX, DI
|
||||
MOVQ DI, d+48(FP)
|
||||
RET
|
|
@ -0,0 +1,238 @@
|
|||
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !amd64 appengine !gc noasm
|
||||
|
||||
package snappy
|
||||
|
||||
func load32(b []byte, i int) uint32 {
|
||||
b = b[i : i+4 : len(b)] // Help the compiler eliminate bounds checks on the next line.
|
||||
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
|
||||
}
|
||||
|
||||
func load64(b []byte, i int) uint64 {
|
||||
b = b[i : i+8 : len(b)] // Help the compiler eliminate bounds checks on the next line.
|
||||
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
|
||||
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
|
||||
}
|
||||
|
||||
// emitLiteral writes a literal chunk and returns the number of bytes written.
|
||||
//
|
||||
// It assumes that:
|
||||
// dst is long enough to hold the encoded bytes
|
||||
// 1 <= len(lit) && len(lit) <= 65536
|
||||
func emitLiteral(dst, lit []byte) int {
|
||||
i, n := 0, uint(len(lit)-1)
|
||||
switch {
|
||||
case n < 60:
|
||||
dst[0] = uint8(n)<<2 | tagLiteral
|
||||
i = 1
|
||||
case n < 1<<8:
|
||||
dst[0] = 60<<2 | tagLiteral
|
||||
dst[1] = uint8(n)
|
||||
i = 2
|
||||
default:
|
||||
dst[0] = 61<<2 | tagLiteral
|
||||
dst[1] = uint8(n)
|
||||
dst[2] = uint8(n >> 8)
|
||||
i = 3
|
||||
}
|
||||
return i + copy(dst[i:], lit)
|
||||
}
|
||||
|
||||
// emitCopy writes a copy chunk and returns the number of bytes written.
|
||||
//
|
||||
// It assumes that:
|
||||
// dst is long enough to hold the encoded bytes
|
||||
// 1 <= offset && offset <= 65535
|
||||
// 4 <= length && length <= 65535
|
||||
func emitCopy(dst []byte, offset, length int) int {
|
||||
i := 0
|
||||
// The maximum length for a single tagCopy1 or tagCopy2 op is 64 bytes. The
|
||||
// threshold for this loop is a little higher (at 68 = 64 + 4), and the
|
||||
// length emitted down below is is a little lower (at 60 = 64 - 4), because
|
||||
// it's shorter to encode a length 67 copy as a length 60 tagCopy2 followed
|
||||
// by a length 7 tagCopy1 (which encodes as 3+2 bytes) than to encode it as
|
||||
// a length 64 tagCopy2 followed by a length 3 tagCopy2 (which encodes as
|
||||
// 3+3 bytes). The magic 4 in the 64±4 is because the minimum length for a
|
||||
// tagCopy1 op is 4 bytes, which is why a length 3 copy has to be an
|
||||
// encodes-as-3-bytes tagCopy2 instead of an encodes-as-2-bytes tagCopy1.
|
||||
for length >= 68 {
|
||||
// Emit a length 64 copy, encoded as 3 bytes.
|
||||
dst[i+0] = 63<<2 | tagCopy2
|
||||
dst[i+1] = uint8(offset)
|
||||
dst[i+2] = uint8(offset >> 8)
|
||||
i += 3
|
||||
length -= 64
|
||||
}
|
||||
if length > 64 {
|
||||
// Emit a length 60 copy, encoded as 3 bytes.
|
||||
dst[i+0] = 59<<2 | tagCopy2
|
||||
dst[i+1] = uint8(offset)
|
||||
dst[i+2] = uint8(offset >> 8)
|
||||
i += 3
|
||||
length -= 60
|
||||
}
|
||||
if length >= 12 || offset >= 2048 {
|
||||
// Emit the remaining copy, encoded as 3 bytes.
|
||||
dst[i+0] = uint8(length-1)<<2 | tagCopy2
|
||||
dst[i+1] = uint8(offset)
|
||||
dst[i+2] = uint8(offset >> 8)
|
||||
return i + 3
|
||||
}
|
||||
// Emit the remaining copy, encoded as 2 bytes.
|
||||
dst[i+0] = uint8(offset>>8)<<5 | uint8(length-4)<<2 | tagCopy1
|
||||
dst[i+1] = uint8(offset)
|
||||
return i + 2
|
||||
}
|
||||
|
||||
// extendMatch returns the largest k such that k <= len(src) and that
|
||||
// src[i:i+k-j] and src[j:k] have the same contents.
|
||||
//
|
||||
// It assumes that:
|
||||
// 0 <= i && i < j && j <= len(src)
|
||||
func extendMatch(src []byte, i, j int) int {
|
||||
for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 {
|
||||
}
|
||||
return j
|
||||
}
|
||||
|
||||
func hash(u, shift uint32) uint32 {
|
||||
return (u * 0x1e35a7bd) >> shift
|
||||
}
|
||||
|
||||
// encodeBlock encodes a non-empty src to a guaranteed-large-enough dst. It
|
||||
// assumes that the varint-encoded length of the decompressed bytes has already
|
||||
// been written.
|
||||
//
|
||||
// It also assumes that:
|
||||
// len(dst) >= MaxEncodedLen(len(src)) &&
|
||||
// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize
|
||||
func encodeBlock(dst, src []byte) (d int) {
|
||||
// Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive.
|
||||
// The table element type is uint16, as s < sLimit and sLimit < len(src)
|
||||
// and len(src) <= maxBlockSize and maxBlockSize == 65536.
|
||||
const (
|
||||
maxTableSize = 1 << 14
|
||||
// tableMask is redundant, but helps the compiler eliminate bounds
|
||||
// checks.
|
||||
tableMask = maxTableSize - 1
|
||||
)
|
||||
shift := uint32(32 - 8)
|
||||
for tableSize := 1 << 8; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 {
|
||||
shift--
|
||||
}
|
||||
// In Go, all array elements are zero-initialized, so there is no advantage
|
||||
// to a smaller tableSize per se. However, it matches the C++ algorithm,
|
||||
// and in the asm versions of this code, we can get away with zeroing only
|
||||
// the first tableSize elements.
|
||||
var table [maxTableSize]uint16
|
||||
|
||||
// sLimit is when to stop looking for offset/length copies. The inputMargin
|
||||
// lets us use a fast path for emitLiteral in the main loop, while we are
|
||||
// looking for copies.
|
||||
sLimit := len(src) - inputMargin
|
||||
|
||||
// nextEmit is where in src the next emitLiteral should start from.
|
||||
nextEmit := 0
|
||||
|
||||
// The encoded form must start with a literal, as there are no previous
|
||||
// bytes to copy, so we start looking for hash matches at s == 1.
|
||||
s := 1
|
||||
nextHash := hash(load32(src, s), shift)
|
||||
|
||||
for {
|
||||
// Copied from the C++ snappy implementation:
|
||||
//
|
||||
// Heuristic match skipping: If 32 bytes are scanned with no matches
|
||||
// found, start looking only at every other byte. If 32 more bytes are
|
||||
// scanned (or skipped), look at every third byte, etc.. When a match
|
||||
// is found, immediately go back to looking at every byte. This is a
|
||||
// small loss (~5% performance, ~0.1% density) for compressible data
|
||||
// due to more bookkeeping, but for non-compressible data (such as
|
||||
// JPEG) it's a huge win since the compressor quickly "realizes" the
|
||||
// data is incompressible and doesn't bother looking for matches
|
||||
// everywhere.
|
||||
//
|
||||
// The "skip" variable keeps track of how many bytes there are since
|
||||
// the last match; dividing it by 32 (ie. right-shifting by five) gives
|
||||
// the number of bytes to move ahead for each iteration.
|
||||
skip := 32
|
||||
|
||||
nextS := s
|
||||
candidate := 0
|
||||
for {
|
||||
s = nextS
|
||||
bytesBetweenHashLookups := skip >> 5
|
||||
nextS = s + bytesBetweenHashLookups
|
||||
skip += bytesBetweenHashLookups
|
||||
if nextS > sLimit {
|
||||
goto emitRemainder
|
||||
}
|
||||
candidate = int(table[nextHash&tableMask])
|
||||
table[nextHash&tableMask] = uint16(s)
|
||||
nextHash = hash(load32(src, nextS), shift)
|
||||
if load32(src, s) == load32(src, candidate) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// A 4-byte match has been found. We'll later see if more than 4 bytes
|
||||
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
|
||||
// them as literal bytes.
|
||||
d += emitLiteral(dst[d:], src[nextEmit:s])
|
||||
|
||||
// Call emitCopy, and then see if another emitCopy could be our next
|
||||
// move. Repeat until we find no match for the input immediately after
|
||||
// what was consumed by the last emitCopy call.
|
||||
//
|
||||
// If we exit this loop normally then we need to call emitLiteral next,
|
||||
// though we don't yet know how big the literal will be. We handle that
|
||||
// by proceeding to the next iteration of the main loop. We also can
|
||||
// exit this loop via goto if we get close to exhausting the input.
|
||||
for {
|
||||
// Invariant: we have a 4-byte match at s, and no need to emit any
|
||||
// literal bytes prior to s.
|
||||
base := s
|
||||
|
||||
// Extend the 4-byte match as long as possible.
|
||||
//
|
||||
// This is an inlined version of:
|
||||
// s = extendMatch(src, candidate+4, s+4)
|
||||
s += 4
|
||||
for i := candidate + 4; s < len(src) && src[i] == src[s]; i, s = i+1, s+1 {
|
||||
}
|
||||
|
||||
d += emitCopy(dst[d:], base-candidate, s-base)
|
||||
nextEmit = s
|
||||
if s >= sLimit {
|
||||
goto emitRemainder
|
||||
}
|
||||
|
||||
// We could immediately start working at s now, but to improve
|
||||
// compression we first update the hash table at s-1 and at s. If
|
||||
// another emitCopy is not our next move, also calculate nextHash
|
||||
// at s+1. At least on GOARCH=amd64, these three hash calculations
|
||||
// are faster as one load64 call (with some shifts) instead of
|
||||
// three load32 calls.
|
||||
x := load64(src, s-1)
|
||||
prevHash := hash(uint32(x>>0), shift)
|
||||
table[prevHash&tableMask] = uint16(s - 1)
|
||||
currHash := hash(uint32(x>>8), shift)
|
||||
candidate = int(table[currHash&tableMask])
|
||||
table[currHash&tableMask] = uint16(s)
|
||||
if uint32(x>>8) != load32(src, candidate) {
|
||||
nextHash = hash(uint32(x>>16), shift)
|
||||
s++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
emitRemainder:
|
||||
if nextEmit < len(src) {
|
||||
d += emitLiteral(dst[d:], src[nextEmit:])
|
||||
}
|
||||
return d
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,98 @@
|
|||
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package snappy implements the Snappy compression format. It aims for very
|
||||
// high speeds and reasonable compression.
|
||||
//
|
||||
// There are actually two Snappy formats: block and stream. They are related,
|
||||
// but different: trying to decompress block-compressed data as a Snappy stream
|
||||
// will fail, and vice versa. The block format is the Decode and Encode
|
||||
// functions and the stream format is the Reader and Writer types.
|
||||
//
|
||||
// The block format, the more common case, is used when the complete size (the
|
||||
// number of bytes) of the original data is known upfront, at the time
|
||||
// compression starts. The stream format, also known as the framing format, is
|
||||
// for when that isn't always true.
|
||||
//
|
||||
// The canonical, C++ implementation is at https://github.com/google/snappy and
|
||||
// it only implements the block format.
|
||||
package snappy
|
||||
|
||||
import (
|
||||
"hash/crc32"
|
||||
)
|
||||
|
||||
/*
|
||||
Each encoded block begins with the varint-encoded length of the decoded data,
|
||||
followed by a sequence of chunks. Chunks begin and end on byte boundaries. The
|
||||
first byte of each chunk is broken into its 2 least and 6 most significant bits
|
||||
called l and m: l ranges in [0, 4) and m ranges in [0, 64). l is the chunk tag.
|
||||
Zero means a literal tag. All other values mean a copy tag.
|
||||
|
||||
For literal tags:
|
||||
- If m < 60, the next 1 + m bytes are literal bytes.
|
||||
- Otherwise, let n be the little-endian unsigned integer denoted by the next
|
||||
m - 59 bytes. The next 1 + n bytes after that are literal bytes.
|
||||
|
||||
For copy tags, length bytes are copied from offset bytes ago, in the style of
|
||||
Lempel-Ziv compression algorithms. In particular:
|
||||
- For l == 1, the offset ranges in [0, 1<<11) and the length in [4, 12).
|
||||
The length is 4 + the low 3 bits of m. The high 3 bits of m form bits 8-10
|
||||
of the offset. The next byte is bits 0-7 of the offset.
|
||||
- For l == 2, the offset ranges in [0, 1<<16) and the length in [1, 65).
|
||||
The length is 1 + m. The offset is the little-endian unsigned integer
|
||||
denoted by the next 2 bytes.
|
||||
- For l == 3, this tag is a legacy format that is no longer issued by most
|
||||
encoders. Nonetheless, the offset ranges in [0, 1<<32) and the length in
|
||||
[1, 65). The length is 1 + m. The offset is the little-endian unsigned
|
||||
integer denoted by the next 4 bytes.
|
||||
*/
|
||||
const (
|
||||
tagLiteral = 0x00
|
||||
tagCopy1 = 0x01
|
||||
tagCopy2 = 0x02
|
||||
tagCopy4 = 0x03
|
||||
)
|
||||
|
||||
const (
|
||||
checksumSize = 4
|
||||
chunkHeaderSize = 4
|
||||
magicChunk = "\xff\x06\x00\x00" + magicBody
|
||||
magicBody = "sNaPpY"
|
||||
|
||||
// maxBlockSize is the maximum size of the input to encodeBlock. It is not
|
||||
// part of the wire format per se, but some parts of the encoder assume
|
||||
// that an offset fits into a uint16.
|
||||
//
|
||||
// Also, for the framing format (Writer type instead of Encode function),
|
||||
// https://github.com/google/snappy/blob/master/framing_format.txt says
|
||||
// that "the uncompressed data in a chunk must be no longer than 65536
|
||||
// bytes".
|
||||
maxBlockSize = 65536
|
||||
|
||||
// maxEncodedLenOfMaxBlockSize equals MaxEncodedLen(maxBlockSize), but is
|
||||
// hard coded to be a const instead of a variable, so that obufLen can also
|
||||
// be a const. Their equivalence is confirmed by
|
||||
// TestMaxEncodedLenOfMaxBlockSize.
|
||||
maxEncodedLenOfMaxBlockSize = 76490
|
||||
|
||||
obufHeaderLen = len(magicChunk) + checksumSize + chunkHeaderSize
|
||||
obufLen = obufHeaderLen + maxEncodedLenOfMaxBlockSize
|
||||
)
|
||||
|
||||
const (
|
||||
chunkTypeCompressedData = 0x00
|
||||
chunkTypeUncompressedData = 0x01
|
||||
chunkTypePadding = 0xfe
|
||||
chunkTypeStreamIdentifier = 0xff
|
||||
)
|
||||
|
||||
var crcTable = crc32.MakeTable(crc32.Castagnoli)
|
||||
|
||||
// crc implements the checksum specified in section 3 of
|
||||
// https://github.com/google/snappy/blob/master/framing_format.txt
|
||||
func crc(b []byte) uint32 {
|
||||
c := crc32.Update(0, crcTable, b)
|
||||
return uint32(c>>15|c<<17) + 0xa282ead8
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -3,7 +3,10 @@ package server
|
|||
import (
|
||||
"errors"
|
||||
"github.com/cnlh/nps/bridge"
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"github.com/cnlh/nps/lib/conn"
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"github.com/cnlh/nps/lib/pool"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
@ -13,8 +16,8 @@ import (
|
|||
type server struct {
|
||||
id int
|
||||
bridge *bridge.Bridge
|
||||
task *lib.Tunnel
|
||||
config *lib.Config
|
||||
task *file.Tunnel
|
||||
config *file.Config
|
||||
errorContent []byte
|
||||
sync.Mutex
|
||||
}
|
||||
|
@ -26,7 +29,7 @@ func (s *server) FlowAdd(in, out int64) {
|
|||
s.task.Flow.InletFlow += in
|
||||
}
|
||||
|
||||
func (s *server) FlowAddHost(host *lib.Host, in, out int64) {
|
||||
func (s *server) FlowAddHost(host *file.Host, in, out int64) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
host.Flow.ExportFlow += out
|
||||
|
@ -36,7 +39,7 @@ func (s *server) FlowAddHost(host *lib.Host, in, out int64) {
|
|||
//热更新配置
|
||||
func (s *server) ResetConfig() bool {
|
||||
//获取最新数据
|
||||
task, err := lib.GetCsvDb().GetTask(s.task.Id)
|
||||
task, err := file.GetCsvDb().GetTask(s.task.Id)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
@ -45,7 +48,7 @@ func (s *server) ResetConfig() bool {
|
|||
}
|
||||
s.task.UseClientCnf = task.UseClientCnf
|
||||
//使用客户端配置
|
||||
client, err := lib.GetCsvDb().GetClient(s.task.Client.Id)
|
||||
client, err := file.GetCsvDb().GetClient(s.task.Client.Id)
|
||||
if s.task.UseClientCnf {
|
||||
if err == nil {
|
||||
s.config.U = client.Cnf.U
|
||||
|
@ -62,11 +65,11 @@ func (s *server) ResetConfig() bool {
|
|||
}
|
||||
}
|
||||
s.task.Client.Rate = client.Rate
|
||||
s.config.CompressDecode, s.config.CompressEncode = lib.GetCompressType(s.config.Compress)
|
||||
s.config.CompressDecode, s.config.CompressEncode = common.GetCompressType(s.config.Compress)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *server) linkCopy(link *lib.Link, c *lib.Conn, rb []byte, tunnel *lib.Conn, flow *lib.Flow) {
|
||||
func (s *server) linkCopy(link *conn.Link, c *conn.Conn, rb []byte, tunnel *conn.Conn, flow *file.Flow) {
|
||||
if rb != nil {
|
||||
if _, err := tunnel.SendMsg(rb, link); err != nil {
|
||||
c.Close()
|
||||
|
@ -74,32 +77,32 @@ func (s *server) linkCopy(link *lib.Link, c *lib.Conn, rb []byte, tunnel *lib.Co
|
|||
}
|
||||
flow.Add(len(rb), 0)
|
||||
}
|
||||
|
||||
buf := pool.BufPoolCopy.Get().([]byte)
|
||||
for {
|
||||
buf := lib.BufPoolCopy.Get().([]byte)
|
||||
if n, err := c.Read(buf); err != nil {
|
||||
tunnel.SendMsg([]byte(lib.IO_EOF), link)
|
||||
tunnel.SendMsg([]byte(common.IO_EOF), link)
|
||||
break
|
||||
} else {
|
||||
if _, err := tunnel.SendMsg(buf[:n], link); err != nil {
|
||||
lib.PutBufPoolCopy(buf)
|
||||
c.Close()
|
||||
break
|
||||
}
|
||||
lib.PutBufPoolCopy(buf)
|
||||
flow.Add(n, 0)
|
||||
}
|
||||
}
|
||||
pool.PutBufPoolCopy(buf)
|
||||
}
|
||||
|
||||
func (s *server) writeConnFail(c net.Conn) {
|
||||
c.Write([]byte(lib.ConnectionFailBytes))
|
||||
c.Write([]byte(common.ConnectionFailBytes))
|
||||
c.Write(s.errorContent)
|
||||
}
|
||||
|
||||
//权限认证
|
||||
func (s *server) auth(r *http.Request, c *lib.Conn, u, p string) error {
|
||||
if u != "" && p != "" && !lib.CheckAuth(r, u, p) {
|
||||
c.Write([]byte(lib.UnauthorizedBytes))
|
||||
func (s *server) auth(r *http.Request, c *conn.Conn, u, p string) error {
|
||||
if u != "" && p != "" && !common.CheckAuth(r, u, p) {
|
||||
c.Write([]byte(common.UnauthorizedBytes))
|
||||
c.Close()
|
||||
return errors.New("401 Unauthorized")
|
||||
}
|
||||
|
|
|
@ -3,9 +3,13 @@ package server
|
|||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/cnlh/nps/lib/beego"
|
||||
"github.com/cnlh/nps/bridge"
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/conn"
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"github.com/cnlh/nps/lib/lg"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"path/filepath"
|
||||
|
@ -22,7 +26,7 @@ type httpServer struct {
|
|||
stop chan bool
|
||||
}
|
||||
|
||||
func NewHttp(bridge *bridge.Bridge, c *lib.Tunnel) *httpServer {
|
||||
func NewHttp(bridge *bridge.Bridge, c *file.Tunnel) *httpServer {
|
||||
httpPort, _ := beego.AppConfig.Int("httpProxyPort")
|
||||
httpsPort, _ := beego.AppConfig.Int("httpsProxyPort")
|
||||
pemPath := beego.AppConfig.String("pemPath")
|
||||
|
@ -44,33 +48,33 @@ func NewHttp(bridge *bridge.Bridge, c *lib.Tunnel) *httpServer {
|
|||
func (s *httpServer) Start() error {
|
||||
var err error
|
||||
var http, https *http.Server
|
||||
if s.errorContent, err = lib.ReadAllFromFile(filepath.Join(lib.GetRunPath(), "web", "static", "page", "error.html")); err != nil {
|
||||
if s.errorContent, err = common.ReadAllFromFile(filepath.Join(common.GetRunPath(), "web", "static", "page", "error.html")); err != nil {
|
||||
s.errorContent = []byte("easyProxy 404")
|
||||
}
|
||||
|
||||
if s.httpPort > 0 {
|
||||
http = s.NewServer(s.httpPort)
|
||||
go func() {
|
||||
lib.Println("启动http监听,端口为", s.httpPort)
|
||||
lg.Println("启动http监听,端口为", s.httpPort)
|
||||
err := http.ListenAndServe()
|
||||
if err != nil {
|
||||
lib.Fatalln(err)
|
||||
lg.Fatalln(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
if s.httpsPort > 0 {
|
||||
if !lib.FileExists(s.pemPath) {
|
||||
lib.Fatalf("ssl certFile文件%s不存在", s.pemPath)
|
||||
if !common.FileExists(s.pemPath) {
|
||||
lg.Fatalf("ssl certFile文件%s不存在", s.pemPath)
|
||||
}
|
||||
if !lib.FileExists(s.keyPath) {
|
||||
lib.Fatalf("ssl keyFile文件%s不存在", s.keyPath)
|
||||
if !common.FileExists(s.keyPath) {
|
||||
lg.Fatalf("ssl keyFile文件%s不存在", s.keyPath)
|
||||
}
|
||||
https = s.NewServer(s.httpsPort)
|
||||
go func() {
|
||||
lib.Println("启动https监听,端口为", s.httpsPort)
|
||||
lg.Println("启动https监听,端口为", s.httpsPort)
|
||||
err := https.ListenAndServeTLS(s.pemPath, s.keyPath)
|
||||
if err != nil {
|
||||
lib.Fatalln(err)
|
||||
lg.Fatalln(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@ -96,40 +100,41 @@ func (s *httpServer) handleTunneling(w http.ResponseWriter, r *http.Request) {
|
|||
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
conn, _, err := hijacker.Hijack()
|
||||
c, _, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
}
|
||||
s.process(lib.NewConn(conn), r)
|
||||
s.process(conn.NewConn(c), r)
|
||||
}
|
||||
|
||||
func (s *httpServer) process(c *lib.Conn, r *http.Request) {
|
||||
func (s *httpServer) process(c *conn.Conn, r *http.Request) {
|
||||
//多客户端域名代理
|
||||
var (
|
||||
isConn = true
|
||||
link *lib.Link
|
||||
host *lib.Host
|
||||
tunnel *lib.Conn
|
||||
lk *conn.Link
|
||||
host *file.Host
|
||||
tunnel *conn.Conn
|
||||
err error
|
||||
)
|
||||
for {
|
||||
//首次获取conn
|
||||
if isConn {
|
||||
if host, err = GetInfoByHost(r.Host); err != nil {
|
||||
lib.Printf("the host %s is not found !", r.Host)
|
||||
lg.Printf("the host %s is not found !", r.Host)
|
||||
break
|
||||
}
|
||||
//流量限制
|
||||
if host.Client.Flow.FlowLimit > 0 && (host.Client.Flow.FlowLimit<<20) < (host.Client.Flow.ExportFlow+host.Client.Flow.InletFlow) {
|
||||
break
|
||||
}
|
||||
host.Client.Cnf.CompressDecode, host.Client.Cnf.CompressEncode = lib.GetCompressType(host.Client.Cnf.Compress)
|
||||
host.Client.Cnf.CompressDecode, host.Client.Cnf.CompressEncode = common.GetCompressType(host.Client.Cnf.Compress)
|
||||
//权限控制
|
||||
if err = s.auth(r, c, host.Client.Cnf.U, host.Client.Cnf.P); err != nil {
|
||||
break
|
||||
}
|
||||
link = lib.NewLink(host.Client.GetId(), lib.CONN_TCP, host.GetRandomTarget(), host.Client.Cnf.CompressEncode, host.Client.Cnf.CompressDecode, host.Client.Cnf.Crypt, c, host.Flow, nil, host.Client.Rate, nil)
|
||||
if tunnel, err = s.bridge.SendLinkInfo(host.Client.Id, link); err != nil {
|
||||
lk = conn.NewLink(host.Client.GetId(), common.CONN_TCP, host.GetRandomTarget(), host.Client.Cnf.CompressEncode, host.Client.Cnf.CompressDecode, host.Client.Cnf.Crypt, c, host.Flow, nil, host.Client.Rate, nil)
|
||||
if tunnel, err = s.bridge.SendLinkInfo(host.Client.Id, lk); err != nil {
|
||||
log.Println(err)
|
||||
break
|
||||
}
|
||||
isConn = false
|
||||
|
@ -140,13 +145,13 @@ func (s *httpServer) process(c *lib.Conn, r *http.Request) {
|
|||
}
|
||||
}
|
||||
//根据设定,修改header和host
|
||||
lib.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, c.Conn.RemoteAddr().String())
|
||||
common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, c.Conn.RemoteAddr().String())
|
||||
b, err := httputil.DumpRequest(r, true)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
host.Flow.Add(len(b), 0)
|
||||
if _, err := tunnel.SendMsg(b, link); err != nil {
|
||||
if _, err := tunnel.SendMsg(b, lk); err != nil {
|
||||
c.Close()
|
||||
break
|
||||
}
|
||||
|
@ -155,7 +160,7 @@ func (s *httpServer) process(c *lib.Conn, r *http.Request) {
|
|||
if isConn {
|
||||
s.writeConnFail(c.Conn)
|
||||
} else {
|
||||
tunnel.SendMsg([]byte(lib.IO_EOF), link)
|
||||
tunnel.SendMsg([]byte(common.IO_EOF), lk)
|
||||
}
|
||||
|
||||
c.Close()
|
||||
|
|
|
@ -3,7 +3,8 @@ package server
|
|||
import (
|
||||
"errors"
|
||||
"github.com/cnlh/nps/bridge"
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"github.com/cnlh/nps/lib/lg"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
@ -11,44 +12,41 @@ import (
|
|||
var (
|
||||
Bridge *bridge.Bridge
|
||||
RunList map[int]interface{} //运行中的任务
|
||||
startFinish chan bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
RunList = make(map[int]interface{})
|
||||
startFinish = make(chan bool)
|
||||
}
|
||||
|
||||
//从csv文件中恢复任务
|
||||
func InitFromCsv() {
|
||||
for _, v := range lib.GetCsvDb().Tasks {
|
||||
for _, v := range file.GetCsvDb().Tasks {
|
||||
if v.Status {
|
||||
lib.Println("启动模式:", v.Mode, "监听端口:", v.TcpPort)
|
||||
lg.Println("启动模式:", v.Mode, "监听端口:", v.TcpPort)
|
||||
AddTask(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//start a new server
|
||||
func StartNewServer(bridgePort int, cnf *lib.Tunnel) {
|
||||
Bridge = bridge.NewTunnel(bridgePort, RunList)
|
||||
func StartNewServer(bridgePort int, cnf *file.Tunnel, bridgeType string) {
|
||||
Bridge = bridge.NewTunnel(bridgePort, RunList, bridgeType)
|
||||
if err := Bridge.StartTunnel(); err != nil {
|
||||
lib.Fatalln("服务端开启失败", err)
|
||||
lg.Fatalln("服务端开启失败", err)
|
||||
}
|
||||
if svr := NewMode(Bridge, cnf); svr != nil {
|
||||
RunList[cnf.Id] = svr
|
||||
err := reflect.ValueOf(svr).MethodByName("Start").Call(nil)[0]
|
||||
if err.Interface() != nil {
|
||||
lib.Fatalln(err)
|
||||
lg.Fatalln(err)
|
||||
}
|
||||
} else {
|
||||
lib.Fatalln("启动模式不正确")
|
||||
lg.Fatalln("启动模式不正确")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//new a server by mode name
|
||||
func NewMode(Bridge *bridge.Bridge, c *lib.Tunnel) interface{} {
|
||||
func NewMode(Bridge *bridge.Bridge, c *file.Tunnel) interface{} {
|
||||
switch c.Mode {
|
||||
case "tunnelServer":
|
||||
return NewTunnelModeServer(ProcessTunnel, Bridge, c)
|
||||
|
@ -60,17 +58,15 @@ func NewMode(Bridge *bridge.Bridge, c *lib.Tunnel) interface{} {
|
|||
return NewUdpModeServer(Bridge, c)
|
||||
case "webServer":
|
||||
InitFromCsv()
|
||||
t := &lib.Tunnel{
|
||||
t := &file.Tunnel{
|
||||
TcpPort: 0,
|
||||
Mode: "httpHostServer",
|
||||
Target: "",
|
||||
Config: &lib.Config{},
|
||||
Config: &file.Config{},
|
||||
Status: true,
|
||||
}
|
||||
AddTask(t)
|
||||
return NewWebServer(Bridge)
|
||||
case "hostServer":
|
||||
return NewHostServer(c)
|
||||
case "httpHostServer":
|
||||
return NewHttp(Bridge, c)
|
||||
}
|
||||
|
@ -81,11 +77,11 @@ func NewMode(Bridge *bridge.Bridge, c *lib.Tunnel) interface{} {
|
|||
func StopServer(id int) error {
|
||||
if v, ok := RunList[id]; ok {
|
||||
reflect.ValueOf(v).MethodByName("Close").Call(nil)
|
||||
if t, err := lib.GetCsvDb().GetTask(id); err != nil {
|
||||
if t, err := file.GetCsvDb().GetTask(id); err != nil {
|
||||
return err
|
||||
} else {
|
||||
t.Status = false
|
||||
lib.GetCsvDb().UpdateTask(t)
|
||||
file.GetCsvDb().UpdateTask(t)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -93,13 +89,13 @@ func StopServer(id int) error {
|
|||
}
|
||||
|
||||
//add task
|
||||
func AddTask(t *lib.Tunnel) error {
|
||||
func AddTask(t *file.Tunnel) error {
|
||||
if svr := NewMode(Bridge, t); svr != nil {
|
||||
RunList[t.Id] = svr
|
||||
go func() {
|
||||
err := reflect.ValueOf(svr).MethodByName("Start").Call(nil)[0]
|
||||
if err.Interface() != nil {
|
||||
lib.Fatalln("服务端", t.Id, "启动失败,错误:", err)
|
||||
lg.Fatalln("客户端", t.Id, "启动失败,错误:", err)
|
||||
delete(RunList, t.Id)
|
||||
}
|
||||
}()
|
||||
|
@ -111,12 +107,12 @@ func AddTask(t *lib.Tunnel) error {
|
|||
|
||||
//start task
|
||||
func StartTask(id int) error {
|
||||
if t, err := lib.GetCsvDb().GetTask(id); err != nil {
|
||||
if t, err := file.GetCsvDb().GetTask(id); err != nil {
|
||||
return err
|
||||
} else {
|
||||
AddTask(t)
|
||||
t.Status = true
|
||||
lib.GetCsvDb().UpdateTask(t)
|
||||
file.GetCsvDb().UpdateTask(t)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -126,12 +122,12 @@ func DelTask(id int) error {
|
|||
if err := StopServer(id); err != nil {
|
||||
return err
|
||||
}
|
||||
return lib.GetCsvDb().DelTask(id)
|
||||
return file.GetCsvDb().DelTask(id)
|
||||
}
|
||||
|
||||
//get key by host from x
|
||||
func GetInfoByHost(host string) (h *lib.Host, err error) {
|
||||
for _, v := range lib.GetCsvDb().Hosts {
|
||||
func GetInfoByHost(host string) (h *file.Host, err error) {
|
||||
for _, v := range file.GetCsvDb().Hosts {
|
||||
s := strings.Split(host, ":")
|
||||
if s[0] == v.Host {
|
||||
h = v
|
||||
|
@ -143,10 +139,10 @@ func GetInfoByHost(host string) (h *lib.Host, err error) {
|
|||
}
|
||||
|
||||
//get task list by page num
|
||||
func GetTunnel(start, length int, typeVal string, clientId int) ([]*lib.Tunnel, int) {
|
||||
list := make([]*lib.Tunnel, 0)
|
||||
func GetTunnel(start, length int, typeVal string, clientId int) ([]*file.Tunnel, int) {
|
||||
list := make([]*file.Tunnel, 0)
|
||||
var cnt int
|
||||
for _, v := range lib.GetCsvDb().Tasks {
|
||||
for _, v := range file.GetCsvDb().Tasks {
|
||||
if (typeVal != "" && v.Mode != typeVal) || (typeVal == "" && clientId != v.Client.Id) {
|
||||
continue
|
||||
}
|
||||
|
@ -171,13 +167,13 @@ func GetTunnel(start, length int, typeVal string, clientId int) ([]*lib.Tunnel,
|
|||
}
|
||||
|
||||
//获取客户端列表
|
||||
func GetClientList(start, length int) (list []*lib.Client, cnt int) {
|
||||
list, cnt = lib.GetCsvDb().GetClientList(start, length)
|
||||
func GetClientList(start, length int) (list []*file.Client, cnt int) {
|
||||
list, cnt = file.GetCsvDb().GetClientList(start, length)
|
||||
dealClientData(list)
|
||||
return
|
||||
}
|
||||
|
||||
func dealClientData(list []*lib.Client) {
|
||||
func dealClientData(list []*file.Client) {
|
||||
for _, v := range list {
|
||||
if _, ok := Bridge.Client[v.Id]; ok {
|
||||
v.IsConnect = true
|
||||
|
@ -186,13 +182,13 @@ func dealClientData(list []*lib.Client) {
|
|||
}
|
||||
v.Flow.InletFlow = 0
|
||||
v.Flow.ExportFlow = 0
|
||||
for _, h := range lib.GetCsvDb().Hosts {
|
||||
for _, h := range file.GetCsvDb().Hosts {
|
||||
if h.Client.Id == v.Id {
|
||||
v.Flow.InletFlow += h.Flow.InletFlow
|
||||
v.Flow.ExportFlow += h.Flow.ExportFlow
|
||||
}
|
||||
}
|
||||
for _, t := range lib.GetCsvDb().Tasks {
|
||||
for _, t := range file.GetCsvDb().Tasks {
|
||||
if t.Client.Id == v.Id {
|
||||
v.Flow.InletFlow += t.Flow.InletFlow
|
||||
v.Flow.ExportFlow += t.Flow.ExportFlow
|
||||
|
@ -204,14 +200,14 @@ func dealClientData(list []*lib.Client) {
|
|||
|
||||
//根据客户端id删除其所属的所有隧道和域名
|
||||
func DelTunnelAndHostByClientId(clientId int) {
|
||||
for _, v := range lib.GetCsvDb().Tasks {
|
||||
for _, v := range file.GetCsvDb().Tasks {
|
||||
if v.Client.Id == clientId {
|
||||
DelTask(v.Id)
|
||||
}
|
||||
}
|
||||
for _, v := range lib.GetCsvDb().Hosts {
|
||||
for _, v := range file.GetCsvDb().Hosts {
|
||||
if v.Client.Id == clientId {
|
||||
lib.GetCsvDb().DelHost(v.Host)
|
||||
file.GetCsvDb().DelHost(v.Host)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -223,9 +219,9 @@ func DelClientConnect(clientId int) {
|
|||
|
||||
func GetDashboardData() map[string]int {
|
||||
data := make(map[string]int)
|
||||
data["hostCount"] = len(lib.GetCsvDb().Hosts)
|
||||
data["clientCount"] = len(lib.GetCsvDb().Clients)
|
||||
list := lib.GetCsvDb().Clients
|
||||
data["hostCount"] = len(file.GetCsvDb().Hosts)
|
||||
data["clientCount"] = len(file.GetCsvDb().Clients)
|
||||
list := file.GetCsvDb().Clients
|
||||
dealClientData(list)
|
||||
c := 0
|
||||
var in, out int64
|
||||
|
@ -239,7 +235,7 @@ func GetDashboardData() map[string]int {
|
|||
data["clientOnlineCount"] = c
|
||||
data["inletFlowCount"] = int(in)
|
||||
data["exportFlowCount"] = int(out)
|
||||
for _, v := range lib.GetCsvDb().Tasks {
|
||||
for _, v := range file.GetCsvDb().Tasks {
|
||||
switch v.Mode {
|
||||
case "tunnelServer":
|
||||
data["tunnelServerCount"] += 1
|
||||
|
|
|
@ -4,7 +4,10 @@ import (
|
|||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/cnlh/nps/bridge"
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"github.com/cnlh/nps/lib/conn"
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"github.com/cnlh/nps/lib/lg"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
|
@ -65,7 +68,7 @@ func (s *Sock5ModeServer) handleRequest(c net.Conn) {
|
|||
_, err := io.ReadFull(c, header)
|
||||
|
||||
if err != nil {
|
||||
lib.Println("illegal request", err)
|
||||
lg.Println("illegal request", err)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
@ -135,18 +138,18 @@ func (s *Sock5ModeServer) doConnect(c net.Conn, command uint8) {
|
|||
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||
var ltype string
|
||||
if command == associateMethod {
|
||||
ltype = lib.CONN_UDP
|
||||
ltype = common.CONN_UDP
|
||||
} else {
|
||||
ltype = lib.CONN_TCP
|
||||
ltype = common.CONN_TCP
|
||||
}
|
||||
link := lib.NewLink(s.task.Client.GetId(), ltype, addr, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, lib.NewConn(c), s.task.Flow, nil, s.task.Client.Rate, nil)
|
||||
link := conn.NewLink(s.task.Client.GetId(), ltype, addr, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, conn.NewConn(c), s.task.Flow, nil, s.task.Client.Rate, nil)
|
||||
|
||||
if tunnel, err := s.bridge.SendLinkInfo(s.task.Client.Id, link); err != nil {
|
||||
c.Close()
|
||||
return
|
||||
} else {
|
||||
s.sendReply(c, succeeded)
|
||||
s.linkCopy(link, lib.NewConn(c), nil, tunnel, s.task.Flow)
|
||||
s.linkCopy(link, conn.NewConn(c), nil, tunnel, s.task.Flow)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -162,7 +165,7 @@ func (s *Sock5ModeServer) handleBind(c net.Conn) {
|
|||
|
||||
//udp
|
||||
func (s *Sock5ModeServer) handleUDP(c net.Conn) {
|
||||
lib.Println("UDP Associate")
|
||||
lg.Println("UDP Associate")
|
||||
/*
|
||||
+----+------+------+----------+----------+----------+
|
||||
|RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
|
||||
|
@ -175,7 +178,7 @@ func (s *Sock5ModeServer) handleUDP(c net.Conn) {
|
|||
// relay udp datagram silently, without any notification to the requesting client
|
||||
if buf[2] != 0 {
|
||||
// does not support fragmentation, drop it
|
||||
lib.Println("does not support fragmentation, drop")
|
||||
lg.Println("does not support fragmentation, drop")
|
||||
dummy := make([]byte, maxUDPPacketSize)
|
||||
c.Read(dummy)
|
||||
}
|
||||
|
@ -187,13 +190,13 @@ func (s *Sock5ModeServer) handleUDP(c net.Conn) {
|
|||
func (s *Sock5ModeServer) handleConn(c net.Conn) {
|
||||
buf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(c, buf); err != nil {
|
||||
lib.Println("negotiation err", err)
|
||||
lg.Println("negotiation err", err)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if version := buf[0]; version != 5 {
|
||||
lib.Println("only support socks5, request from: ", c.RemoteAddr())
|
||||
lg.Println("only support socks5, request from: ", c.RemoteAddr())
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
@ -201,7 +204,7 @@ func (s *Sock5ModeServer) handleConn(c net.Conn) {
|
|||
|
||||
methods := make([]byte, nMethods)
|
||||
if len, err := c.Read(methods); len != int(nMethods) || err != nil {
|
||||
lib.Println("wrong method")
|
||||
lg.Println("wrong method")
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
@ -210,7 +213,7 @@ func (s *Sock5ModeServer) handleConn(c net.Conn) {
|
|||
c.Write(buf)
|
||||
if err := s.Auth(c); err != nil {
|
||||
c.Close()
|
||||
lib.Println("验证失败:", err)
|
||||
lg.Println("验证失败:", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
|
@ -269,7 +272,7 @@ func (s *Sock5ModeServer) Start() error {
|
|||
if strings.Contains(err.Error(), "use of closed network connection") {
|
||||
break
|
||||
}
|
||||
lib.Fatalln("accept error: ", err)
|
||||
lg.Fatalln("accept error: ", err)
|
||||
}
|
||||
if !s.ResetConfig() {
|
||||
conn.Close()
|
||||
|
@ -286,11 +289,11 @@ func (s *Sock5ModeServer) Close() error {
|
|||
}
|
||||
|
||||
//new
|
||||
func NewSock5ModeServer(bridge *bridge.Bridge, task *lib.Tunnel) *Sock5ModeServer {
|
||||
func NewSock5ModeServer(bridge *bridge.Bridge, task *file.Tunnel) *Sock5ModeServer {
|
||||
s := new(Sock5ModeServer)
|
||||
s.bridge = bridge
|
||||
s.task = task
|
||||
s.config = lib.DeepCopyConfig(task.Config)
|
||||
s.config = file.DeepCopyConfig(task.Config)
|
||||
if s.config.U != "" && s.config.P != "" {
|
||||
s.isVerify = true
|
||||
} else {
|
||||
|
|
|
@ -2,9 +2,12 @@ package server
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/cnlh/nps/lib/beego"
|
||||
"github.com/cnlh/nps/bridge"
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"github.com/cnlh/nps/lib/conn"
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"github.com/cnlh/nps/lib/lg"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
@ -17,12 +20,12 @@ type TunnelModeServer struct {
|
|||
}
|
||||
|
||||
//tcp|http|host
|
||||
func NewTunnelModeServer(process process, bridge *bridge.Bridge, task *lib.Tunnel) *TunnelModeServer {
|
||||
func NewTunnelModeServer(process process, bridge *bridge.Bridge, task *file.Tunnel) *TunnelModeServer {
|
||||
s := new(TunnelModeServer)
|
||||
s.bridge = bridge
|
||||
s.process = process
|
||||
s.task = task
|
||||
s.config = lib.DeepCopyConfig(task.Config)
|
||||
s.config = file.DeepCopyConfig(task.Config)
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -34,22 +37,22 @@ func (s *TunnelModeServer) Start() error {
|
|||
return err
|
||||
}
|
||||
for {
|
||||
conn, err := s.listener.AcceptTCP()
|
||||
c, err := s.listener.AcceptTCP()
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "use of closed network connection") {
|
||||
break
|
||||
}
|
||||
lib.Println(err)
|
||||
lg.Println(err)
|
||||
continue
|
||||
}
|
||||
go s.process(lib.NewConn(conn), s)
|
||||
go s.process(conn.NewConn(c), s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//与客户端建立通道
|
||||
func (s *TunnelModeServer) dealClient(c *lib.Conn, cnf *lib.Config, addr string, method string, rb []byte) error {
|
||||
link := lib.NewLink(s.task.Client.GetId(), lib.CONN_TCP, addr, cnf.CompressEncode, cnf.CompressDecode, cnf.Crypt, c, s.task.Flow, nil, s.task.Client.Rate, nil)
|
||||
func (s *TunnelModeServer) dealClient(c *conn.Conn, cnf *file.Config, addr string, method string, rb []byte) error {
|
||||
link := conn.NewLink(s.task.Client.GetId(), common.CONN_TCP, addr, cnf.CompressEncode, cnf.CompressDecode, cnf.Crypt, c, s.task.Flow, nil, s.task.Client.Rate, nil)
|
||||
|
||||
if tunnel, err := s.bridge.SendLinkInfo(s.task.Client.Id, link); err != nil {
|
||||
c.Close()
|
||||
|
@ -73,13 +76,13 @@ type WebServer struct {
|
|||
//开始
|
||||
func (s *WebServer) Start() error {
|
||||
p, _ := beego.AppConfig.Int("httpport")
|
||||
if !lib.TestTcpPort(p) {
|
||||
lib.Fatalln("web管理端口", p, "被占用!")
|
||||
if !common.TestTcpPort(p) {
|
||||
lg.Fatalln("web管理端口", p, "被占用!")
|
||||
}
|
||||
beego.BConfig.WebConfig.Session.SessionOn = true
|
||||
lib.Println("web管理启动,访问端口为", beego.AppConfig.String("httpport"))
|
||||
beego.SetStaticPath("/static", filepath.Join(lib.GetRunPath(), "web", "static"))
|
||||
beego.SetViewsPath(filepath.Join(lib.GetRunPath(), "web", "views"))
|
||||
lg.Println("web管理启动,访问端口为", p)
|
||||
beego.SetStaticPath("/static", filepath.Join(common.GetRunPath(), "web", "static"))
|
||||
beego.SetViewsPath(filepath.Join(common.GetRunPath(), "web", "views"))
|
||||
beego.Run()
|
||||
return errors.New("web管理启动失败")
|
||||
}
|
||||
|
@ -91,32 +94,10 @@ func NewWebServer(bridge *bridge.Bridge) *WebServer {
|
|||
return s
|
||||
}
|
||||
|
||||
//host
|
||||
type HostServer struct {
|
||||
server
|
||||
}
|
||||
|
||||
//开始
|
||||
func (s *HostServer) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewHostServer(task *lib.Tunnel) *HostServer {
|
||||
s := new(HostServer)
|
||||
s.task = task
|
||||
s.config = lib.DeepCopyConfig(task.Config)
|
||||
return s
|
||||
}
|
||||
|
||||
//close
|
||||
func (s *HostServer) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type process func(c *lib.Conn, s *TunnelModeServer) error
|
||||
type process func(c *conn.Conn, s *TunnelModeServer) error
|
||||
|
||||
//tcp隧道模式
|
||||
func ProcessTunnel(c *lib.Conn, s *TunnelModeServer) error {
|
||||
func ProcessTunnel(c *conn.Conn, s *TunnelModeServer) error {
|
||||
if !s.ResetConfig() {
|
||||
c.Close()
|
||||
return errors.New("流量超出")
|
||||
|
@ -125,7 +106,7 @@ func ProcessTunnel(c *lib.Conn, s *TunnelModeServer) error {
|
|||
}
|
||||
|
||||
//http代理模式
|
||||
func ProcessHttp(c *lib.Conn, s *TunnelModeServer) error {
|
||||
func ProcessHttp(c *conn.Conn, s *TunnelModeServer) error {
|
||||
if !s.ResetConfig() {
|
||||
c.Close()
|
||||
return errors.New("流量超出")
|
||||
|
|
|
@ -1,54 +1,78 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/beego"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"log"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func TestServerConfig() {
|
||||
var postArr []int
|
||||
for _, v := range lib.GetCsvDb().Tasks {
|
||||
isInArr(&postArr, v.TcpPort, v.Remark)
|
||||
var postTcpArr []int
|
||||
var postUdpArr []int
|
||||
for _, v := range file.GetCsvDb().Tasks {
|
||||
if v.Mode == "udpServer" {
|
||||
isInArr(&postUdpArr, v.TcpPort, v.Remark, "udp")
|
||||
} else {
|
||||
isInArr(&postTcpArr, v.TcpPort, v.Remark, "tcp")
|
||||
}
|
||||
}
|
||||
p, err := beego.AppConfig.Int("httpport")
|
||||
if err != nil {
|
||||
log.Fatalln("Getting web management port error :", err)
|
||||
} else {
|
||||
isInArr(&postArr, p, "WebmManagement port")
|
||||
isInArr(&postTcpArr, p, "Web Management port", "tcp")
|
||||
}
|
||||
|
||||
if p := beego.AppConfig.String("bridgePort"); p != "" {
|
||||
if port, err := strconv.Atoi(p); err != nil {
|
||||
log.Fatalln("get Server and client communication portserror:", err)
|
||||
} else if beego.AppConfig.String("bridgeType") == "kcp" {
|
||||
isInArr(&postUdpArr, port, "Server and client communication ports", "udp")
|
||||
} else {
|
||||
isInArr(&postTcpArr, port, "Server and client communication ports", "tcp")
|
||||
}
|
||||
}
|
||||
|
||||
if p := beego.AppConfig.String("httpProxyPort"); p != "" {
|
||||
if port, err := strconv.Atoi(p); err != nil {
|
||||
log.Fatalln("get http port error:", err)
|
||||
} else {
|
||||
isInArr(&postArr, port, "https port")
|
||||
isInArr(&postTcpArr, port, "https port", "tcp")
|
||||
}
|
||||
}
|
||||
if p := beego.AppConfig.String("httpsProxyPort"); p != "" {
|
||||
if port, err := strconv.Atoi(p); err != nil {
|
||||
log.Fatalln("get https port error", err)
|
||||
} else {
|
||||
if !lib.FileExists(beego.AppConfig.String("pemPath")) {
|
||||
if !common.FileExists(beego.AppConfig.String("pemPath")) {
|
||||
log.Fatalf("ssl certFile %s is not exist", beego.AppConfig.String("pemPath"))
|
||||
}
|
||||
if !lib.FileExists(beego.AppConfig.String("ketPath")) {
|
||||
if !common.FileExists(beego.AppConfig.String("ketPath")) {
|
||||
log.Fatalf("ssl keyFile %s is not exist", beego.AppConfig.String("pemPath"))
|
||||
}
|
||||
isInArr(&postArr, port, "http port")
|
||||
isInArr(&postTcpArr, port, "http port", "tcp")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isInArr(arr *[]int, val int, remark string) {
|
||||
func isInArr(arr *[]int, val int, remark string, tp string) {
|
||||
for _, v := range *arr {
|
||||
if v == val {
|
||||
log.Fatalf("the port %d is reused,remark: %s", val, remark)
|
||||
}
|
||||
}
|
||||
if !lib.TestTcpPort(val) {
|
||||
log.Fatalf("open the %d port error ,remark: %s", val, remark)
|
||||
if tp == "tcp" {
|
||||
if !common.TestTcpPort(val) {
|
||||
log.Fatalf("open the %d port error ,remark: %s", val, remark)
|
||||
}
|
||||
} else {
|
||||
if !common.TestUdpPort(val) {
|
||||
log.Fatalf("open the %d port error ,remark: %s", val, remark)
|
||||
}
|
||||
}
|
||||
|
||||
*arr = append(*arr, val)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -2,7 +2,10 @@ package server
|
|||
|
||||
import (
|
||||
"github.com/cnlh/nps/bridge"
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"github.com/cnlh/nps/lib/conn"
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"github.com/cnlh/nps/lib/pool"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
@ -10,15 +13,15 @@ import (
|
|||
type UdpModeServer struct {
|
||||
server
|
||||
listener *net.UDPConn
|
||||
udpMap map[string]*lib.Conn
|
||||
udpMap map[string]*conn.Conn
|
||||
}
|
||||
|
||||
func NewUdpModeServer(bridge *bridge.Bridge, task *lib.Tunnel) *UdpModeServer {
|
||||
func NewUdpModeServer(bridge *bridge.Bridge, task *file.Tunnel) *UdpModeServer {
|
||||
s := new(UdpModeServer)
|
||||
s.bridge = bridge
|
||||
s.udpMap = make(map[string]*lib.Conn)
|
||||
s.udpMap = make(map[string]*conn.Conn)
|
||||
s.task = task
|
||||
s.config = lib.DeepCopyConfig(task.Config)
|
||||
s.config = file.DeepCopyConfig(task.Config)
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -29,7 +32,7 @@ func (s *UdpModeServer) Start() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf := lib.BufPoolUdp.Get().([]byte)
|
||||
buf := pool.BufPoolUdp.Get().([]byte)
|
||||
for {
|
||||
n, addr, err := s.listener.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
|
@ -47,13 +50,14 @@ func (s *UdpModeServer) Start() error {
|
|||
}
|
||||
|
||||
func (s *UdpModeServer) process(addr *net.UDPAddr, data []byte) {
|
||||
link := lib.NewLink(s.task.Client.GetId(), lib.CONN_UDP, s.task.Target, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, nil, s.task.Flow, s.listener, s.task.Client.Rate, addr)
|
||||
link := conn.NewLink(s.task.Client.GetId(), common.CONN_UDP, s.task.Target, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, nil, s.task.Flow, s.listener, s.task.Client.Rate, addr)
|
||||
|
||||
if tunnel, err := s.bridge.SendLinkInfo(s.task.Client.Id, link); err != nil {
|
||||
return
|
||||
} else {
|
||||
s.task.Flow.Add(len(data), 0)
|
||||
tunnel.SendMsg(data, link)
|
||||
pool.PutBufPoolUdp(data)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/beego"
|
||||
"github.com/cnlh/nps/lib/common"
|
||||
"github.com/cnlh/nps/server"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -40,7 +40,7 @@ func (s *BaseController) display(tpl ...string) {
|
|||
}
|
||||
ip := s.Ctx.Request.Host
|
||||
if strings.LastIndex(ip, ":") > 0 {
|
||||
arr := strings.Split(lib.GetHostByName(ip), ":")
|
||||
arr := strings.Split(common.GetHostByName(ip), ":")
|
||||
s.Data["ip"] = arr[0]
|
||||
}
|
||||
s.Data["p"] = server.Bridge.TunnelPort
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/crypt"
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"github.com/cnlh/nps/lib/rate"
|
||||
"github.com/cnlh/nps/server"
|
||||
)
|
||||
|
||||
|
@ -28,29 +30,29 @@ func (s *ClientController) Add() {
|
|||
s.SetInfo("新增")
|
||||
s.display()
|
||||
} else {
|
||||
t := &lib.Client{
|
||||
VerifyKey: lib.GetRandomString(16),
|
||||
Id: lib.GetCsvDb().GetClientId(),
|
||||
t := &file.Client{
|
||||
VerifyKey: crypt.GetRandomString(16),
|
||||
Id: file.GetCsvDb().GetClientId(),
|
||||
Status: true,
|
||||
Remark: s.GetString("remark"),
|
||||
Cnf: &lib.Config{
|
||||
Cnf: &file.Config{
|
||||
U: s.GetString("u"),
|
||||
P: s.GetString("p"),
|
||||
Compress: s.GetString("compress"),
|
||||
Crypt: s.GetBoolNoErr("crypt"),
|
||||
},
|
||||
RateLimit: s.GetIntNoErr("rate_limit"),
|
||||
Flow: &lib.Flow{
|
||||
Flow: &file.Flow{
|
||||
ExportFlow: 0,
|
||||
InletFlow: 0,
|
||||
FlowLimit: int64(s.GetIntNoErr("flow_limit")),
|
||||
},
|
||||
}
|
||||
if t.RateLimit > 0 {
|
||||
t.Rate = lib.NewRate(int64(t.RateLimit * 1024))
|
||||
t.Rate = rate.NewRate(int64(t.RateLimit * 1024))
|
||||
t.Rate.Start()
|
||||
}
|
||||
lib.GetCsvDb().NewClient(t)
|
||||
file.GetCsvDb().NewClient(t)
|
||||
s.AjaxOk("添加成功")
|
||||
}
|
||||
}
|
||||
|
@ -58,7 +60,7 @@ func (s *ClientController) GetClient() {
|
|||
if s.Ctx.Request.Method == "POST" {
|
||||
id := s.GetIntNoErr("id")
|
||||
data := make(map[string]interface{})
|
||||
if c, err := lib.GetCsvDb().GetClient(id); err != nil {
|
||||
if c, err := file.GetCsvDb().GetClient(id); err != nil {
|
||||
data["code"] = 0
|
||||
} else {
|
||||
data["code"] = 1
|
||||
|
@ -74,7 +76,7 @@ func (s *ClientController) Edit() {
|
|||
id := s.GetIntNoErr("id")
|
||||
if s.Ctx.Request.Method == "GET" {
|
||||
s.Data["menu"] = "client"
|
||||
if c, err := lib.GetCsvDb().GetClient(id); err != nil {
|
||||
if c, err := file.GetCsvDb().GetClient(id); err != nil {
|
||||
s.error()
|
||||
} else {
|
||||
s.Data["c"] = c
|
||||
|
@ -82,7 +84,7 @@ func (s *ClientController) Edit() {
|
|||
s.SetInfo("修改")
|
||||
s.display()
|
||||
} else {
|
||||
if c, err := lib.GetCsvDb().GetClient(id); err != nil {
|
||||
if c, err := file.GetCsvDb().GetClient(id); err != nil {
|
||||
s.error()
|
||||
} else {
|
||||
c.Remark = s.GetString("remark")
|
||||
|
@ -96,12 +98,12 @@ func (s *ClientController) Edit() {
|
|||
c.Rate.Stop()
|
||||
}
|
||||
if c.RateLimit > 0 {
|
||||
c.Rate = lib.NewRate(int64(c.RateLimit * 1024))
|
||||
c.Rate = rate.NewRate(int64(c.RateLimit * 1024))
|
||||
c.Rate.Start()
|
||||
} else {
|
||||
c.Rate = nil
|
||||
}
|
||||
lib.GetCsvDb().UpdateClient(c)
|
||||
file.GetCsvDb().UpdateClient(c)
|
||||
}
|
||||
s.AjaxOk("修改成功")
|
||||
}
|
||||
|
@ -110,7 +112,7 @@ func (s *ClientController) Edit() {
|
|||
//更改状态
|
||||
func (s *ClientController) ChangeStatus() {
|
||||
id := s.GetIntNoErr("id")
|
||||
if client, err := lib.GetCsvDb().GetClient(id); err == nil {
|
||||
if client, err := file.GetCsvDb().GetClient(id); err == nil {
|
||||
client.Status = s.GetBoolNoErr("status")
|
||||
if client.Status == false {
|
||||
server.DelClientConnect(client.Id)
|
||||
|
@ -123,7 +125,7 @@ func (s *ClientController) ChangeStatus() {
|
|||
//删除客户端
|
||||
func (s *ClientController) Del() {
|
||||
id := s.GetIntNoErr("id")
|
||||
if err := lib.GetCsvDb().DelClient(id); err != nil {
|
||||
if err := file.GetCsvDb().DelClient(id); err != nil {
|
||||
s.AjaxErr("删除失败")
|
||||
}
|
||||
server.DelTunnelAndHostByClientId(id)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/cnlh/nps/lib"
|
||||
"github.com/cnlh/nps/lib/file"
|
||||
"github.com/cnlh/nps/server"
|
||||
)
|
||||
|
||||
|
@ -72,27 +72,27 @@ func (s *IndexController) Add() {
|
|||
s.SetInfo("新增")
|
||||
s.display()
|
||||
} else {
|
||||
t := &lib.Tunnel{
|
||||
t := &file.Tunnel{
|
||||
TcpPort: s.GetIntNoErr("port"),
|
||||
Mode: s.GetString("type"),
|
||||
Target: s.GetString("target"),
|
||||
Config: &lib.Config{
|
||||
Config: &file.Config{
|
||||
U: s.GetString("u"),
|
||||
P: s.GetString("p"),
|
||||
Compress: s.GetString("compress"),
|
||||
Crypt: s.GetBoolNoErr("crypt"),
|
||||
},
|
||||
Id: lib.GetCsvDb().GetTaskId(),
|
||||
Id: file.GetCsvDb().GetTaskId(),
|
||||
UseClientCnf: s.GetBoolNoErr("use_client"),
|
||||
Status: true,
|
||||
Remark: s.GetString("remark"),
|
||||
Flow: &lib.Flow{},
|
||||
Flow: &file.Flow{},
|
||||
}
|
||||
var err error
|
||||
if t.Client, err = lib.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
|
||||
if t.Client, err = file.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
|
||||
s.AjaxErr(err.Error())
|
||||
}
|
||||
lib.GetCsvDb().NewTask(t)
|
||||
file.GetCsvDb().NewTask(t)
|
||||
if err := server.AddTask(t); err != nil {
|
||||
s.AjaxErr(err.Error())
|
||||
} else {
|
||||
|
@ -103,7 +103,7 @@ func (s *IndexController) Add() {
|
|||
func (s *IndexController) GetOneTunnel() {
|
||||
id := s.GetIntNoErr("id")
|
||||
data := make(map[string]interface{})
|
||||
if t, err := lib.GetCsvDb().GetTask(id); err != nil {
|
||||
if t, err := file.GetCsvDb().GetTask(id); err != nil {
|
||||
data["code"] = 0
|
||||
} else {
|
||||
data["code"] = 1
|
||||
|
@ -115,7 +115,7 @@ func (s *IndexController) GetOneTunnel() {
|
|||
func (s *IndexController) Edit() {
|
||||
id := s.GetIntNoErr("id")
|
||||
if s.Ctx.Request.Method == "GET" {
|
||||
if t, err := lib.GetCsvDb().GetTask(id); err != nil {
|
||||
if t, err := file.GetCsvDb().GetTask(id); err != nil {
|
||||
s.error()
|
||||
} else {
|
||||
s.Data["t"] = t
|
||||
|
@ -123,7 +123,7 @@ func (s *IndexController) Edit() {
|
|||
s.SetInfo("修改")
|
||||
s.display()
|
||||
} else {
|
||||
if t, err := lib.GetCsvDb().GetTask(id); err != nil {
|
||||
if t, err := file.GetCsvDb().GetTask(id); err != nil {
|
||||
s.error()
|
||||
} else {
|
||||
t.TcpPort = s.GetIntNoErr("port")
|
||||
|
@ -137,10 +137,10 @@ func (s *IndexController) Edit() {
|
|||
t.Config.Crypt = s.GetBoolNoErr("crypt")
|
||||
t.UseClientCnf = s.GetBoolNoErr("use_client")
|
||||
t.Remark = s.GetString("remark")
|
||||
if t.Client, err = lib.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
|
||||
if t.Client, err = file.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
|
||||
s.AjaxErr("修改失败")
|
||||
}
|
||||
lib.GetCsvDb().UpdateTask(t)
|
||||
file.GetCsvDb().UpdateTask(t)
|
||||
}
|
||||
s.AjaxOk("修改成功")
|
||||
}
|
||||
|
@ -179,7 +179,7 @@ func (s *IndexController) HostList() {
|
|||
} else {
|
||||
start, length := s.GetAjaxParams()
|
||||
clientId := s.GetIntNoErr("client_id")
|
||||
list, cnt := lib.GetCsvDb().GetHost(start, length, clientId)
|
||||
list, cnt := file.GetCsvDb().GetHost(start, length, clientId)
|
||||
s.AjaxTable(list, cnt, cnt)
|
||||
}
|
||||
}
|
||||
|
@ -200,7 +200,7 @@ func (s *IndexController) GetHost() {
|
|||
|
||||
func (s *IndexController) DelHost() {
|
||||
host := s.GetString("host")
|
||||
if err := lib.GetCsvDb().DelHost(host); err != nil {
|
||||
if err := file.GetCsvDb().DelHost(host); err != nil {
|
||||
s.AjaxErr("删除失败")
|
||||
}
|
||||
s.AjaxOk("删除成功")
|
||||
|
@ -213,19 +213,19 @@ func (s *IndexController) AddHost() {
|
|||
s.SetInfo("新增")
|
||||
s.display("index/hadd")
|
||||
} else {
|
||||
h := &lib.Host{
|
||||
h := &file.Host{
|
||||
Host: s.GetString("host"),
|
||||
Target: s.GetString("target"),
|
||||
HeaderChange: s.GetString("header"),
|
||||
HostChange: s.GetString("hostchange"),
|
||||
Remark: s.GetString("remark"),
|
||||
Flow: &lib.Flow{},
|
||||
Flow: &file.Flow{},
|
||||
}
|
||||
var err error
|
||||
if h.Client, err = lib.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
|
||||
if h.Client, err = file.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
|
||||
s.AjaxErr("添加失败")
|
||||
}
|
||||
lib.GetCsvDb().NewHost(h)
|
||||
file.GetCsvDb().NewHost(h)
|
||||
s.AjaxOk("添加成功")
|
||||
}
|
||||
}
|
||||
|
@ -251,9 +251,9 @@ func (s *IndexController) EditHost() {
|
|||
h.HostChange = s.GetString("hostchange")
|
||||
h.Remark = s.GetString("remark")
|
||||
h.TargetArr = nil
|
||||
lib.GetCsvDb().UpdateHost(h)
|
||||
file.GetCsvDb().UpdateHost(h)
|
||||
var err error
|
||||
if h.Client, err = lib.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
|
||||
if h.Client, err = file.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
|
||||
s.AjaxErr("修改失败")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/cnlh/nps/lib/beego"
|
||||
)
|
||||
|
||||
type LoginController struct {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package routers
|
||||
|
||||
import (
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/cnlh/nps/lib/beego"
|
||||
"github.com/cnlh/nps/web/controllers"
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue