结构调整、kcp支持

pull/1219/head v0.0.14
刘河 2019-02-09 17:07:47 +08:00
parent 2e8af6f120
commit 59d789d253
60 changed files with 11097 additions and 783 deletions

View File

@ -45,7 +45,7 @@ go语言编写无第三方依赖各个平台都已经编译在release中
* [与nginx配合](#与nginx配合)
* [关闭http|https代理](#关闭代理)
* [将nps安装到系统](#将nps安装到系统)
* 单隧道模式及介绍
* 单隧道模式及介绍(即将移除)
* [tcp隧道模式](#tcp隧道模式)
* [udp隧道模式](#udp隧道模式)
* [socks5代理模式](#socks5代理模式)
@ -62,6 +62,7 @@ go语言编写无第三方依赖各个平台都已经编译在release中
* [带宽限制](#带宽限制)
* [负载均衡](#负载均衡)
* [守护进程](#守护进程)
* [KCP协议支持](#KCP协议支持)
* [相关说明](#相关说明)
* [流量统计](#流量统计)
* [热更新支持](#热更新支持)
@ -138,12 +139,13 @@ go语言编写无第三方依赖各个平台都已经编译在release中
---|---
httpport | web管理端口
password | web界面管理密码
tcpport | 服务端客户端通信端口
bridePort | 服务端客户端通信端口
pemPath | ssl certFile绝对路径
keyPath | ssl keyFile绝对路径
httpsProxyPort | 域名代理https代理监听端口
httpProxyPort | 域名代理http代理监听端口
authip|web api免验证IP地址
bridgeType|客户端与服务端连接方式kcp或tcp
### 详细说明
@ -539,12 +541,23 @@ authip | 免验证ip适用于web api
### 守护进程
本代理支持守护进程,使用示例如下,服务端客户端所有模式通用,支持linuxdarwinwindows。
```
./(nps|npc) start|stop|restart|status xxxxxx
./(nps|npc) start|stop|restart|status 若有其他参数可加其他参数
```
```
(nps|npc).exe start|stop|restart|status xxxxxx
(nps|npc).exe start|stop|restart|status 若有其他参数可加其他参数
```
### KCP协议支持
KCP 是一个快速可靠协议,能以比 TCP浪费10%-20%的带宽的代价,换取平均延迟降低 30%-40%在弱网环境下对性能能有一定的提升。可在app.conf中修改bridgeType为kcp
设置后本代理将开启udp端口bridgePort
注意当服务端为kcp时客户端连接时也需要加上参数
```
-type=kcp
```
## 相关说明
### 获取用户真实ip

View File

@ -2,71 +2,103 @@ package bridge
import (
"errors"
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/conn"
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/lib/kcp"
"github.com/cnlh/nps/lib/lg"
"github.com/cnlh/nps/lib/pool"
"github.com/cnlh/nps/lib/common"
"net"
"strconv"
"sync"
"time"
)
type Client struct {
tunnel *lib.Conn
signal *lib.Conn
linkMap map[int]*lib.Link
tunnel *conn.Conn
signal *conn.Conn
linkMap map[int]*conn.Link
linkStatusMap map[int]bool
stop chan bool
sync.RWMutex
}
type Bridge struct {
TunnelPort int //通信隧道端口
listener *net.TCPListener //server端监听
Client map[int]*Client
RunList map[int]interface{} //运行中的任务
lock sync.Mutex
tunnelLock sync.Mutex
clientLock sync.Mutex
func NewClient(t *conn.Conn, s *conn.Conn) *Client {
return &Client{
linkMap: make(map[int]*conn.Link),
stop: make(chan bool),
linkStatusMap: make(map[int]bool),
signal: s,
tunnel: t,
}
}
func NewTunnel(tunnelPort int, runList map[int]interface{}) *Bridge {
type Bridge struct {
TunnelPort int //通信隧道端口
tcpListener *net.TCPListener //server端监听
kcpListener *kcp.Listener //server端监听
Client map[int]*Client
RunList map[int]interface{} //运行中的任务
tunnelType string //bridge type kcp or tcp
lock sync.Mutex
tunnelLock sync.Mutex
clientLock sync.RWMutex
}
func NewTunnel(tunnelPort int, runList map[int]interface{}, tunnelType string) *Bridge {
t := new(Bridge)
t.TunnelPort = tunnelPort
t.Client = make(map[int]*Client)
t.RunList = runList
t.tunnelType = tunnelType
return t
}
func (s *Bridge) StartTunnel() error {
var err error
s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.TunnelPort, ""})
if err != nil {
return err
if s.tunnelType == "kcp" {
s.kcpListener, err = kcp.ListenWithOptions(":"+strconv.Itoa(s.TunnelPort), nil, 150, 3)
if err != nil {
return err
}
go func() {
for {
c, err := s.kcpListener.AcceptKCP()
conn.SetUdpSession(c)
if err != nil {
lg.Println(err)
continue
}
go s.cliProcess(conn.NewConn(c))
}
}()
} else {
s.tcpListener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.TunnelPort, ""})
if err != nil {
return err
}
go func() {
for {
c, err := s.tcpListener.Accept()
if err != nil {
lg.Println(err)
continue
}
go s.cliProcess(conn.NewConn(c))
}
}()
}
go s.tunnelProcess()
return nil
}
//tcp server
func (s *Bridge) tunnelProcess() error {
var err error
for {
conn, err := s.listener.Accept()
if err != nil {
lib.Println(err)
continue
}
go s.cliProcess(lib.NewConn(conn))
}
return err
}
//验证失败返回错误验证flag并且关闭连接
func (s *Bridge) verifyError(c *lib.Conn) {
c.Write([]byte(lib.VERIFY_EER))
func (s *Bridge) verifyError(c *conn.Conn) {
c.Write([]byte(common.VERIFY_EER))
c.Conn.Close()
}
func (s *Bridge) cliProcess(c *lib.Conn) {
c.SetReadDeadline(5)
func (s *Bridge) cliProcess(c *conn.Conn) {
c.SetReadDeadline(5, s.tunnelType)
var buf []byte
var err error
if buf, err = c.ReadLen(32); err != nil {
@ -74,9 +106,9 @@ func (s *Bridge) cliProcess(c *lib.Conn) {
return
}
//验证
id, err := lib.GetCsvDb().GetIdByVerifyKey(string(buf), c.Conn.RemoteAddr().String())
id, err := file.GetCsvDb().GetIdByVerifyKey(string(buf), c.Conn.RemoteAddr().String())
if err != nil {
lib.Println("当前客户端连接校验错误,关闭此客户端:", c.Conn.RemoteAddr())
lg.Println("当前客户端连接校验错误,关闭此客户端:", c.Conn.RemoteAddr())
s.verifyError(c)
return
}
@ -97,40 +129,39 @@ func (s *Bridge) closeClient(id int) {
}
//tcp连接类型区分
func (s *Bridge) typeDeal(typeVal string, c *lib.Conn, id int) {
func (s *Bridge) typeDeal(typeVal string, c *conn.Conn, id int) {
switch typeVal {
case lib.WORK_MAIN:
case common.WORK_MAIN:
//客户端已经存在,下线
s.clientLock.Lock()
if _, ok := s.Client[id]; ok {
s.clientLock.Unlock()
s.closeClient(id)
} else {
s.clientLock.Unlock()
}
s.clientLock.Lock()
s.Client[id] = &Client{
linkMap: make(map[int]*lib.Link),
stop: make(chan bool),
linkStatusMap: make(map[int]bool),
}
lib.Printf("客户端%d连接成功,地址为:%s", id, c.Conn.RemoteAddr())
s.Client[id].signal = c
s.clientLock.Unlock()
go s.GetStatus(id)
case lib.WORK_CHAN:
s.clientLock.Lock()
if v, ok := s.Client[id]; ok {
s.clientLock.Unlock()
v.tunnel = c
if v.signal != nil {
v.signal.WriteClose()
}
v.Lock()
v.signal = c
v.Unlock()
} else {
s.Client[id] = NewClient(nil, c)
s.clientLock.Unlock()
}
lg.Printf("客户端%d连接成功,地址为:%s", id, c.Conn.RemoteAddr())
go s.GetStatus(id)
case common.WORK_CHAN:
s.clientLock.Lock()
if v, ok := s.Client[id]; ok {
s.clientLock.Unlock()
v.Lock()
v.tunnel = c
v.Unlock()
} else {
s.Client[id] = NewClient(c, nil)
s.clientLock.Unlock()
return
}
go s.clientCopy(id)
}
c.SetAlive()
c.SetAlive(s.tunnelType)
return
}
@ -161,13 +192,13 @@ func (s *Bridge) waitStatus(clientId, id int) (bool) {
return false
}
func (s *Bridge) SendLinkInfo(clientId int, link *lib.Link) (tunnel *lib.Conn, err error) {
func (s *Bridge) SendLinkInfo(clientId int, link *conn.Link) (tunnel *conn.Conn, err error) {
s.clientLock.Lock()
if v, ok := s.Client[clientId]; ok {
s.clientLock.Unlock()
v.signal.SendLinkInfo(link)
if err != nil {
lib.Println("send error:", err, link.Id)
lg.Println("send error:", err, link.Id)
s.DelClient(clientId)
return
}
@ -192,7 +223,7 @@ func (s *Bridge) SendLinkInfo(clientId int, link *lib.Link) (tunnel *lib.Conn, e
}
//得到一个tcp隧道
func (s *Bridge) GetTunnel(id int, en, de int, crypt, mux bool) (conn *lib.Conn, err error) {
func (s *Bridge) GetTunnel(id int, en, de int, crypt, mux bool) (conn *conn.Conn, err error) {
s.clientLock.Lock()
defer s.clientLock.Unlock()
if v, ok := s.Client[id]; !ok {
@ -204,7 +235,7 @@ func (s *Bridge) GetTunnel(id int, en, de int, crypt, mux bool) (conn *lib.Conn,
}
//得到一个通信通道
func (s *Bridge) GetSignal(id int) (conn *lib.Conn, err error) {
func (s *Bridge) GetSignal(id int) (conn *conn.Conn, err error) {
s.clientLock.Lock()
defer s.clientLock.Unlock()
if v, ok := s.Client[id]; !ok {
@ -257,19 +288,19 @@ func (s *Bridge) clientCopy(clientId int) {
for {
if id, err := client.tunnel.GetLen(); err != nil {
s.closeClient(clientId)
lib.Println("读取msg id 错误", err, id)
lg.Println("读取msg id 错误", err, id)
break
} else {
client.Lock()
if link, ok := client.linkMap[id]; ok {
client.Unlock()
if content, err := client.tunnel.GetMsgContent(link); err != nil {
lib.PutBufPoolCopy(content)
pool.PutBufPoolCopy(content)
s.closeClient(clientId)
lib.Println("read msg content error", err, "close client")
lg.Println("read msg content error", err, "close client")
break
} else {
if len(content) == len(lib.IO_EOF) && string(content) == lib.IO_EOF {
if len(content) == len(common.IO_EOF) && string(content) == common.IO_EOF {
if link.Conn != nil {
link.Conn.Close()
}
@ -281,7 +312,7 @@ func (s *Bridge) clientCopy(clientId int) {
}
link.Flow.Add(0, len(content))
}
lib.PutBufPoolCopy(content)
pool.PutBufPoolCopy(content)
}
} else {
client.Unlock()
@ -289,5 +320,4 @@ func (s *Bridge) clientCopy(clientId int) {
}
}
}
}

View File

@ -1,30 +1,35 @@
package client
import (
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/common"
"github.com/cnlh/nps/lib/conn"
"github.com/cnlh/nps/lib/kcp"
"github.com/cnlh/nps/lib/lg"
"github.com/cnlh/nps/lib/pool"
"net"
"sync"
"time"
)
type TRPClient struct {
svrAddr string
linkMap map[int]*lib.Link
stop chan bool
tunnel *lib.Conn
svrAddr string
linkMap map[int]*conn.Link
stop chan bool
tunnel *conn.Conn
bridgeConnType string
sync.Mutex
vKey string
}
//new client
func NewRPClient(svraddr string, vKey string) *TRPClient {
func NewRPClient(svraddr string, vKey string, bridgeConnType string) *TRPClient {
return &TRPClient{
svrAddr: svraddr,
linkMap: make(map[int]*lib.Link),
stop: make(chan bool),
tunnel: nil,
Mutex: sync.Mutex{},
vKey: vKey,
svrAddr: svraddr,
linkMap: make(map[int]*conn.Link),
stop: make(chan bool),
Mutex: sync.Mutex{},
vKey: vKey,
bridgeConnType: bridgeConnType,
}
}
@ -36,37 +41,44 @@ func (s *TRPClient) Start() error {
//新建
func (s *TRPClient) NewConn() {
var err error
var c net.Conn
retry:
conn, err := net.Dial("tcp", s.svrAddr)
if s.bridgeConnType == "tcp" {
c, err = net.Dial("tcp", s.svrAddr)
} else {
var sess *kcp.UDPSession
sess, err = kcp.DialWithOptions(s.svrAddr, nil, 150, 3)
conn.SetUdpSession(sess)
c = sess
}
if err != nil {
lib.Println("连接服务端失败,五秒后将重连")
lg.Println("连接服务端失败,五秒后将重连")
time.Sleep(time.Second * 5)
goto retry
return
}
s.processor(lib.NewConn(conn))
s.processor(conn.NewConn(c))
}
//处理
func (s *TRPClient) processor(c *lib.Conn) {
c.SetAlive()
if _, err := c.Write([]byte(lib.Getverifyval(s.vKey))); err != nil {
func (s *TRPClient) processor(c *conn.Conn) {
c.SetAlive(s.bridgeConnType)
if _, err := c.Write([]byte(common.Getverifyval(s.vKey))); err != nil {
return
}
c.WriteMain()
go s.dealChan()
for {
flags, err := c.ReadFlag()
if err != nil {
lib.Println("服务端断开,正在重新连接")
lg.Println("服务端断开,正在重新连接")
break
}
switch flags {
case lib.VERIFY_EER:
lib.Fatalf("vKey:%s不正确,服务端拒绝连接,请检查", s.vKey)
case lib.NEW_CONN:
case common.VERIFY_EER:
lg.Fatalf("vKey:%s不正确,服务端拒绝连接,请检查", s.vKey)
case common.NEW_CONN:
if link, err := c.GetLinkInfo(); err != nil {
break
} else {
@ -75,54 +87,46 @@ func (s *TRPClient) processor(c *lib.Conn) {
s.Unlock()
go s.linkProcess(link, c)
}
case lib.RES_CLOSE:
lib.Fatalln("该vkey被另一客户连接")
case lib.RES_MSG:
lib.Println("服务端返回错误,重新连接")
case common.RES_CLOSE:
lg.Fatalln("该vkey被另一客户连接")
case common.RES_MSG:
lg.Println("服务端返回错误,重新连接")
break
default:
lib.Println("无法解析该错误,重新连接")
lg.Println("无法解析该错误,重新连接")
break
}
}
s.stop <- true
s.linkMap = make(map[int]*lib.Link)
s.linkMap = make(map[int]*conn.Link)
go s.NewConn()
}
func (s *TRPClient) linkProcess(link *lib.Link, c *lib.Conn) {
func (s *TRPClient) linkProcess(link *conn.Link, c *conn.Conn) {
//与目标建立连接
server, err := net.DialTimeout(link.ConnType, link.Host, time.Second*3)
if err != nil {
c.WriteFail(link.Id)
lib.Println("connect to ", link.Host, "error:", err)
lg.Println("connect to ", link.Host, "error:", err)
return
}
c.WriteSuccess(link.Id)
link.Conn = lib.NewConn(server)
link.Conn = conn.NewConn(server)
buf := pool.BufPoolCopy.Get().([]byte)
for {
buf := lib.BufPoolCopy.Get().([]byte)
if n, err := server.Read(buf); err != nil {
lib.PutBufPoolCopy(buf)
s.tunnel.SendMsg([]byte(lib.IO_EOF), link)
s.tunnel.SendMsg([]byte(common.IO_EOF), link)
break
} else {
if _, err := s.tunnel.SendMsg(buf[:n], link); err != nil {
lib.PutBufPoolCopy(buf)
c.Close()
break
}
lib.PutBufPoolCopy(buf)
//if link.ConnType == utils.CONN_UDP {
// c.Close()
// break
//}
}
}
pool.PutBufPoolCopy(buf)
s.Lock()
delete(s.linkMap, link.Id)
s.Unlock()
@ -131,41 +135,50 @@ func (s *TRPClient) linkProcess(link *lib.Link, c *lib.Conn) {
//隧道模式处理
func (s *TRPClient) dealChan() {
var err error
//创建一个tcp连接
conn, err := net.Dial("tcp", s.svrAddr)
var c net.Conn
var sess *kcp.UDPSession
if s.bridgeConnType == "tcp" {
c, err = net.Dial("tcp", s.svrAddr)
} else {
sess, err = kcp.DialWithOptions(s.svrAddr, nil, 10, 3)
conn.SetUdpSession(sess)
c = sess
}
if err != nil {
lib.Println("connect to ", s.svrAddr, "error:", err)
lg.Println("connect to ", s.svrAddr, "error:", err)
return
}
//验证
if _, err := conn.Write([]byte(lib.Getverifyval(s.vKey))); err != nil {
lib.Println("connect to ", s.svrAddr, "error:", err)
if _, err := c.Write([]byte(common.Getverifyval(s.vKey))); err != nil {
lg.Println("connect to ", s.svrAddr, "error:", err)
return
}
//默认长连接保持
s.tunnel = lib.NewConn(conn)
s.tunnel.SetAlive()
s.tunnel = conn.NewConn(c)
s.tunnel.SetAlive(s.bridgeConnType)
//写标志
s.tunnel.WriteChan()
go func() {
for {
if id, err := s.tunnel.GetLen(); err != nil {
lib.Println("get msg id error")
lg.Println("get msg id error")
break
} else {
s.Lock()
if v, ok := s.linkMap[id]; ok {
s.Unlock()
if content, err := s.tunnel.GetMsgContent(v); err != nil {
lib.Println("get msg content error:", err, id)
lg.Println("get msg content error:", err, id)
pool.PutBufPoolCopy(content)
break
} else {
if len(content) == len(lib.IO_EOF) && string(content) == lib.IO_EOF {
if len(content) == len(common.IO_EOF) && string(content) == common.IO_EOF {
v.Conn.Close()
} else if v.Conn != nil {
v.Conn.Write(content)
}
pool.PutBufPoolCopy(content)
}
} else {
s.Unlock()
@ -175,5 +188,6 @@ func (s *TRPClient) dealChan() {
}()
select {
case <-s.stop:
break
}
}

View File

@ -3,8 +3,9 @@ package main
import (
"flag"
"github.com/cnlh/nps/client"
"github.com/cnlh/nps/lib"
_ "github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/daemon"
"github.com/cnlh/nps/lib/lg"
"github.com/cnlh/nps/lib/common"
"strings"
)
@ -14,20 +15,21 @@ var (
serverAddr = flag.String("server", "", "服务器地址ip:端口")
verifyKey = flag.String("vkey", "", "验证密钥")
logType = flag.String("log", "stdout", "日志输出方式stdout|file")
connType = flag.String("type", "tcp", "与服务端建立连接方式kcp|tcp")
)
func main() {
flag.Parse()
lib.InitDaemon("npc")
daemon.InitDaemon("npc", common.GetRunPath(), common.GetPidPath())
if *logType == "stdout" {
lib.InitLogFile("npc", true)
lg.InitLogFile("npc", true, common.GetLogPath())
} else {
lib.InitLogFile("npc", false)
lg.InitLogFile("npc", false, common.GetLogPath())
}
stop := make(chan int)
for _, v := range strings.Split(*verifyKey, ",") {
lib.Println("客户端启动,连接:", *serverAddr, " 验证令牌:", v)
go client.NewRPClient(*serverAddr, v).Start()
lg.Println("客户端启动,连接:", *serverAddr, " 验证令牌:", v)
go client.NewRPClient(*serverAddr, v, *connType).Start()
}
<-stop
}

View File

@ -2,8 +2,12 @@ package main
import (
"flag"
"github.com/astaxie/beego"
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/beego"
"github.com/cnlh/nps/lib/common"
"github.com/cnlh/nps/lib/daemon"
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/lib/install"
"github.com/cnlh/nps/lib/lg"
"github.com/cnlh/nps/server"
_ "github.com/cnlh/nps/web/routers"
"log"
@ -28,58 +32,65 @@ var (
func main() {
flag.Parse()
if len(os.Args) > 1 && os.Args[1] == "test" {
server.TestServerConfig()
log.Println("test ok, no error")
return
if len(os.Args) > 1 {
switch os.Args[1] {
case "test":
server.TestServerConfig()
log.Println("test ok, no error")
return
case "start", "restart", "stop", "status":
daemon.InitDaemon("nps", common.GetRunPath(), common.GetPidPath())
case "install":
install.InstallNps()
return
}
}
lib.InitDaemon("nps")
if *logType == "stdout" {
lib.InitLogFile("nps", true)
lg.InitLogFile("nps", true, common.GetLogPath())
} else {
lib.InitLogFile("nps", false)
lg.InitLogFile("nps", false, common.GetLogPath())
}
task := &lib.Tunnel{
task := &file.Tunnel{
TcpPort: *httpPort,
Mode: *rpMode,
Target: *tunnelTarget,
Config: &lib.Config{
Config: &file.Config{
U: *u,
P: *p,
Compress: *compress,
Crypt: lib.GetBoolByStr(*crypt),
Crypt: common.GetBoolByStr(*crypt),
},
Flow: &lib.Flow{},
Flow: &file.Flow{},
UseClientCnf: false,
}
if *VerifyKey != "" {
c := &lib.Client{
c := &file.Client{
Id: 0,
VerifyKey: *VerifyKey,
Addr: "",
Remark: "",
Status: true,
IsConnect: false,
Cnf: &lib.Config{},
Flow: &lib.Flow{},
Cnf: &file.Config{},
Flow: &file.Flow{},
}
c.Cnf.CompressDecode, c.Cnf.CompressEncode = lib.GetCompressType(c.Cnf.Compress)
lib.GetCsvDb().Clients[0] = c
c.Cnf.CompressDecode, c.Cnf.CompressEncode = common.GetCompressType(c.Cnf.Compress)
file.GetCsvDb().Clients[0] = c
task.Client = c
}
if *TcpPort == 0 {
p, err := beego.AppConfig.Int("tcpport")
p, err := beego.AppConfig.Int("bridgePort")
if err == nil && *rpMode == "webServer" {
*TcpPort = p
} else {
*TcpPort = 8284
}
}
lib.Println("服务端启动监听tcp服务端端口", *TcpPort)
task.Config.CompressDecode, task.Config.CompressEncode = lib.GetCompressType(task.Config.Compress)
lg.Printf("服务端启动,监听%s服务端口%d", beego.AppConfig.String("bridgeType"), *TcpPort)
task.Config.CompressDecode, task.Config.CompressEncode = common.GetCompressType(task.Config.Compress)
if *rpMode != "webServer" {
lib.GetCsvDb().Tasks[0] = task
file.GetCsvDb().Tasks[0] = task
}
beego.LoadAppConfig("ini", filepath.Join(lib.GetRunPath(), "conf", "app.conf"))
server.StartNewServer(*TcpPort, task)
beego.LoadAppConfig("ini", filepath.Join(common.GetRunPath(), "conf", "app.conf"))
server.StartNewServer(*TcpPort, task, beego.AppConfig.String("bridgeType"))
}

View File

@ -1,28 +1,33 @@
appname = nps
#web管理端口
httpport = 8081
#Web Management Port
httpport = 8080
#启动模式dev|pro
#Boot mode(dev|pro)
runmode = dev
#web管理密码
#Web Management Password
password=123
##客户端与服务端通信端口
tcpport=8284
##Communication Port between Client and Server
##If the data transfer mode is tcp, it is TCP port
##If the data transfer mode is kcp, it is UDP port
bridgePort=8284
#web api免验证IP地址
#Web API unauthenticated IP address
authip=127.0.0.1
##http代理端口,为空则不启动
##HTTP proxy port, no startup if empty
httpProxyPort=80
##https代理端口,为空则不启动
##HTTPS proxy port, no startup if empty
httpsProxyPort=
##certFile绝对路径
##certFile absolute path
pemPath=/etc/nginx/certificate.crt
##keyFile绝对路径
keyPath=/etc/nginx/private.key
##KeyFile absolute path
keyPath=/etc/nginx/private.key
##Data transmission mode(kcp or tcp)
bridgeType=tcp

View File

@ -1 +1 @@
1,ydiigrm4ghu7mym1,,true,,,0,,0,0
1,ydiigrm4ghu7mym1,测试,true,,,0,,0,0

1 1 ydiigrm4ghu7mym1 测试 true 0 0 0

View File

@ -1 +1,2 @@
a.o.com,127.0.0.1:8081,1,,,测试
a.o.com,127.0.0.1:8080,1,,,测试
b.o.com,127.0.0.1:8082,1,,,

1 a.o.com 127.0.0.1:8081 127.0.0.1:8080 1 测试
2 b.o.com 127.0.0.1:8082 1

View File

@ -1,4 +1,4 @@
9001,tunnelServer,123.206.77.88:22,,,,1,0,0,0,1,1,true,测试tcp
53,udpServer,114.114.114.114:53,,,,1,0,0,0,2,1,true,udp
0,socks5Server,,,,,1,0,0,0,3,1,true,socks5
9005,httpProxyServer,,,,,1,0,0,0,4,1,true,
9002,socks5Server,,,,,1,0,0,0,3,1,true,socks5
9001,tunnelServer,127.0.0.1:8082,,,,1,0,0,0,1,1,true,测试tcp

1 9001 53 tunnelServer udpServer 123.206.77.88:22 114.114.114.114:53 1 0 0 0 2 1 1 true 测试tcp udp
9001 tunnelServer 123.206.77.88:22 1 0 0 0 1 1 true 测试tcp
1 53 53 udpServer udpServer 114.114.114.114:53 114.114.114.114:53 1 0 0 0 2 1 2 1 true udp udp
0 socks5Server 1 0 0 0 1 3 true socks5
2 9005 9005 httpProxyServer httpProxyServer 1 0 0 0 4 1 4 1 true
3 9002 socks5Server 1 0 0 0 3 1 true socks5
4 9001 tunnelServer 127.0.0.1:8082 1 0 0 0 1 1 true 测试tcp

28
lib/common/const.go Normal file
View File

@ -0,0 +1,28 @@
package common
const (
COMPRESS_NONE_ENCODE = iota
COMPRESS_NONE_DECODE
COMPRESS_SNAPY_ENCODE
COMPRESS_SNAPY_DECODE
VERIFY_EER = "vkey"
WORK_MAIN = "main"
WORK_CHAN = "chan"
RES_SIGN = "sign"
RES_MSG = "msg0"
RES_CLOSE = "clse"
NEW_CONN = "conn" //新连接标志
NEW_TASK = "task" //新连接标志
CONN_SUCCESS = "sucs"
CONN_TCP = "tcp"
CONN_UDP = "udp"
UnauthorizedBytes = `HTTP/1.1 401 Unauthorized
Content-Type: text/plain; charset=utf-8
WWW-Authenticate: Basic realm="easyProxy"
401 Unauthorized`
IO_EOF = "PROXYEOF"
ConnectionFailBytes = `HTTP/1.1 404 Not Found
`
)

67
lib/common/run.go Normal file
View File

@ -0,0 +1,67 @@
package common
import (
"os"
"path/filepath"
"runtime"
)
//Get the currently selected configuration file directory
//For non-Windows systems, select the /etc/nps as config directory if exist, or select ./
//windows system, select the C:\Program Files\nps as config directory if exist, or select ./
func GetRunPath() string {
var path string
if path = GetInstallPath(); !FileExists(path) {
return "./"
}
return path
}
//Different systems get different installation paths
func GetInstallPath() string {
var path string
if IsWindows() {
path = `C:\Program Files\nps`
} else {
path = "/etc/nps"
}
return path
}
//Get the absolute path to the running directory
func GetAppPath() string {
if path, err := filepath.Abs(filepath.Dir(os.Args[0])); err == nil {
return path
}
return os.Args[0]
}
//Determine whether the current system is a Windows system?
func IsWindows() bool {
if runtime.GOOS == "windows" {
return true
}
return false
}
//interface log file path
func GetLogPath() string {
var path string
if IsWindows() {
path = "./"
} else {
path = "/tmp"
}
return path
}
//interface pid file path
func GetPidPath() string {
var path string
if IsWindows() {
path = "./"
} else {
path = "/tmp"
}
return path
}

View File

@ -1,45 +1,21 @@
package lib
package common
import (
"bytes"
"encoding/base64"
"encoding/binary"
"github.com/cnlh/nps/lib/crypt"
"github.com/cnlh/nps/lib/lg"
"io/ioutil"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
)
const (
COMPRESS_NONE_ENCODE = iota
COMPRESS_NONE_DECODE
COMPRESS_SNAPY_ENCODE
COMPRESS_SNAPY_DECODE
VERIFY_EER = "vkey"
WORK_MAIN = "main"
WORK_CHAN = "chan"
RES_SIGN = "sign"
RES_MSG = "msg0"
RES_CLOSE = "clse"
NEW_CONN = "conn" //新连接标志
CONN_SUCCESS = "sucs"
CONN_TCP = "tcp"
CONN_UDP = "udp"
UnauthorizedBytes = `HTTP/1.1 401 Unauthorized
Content-Type: text/plain; charset=utf-8
WWW-Authenticate: Basic realm="easyProxy"
401 Unauthorized`
IO_EOF = "PROXYEOF"
ConnectionFailBytes = `HTTP/1.1 404 Not Found
`
)
//判断压缩方式
//Judging Compression Mode
func GetCompressType(compress string) (int, int) {
switch compress {
case "":
@ -47,12 +23,12 @@ func GetCompressType(compress string) (int, int) {
case "snappy":
return COMPRESS_SNAPY_DECODE, COMPRESS_SNAPY_ENCODE
default:
Fatalln("数据压缩格式错误")
lg.Fatalln("数据压缩格式错误")
}
return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
}
//通过host获取对应的ip地址
//Get the corresponding IP address through domain name
func GetHostByName(hostname string) string {
if !DomainCheck(hostname) {
return hostname
@ -68,7 +44,7 @@ func GetHostByName(hostname string) string {
return ""
}
//检查是否是域名
//Check the legality of domain
func DomainCheck(domain string) bool {
var match bool
IsLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}(/)"
@ -80,7 +56,7 @@ func DomainCheck(domain string) bool {
return match
}
//检查basic认证
//Check if the Request request is validated
func CheckAuth(r *http.Request, user, passwd string) bool {
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
if len(s) != 2 {
@ -122,11 +98,12 @@ func GetIntNoErrByStr(str string) int {
return i
}
//简单的一个校验值
//Get verify value
func Getverifyval(vkey string) string {
return Md5(vkey)
return crypt.Md5(vkey)
}
//Change headers and host of request
func ChangeHostAndHeader(r *http.Request, host string, header string, addr string) {
if host != "" {
r.Host = host
@ -145,6 +122,7 @@ func ChangeHostAndHeader(r *http.Request, host string, header string, addr strin
r.Header.Set("X-Real-IP", addr)
}
//Read file content by file path
func ReadAllFromFile(filePath string) ([]byte, error) {
f, err := os.Open(filePath)
if err != nil {
@ -163,53 +141,7 @@ func FileExists(name string) bool {
return true
}
func GetRunPath() string {
var path string
if path = GetInstallPath(); !FileExists(path) {
return "./"
}
return path
}
func GetInstallPath() string {
var path string
if IsWindows() {
path = `C:\Program Files\nps`
} else {
path = "/etc/nps"
}
return path
}
func GetAppPath() string {
if path, err := filepath.Abs(filepath.Dir(os.Args[0])); err == nil {
return path
}
return os.Args[0]
}
func IsWindows() bool {
if runtime.GOOS == "windows" {
return true
}
return false
}
func GetLogPath() string {
var path string
if IsWindows() {
path = "./"
} else {
path = "/tmp"
}
return path
}
func GetPidPath() string {
var path string
if IsWindows() {
path = "./"
} else {
path = "/tmp"
}
return path
}
//Judge whether the TCP port can open normally
func TestTcpPort(port int) bool {
l, err := net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), port, ""})
defer l.Close()
@ -218,3 +150,27 @@ func TestTcpPort(port int) bool {
}
return true
}
//Judge whether the UDP port can open normally
func TestUdpPort(port int) bool {
l, err := net.ListenUDP("udp", &net.UDPAddr{net.ParseIP("0.0.0.0"), port, ""})
defer l.Close()
if err != nil {
return false
}
return true
}
//Write length and individual byte data
//Length prevents sticking
//# Characters are used to separate data
func BinaryWrite(raw *bytes.Buffer, v ...string) {
buffer := new(bytes.Buffer)
var l int32
for _, v := range v {
l += int32(len([]byte(v))) + int32(len([]byte("#")))
binary.Write(buffer, binary.LittleEndian, []byte(v))
binary.Write(buffer, binary.LittleEndian, []byte("#"))
}
binary.Write(raw, binary.LittleEndian, buffer.Bytes())
}

View File

@ -1,11 +1,15 @@
package lib
package conn
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"github.com/golang/snappy"
"github.com/cnlh/nps/lib/common"
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/lib/kcp"
"github.com/cnlh/nps/lib/pool"
"github.com/cnlh/nps/lib/rate"
"io"
"net"
"net/http"
@ -18,126 +22,6 @@ import (
const cryptKey = "1234567812345678"
type CryptConn struct {
conn net.Conn
crypt bool
rate *Rate
}
func NewCryptConn(conn net.Conn, crypt bool, rate *Rate) *CryptConn {
c := new(CryptConn)
c.conn = conn
c.crypt = crypt
c.rate = rate
return c
}
//加密写
func (s *CryptConn) Write(b []byte) (n int, err error) {
n = len(b)
if s.crypt {
if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
return
}
}
if b, err = GetLenBytes(b); err != nil {
return
}
_, err = s.conn.Write(b)
if s.rate != nil {
s.rate.Get(int64(n))
}
return
}
//解密读
func (s *CryptConn) Read(b []byte) (n int, err error) {
var lens int
var buf []byte
var rb []byte
c := NewConn(s.conn)
if lens, err = c.GetLen(); err != nil {
return
}
if buf, err = c.ReadLen(lens); err != nil {
return
}
if s.crypt {
if rb, err = AesDecrypt(buf, []byte(cryptKey)); err != nil {
return
}
} else {
rb = buf
}
copy(b, rb)
n = len(rb)
if s.rate != nil {
s.rate.Get(int64(n))
}
return
}
type SnappyConn struct {
w *snappy.Writer
r *snappy.Reader
crypt bool
rate *Rate
}
func NewSnappyConn(conn net.Conn, crypt bool, rate *Rate) *SnappyConn {
c := new(SnappyConn)
c.w = snappy.NewBufferedWriter(conn)
c.r = snappy.NewReader(conn)
c.crypt = crypt
c.rate = rate
return c
}
//snappy压缩写 包含加密
func (s *SnappyConn) Write(b []byte) (n int, err error) {
n = len(b)
if s.crypt {
if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
Println("encode crypt error:", err)
return
}
}
if _, err = s.w.Write(b); err != nil {
return
}
if err = s.w.Flush(); err != nil {
return
}
if s.rate != nil {
s.rate.Get(int64(n))
}
return
}
//snappy压缩读 包含解密
func (s *SnappyConn) Read(b []byte) (n int, err error) {
buf := BufPool.Get().([]byte)
defer BufPool.Put(buf)
if n, err = s.r.Read(buf); err != nil {
return
}
var bs []byte
if s.crypt {
if bs, err = AesDecrypt(buf[:n], []byte(cryptKey)); err != nil {
Println("decode crypt error:", err)
return
}
} else {
bs = buf[:n]
}
n = len(bs)
copy(b, bs)
if s.rate != nil {
s.rate.Get(int64(n))
}
return
}
type Conn struct {
Conn net.Conn
sync.Mutex
@ -186,16 +70,16 @@ func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.
//读取指定长度内容
func (s *Conn) ReadLen(cLen int) ([]byte, error) {
if cLen > poolSize {
if cLen > pool.PoolSize {
return nil, errors.New("长度错误" + strconv.Itoa(cLen))
}
var buf []byte
if cLen <= poolSizeSmall {
buf = BufPoolSmall.Get().([]byte)[:cLen]
defer BufPoolSmall.Put(buf)
if cLen <= pool.PoolSizeSmall {
buf = pool.BufPoolSmall.Get().([]byte)[:cLen]
defer pool.BufPoolSmall.Put(buf)
} else {
buf = BufPoolMax.Get().([]byte)[:cLen]
defer BufPoolMax.Put(buf)
buf = pool.BufPoolMax.Get().([]byte)[:cLen]
defer pool.BufPoolMax.Put(buf)
}
if n, err := io.ReadFull(s, buf); err != nil || n != cLen {
return buf, errors.New("读取指定长度错误" + err.Error())
@ -231,35 +115,64 @@ func (s *Conn) GetConnStatus() (id int, status bool, err error) {
if b, err = s.ReadLen(1); err != nil {
return
} else {
status = GetBoolByStr(string(b[0]))
status = common.GetBoolByStr(string(b[0]))
}
return
}
//设置连接为长连接
func (s *Conn) SetAlive() {
func (s *Conn) SetAlive(tp string) {
if tp == "kcp" {
s.setKcpAlive()
} else {
s.setTcpAlive()
}
}
//设置连接为长连接
func (s *Conn) setTcpAlive() {
conn := s.Conn.(*net.TCPConn)
conn.SetReadDeadline(time.Time{})
conn.SetKeepAlive(true)
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
}
//设置连接为长连接
func (s *Conn) setKcpAlive() {
conn := s.Conn.(*kcp.UDPSession)
conn.SetReadDeadline(time.Time{})
}
//设置连接为长连接
func (s *Conn) SetReadDeadline(t time.Duration, tp string) {
if tp == "kcp" {
s.SetKcpReadDeadline(t)
} else {
s.SetTcpReadDeadline(t)
}
}
//set read dead time
func (s *Conn) SetReadDeadline(t time.Duration) {
func (s *Conn) SetTcpReadDeadline(t time.Duration) {
s.Conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Duration(t) * time.Second))
}
//set read dead time
func (s *Conn) SetKcpReadDeadline(t time.Duration) {
s.Conn.(*kcp.UDPSession).SetReadDeadline(time.Now().Add(time.Duration(t) * time.Second))
}
//单独读(加密|压缩)
func (s *Conn) ReadFrom(b []byte, compress int, crypt bool, rate *Rate) (int, error) {
if COMPRESS_SNAPY_DECODE == compress {
func (s *Conn) ReadFrom(b []byte, compress int, crypt bool, rate *rate.Rate) (int, error) {
if common.COMPRESS_SNAPY_DECODE == compress {
return NewSnappyConn(s.Conn, crypt, rate).Read(b)
}
return NewCryptConn(s.Conn, crypt, rate).Read(b)
}
//单独写(加密|压缩)
func (s *Conn) WriteTo(b []byte, compress int, crypt bool, rate *Rate) (n int, err error) {
if COMPRESS_SNAPY_ENCODE == compress {
func (s *Conn) WriteTo(b []byte, compress int, crypt bool, rate *rate.Rate) (n int, err error) {
if common.COMPRESS_SNAPY_ENCODE == compress {
return NewSnappyConn(s.Conn, crypt, rate).Write(b)
}
return NewCryptConn(s.Conn, crypt, rate).Write(b)
@ -292,7 +205,7 @@ func (s *Conn) SendMsg(content []byte, link *Link) (n int, err error) {
func (s *Conn) GetMsgContent(link *Link) (content []byte, err error) {
s.Lock()
defer s.Unlock()
buf := BufPoolCopy.Get().([]byte)
buf := pool.BufPoolCopy.Get().([]byte)
if n, err := s.ReadFrom(buf, link.De, link.Crypt, link.Rate); err == nil && n > 4 {
content = buf[:n]
}
@ -310,7 +223,7 @@ func (s *Conn) SendLinkInfo(link *Link) (int, error) {
+----------+------+----------+------+----+----+------+
*/
raw := bytes.NewBuffer([]byte{})
binary.Write(raw, binary.LittleEndian, []byte(NEW_CONN))
binary.Write(raw, binary.LittleEndian, []byte(common.NEW_CONN))
binary.Write(raw, binary.LittleEndian, int32(14+len(link.Host)))
binary.Write(raw, binary.LittleEndian, int32(link.Id))
binary.Write(raw, binary.LittleEndian, []byte(link.ConnType))
@ -318,13 +231,13 @@ func (s *Conn) SendLinkInfo(link *Link) (int, error) {
binary.Write(raw, binary.LittleEndian, []byte(link.Host))
binary.Write(raw, binary.LittleEndian, []byte(strconv.Itoa(link.En)))
binary.Write(raw, binary.LittleEndian, []byte(strconv.Itoa(link.De)))
binary.Write(raw, binary.LittleEndian, []byte(GetStrByBool(link.Crypt)))
binary.Write(raw, binary.LittleEndian, []byte(common.GetStrByBool(link.Crypt)))
s.Lock()
defer s.Unlock()
return s.Write(raw.Bytes())
}
func (s *Conn) GetLinkInfo() (link *Link, err error) {
func (s *Conn) GetLinkInfo() (lk *Link, err error) {
s.Lock()
defer s.Unlock()
var hostLen, n int
@ -332,21 +245,69 @@ func (s *Conn) GetLinkInfo() (link *Link, err error) {
if n, err = s.GetLen(); err != nil {
return
}
link = new(Link)
lk = new(Link)
if buf, err = s.ReadLen(n); err != nil {
return
}
if link.Id, err = GetLenByBytes(buf[:4]); err != nil {
if lk.Id, err = GetLenByBytes(buf[:4]); err != nil {
return
}
link.ConnType = string(buf[4:7])
lk.ConnType = string(buf[4:7])
if hostLen, err = GetLenByBytes(buf[7:11]); err != nil {
return
} else {
link.Host = string(buf[11 : 11+hostLen])
link.En = GetIntNoErrByStr(string(buf[11+hostLen]))
link.De = GetIntNoErrByStr(string(buf[12+hostLen]))
link.Crypt = GetBoolByStr(string(buf[13+hostLen]))
lk.Host = string(buf[11 : 11+hostLen])
lk.En = common.GetIntNoErrByStr(string(buf[11+hostLen]))
lk.De = common.GetIntNoErrByStr(string(buf[12+hostLen]))
lk.Crypt = common.GetBoolByStr(string(buf[13+hostLen]))
}
return
}
//send task info
func (s *Conn) SendTaskInfo(t *file.Tunnel) (int, error) {
/*
The task info is formed as follows:
+----+-----+---------+
|type| len | content |
+----+---------------+
| 4 | 4 | ... |
+----+---------------+
*/
raw := bytes.NewBuffer([]byte{})
binary.Write(raw, binary.LittleEndian, common.NEW_TASK)
common.BinaryWrite(raw, t.Mode, string(t.TcpPort), string(t.Target), string(t.Config.U), string(t.Config.P), common.GetStrByBool(t.Config.Crypt), t.Config.Compress, t.Remark)
s.Lock()
defer s.Unlock()
return s.Write(raw.Bytes())
}
//get task info
func (s *Conn) GetTaskInfo() (t *file.Tunnel, err error) {
var l int
var b []byte
if l, err = s.GetLen(); err != nil {
return
} else if b, err = s.ReadLen(l); err != nil {
return
} else {
arr := strings.Split(string(b), "#")
t.Mode = arr[0]
t.TcpPort, _ = strconv.Atoi(arr[1])
t.Target = arr[2]
t.Config = new(file.Config)
t.Config.U = arr[3]
t.Config.P = arr[4]
t.Config.Compress = arr[5]
t.Config.CompressDecode, t.Config.CompressDecode = common.GetCompressType(arr[5])
t.Id = file.GetCsvDb().GetTaskId()
t.Status = true
if t.Client, err = file.GetCsvDb().GetClient(0); err != nil {
return
}
t.Flow = new(file.Flow)
t.Remark = arr[6]
t.UseClientCnf = false
}
return
}
@ -388,31 +349,31 @@ func (s *Conn) Read(b []byte) (int, error) {
//write error
func (s *Conn) WriteError() (int, error) {
return s.Write([]byte(RES_MSG))
return s.Write([]byte(common.RES_MSG))
}
//write sign flag
func (s *Conn) WriteSign() (int, error) {
return s.Write([]byte(RES_SIGN))
return s.Write([]byte(common.RES_SIGN))
}
//write sign flag
func (s *Conn) WriteClose() (int, error) {
return s.Write([]byte(RES_CLOSE))
return s.Write([]byte(common.RES_CLOSE))
}
//write main
func (s *Conn) WriteMain() (int, error) {
s.Lock()
defer s.Unlock()
return s.Write([]byte(WORK_MAIN))
return s.Write([]byte(common.WORK_MAIN))
}
//write chan
func (s *Conn) WriteChan() (int, error) {
s.Lock()
defer s.Unlock()
return s.Write([]byte(WORK_CHAN))
return s.Write([]byte(common.WORK_CHAN))
}
//获取长度+内容
@ -436,3 +397,13 @@ func GetLenByBytes(buf []byte) (int, error) {
}
return int(nlen), nil
}
func SetUdpSession(sess *kcp.UDPSession) {
sess.SetStreamMode(true)
sess.SetWindowSize(1024, 1024)
sess.SetReadBuffer(64 * 1024)
sess.SetWriteBuffer(64 * 1024)
sess.SetNoDelay(1, 10, 2, 1)
sess.SetMtu(1600)
sess.SetACKNoDelay(true)
}

37
lib/conn/link.go Normal file
View File

@ -0,0 +1,37 @@
package conn
import (
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/lib/rate"
"net"
)
type Link struct {
Id int //id
ConnType string //连接类型
Host string //目标
En int //加密
De int //解密
Crypt bool //加密
Conn *Conn
Flow *file.Flow
UdpListener *net.UDPConn
Rate *rate.Rate
UdpRemoteAddr *net.UDPAddr
}
func NewLink(id int, connType string, host string, en, de int, crypt bool, c *Conn, flow *file.Flow, udpListener *net.UDPConn, rate *rate.Rate, UdpRemoteAddr *net.UDPAddr) *Link {
return &Link{
Id: id,
ConnType: connType,
Host: host,
En: en,
De: de,
Crypt: crypt,
Conn: c,
Flow: flow,
UdpListener: udpListener,
Rate: rate,
UdpRemoteAddr: UdpRemoteAddr,
}
}

66
lib/conn/normal.go Normal file
View File

@ -0,0 +1,66 @@
package conn
import (
"github.com/cnlh/nps/lib/crypt"
"github.com/cnlh/nps/lib/rate"
"net"
)
type CryptConn struct {
conn net.Conn
crypt bool
rate *rate.Rate
}
func NewCryptConn(conn net.Conn, crypt bool, rate *rate.Rate) *CryptConn {
c := new(CryptConn)
c.conn = conn
c.crypt = crypt
c.rate = rate
return c
}
//加密写
func (s *CryptConn) Write(b []byte) (n int, err error) {
n = len(b)
if s.crypt {
if b, err = crypt.AesEncrypt(b, []byte(cryptKey)); err != nil {
return
}
}
if b, err = GetLenBytes(b); err != nil {
return
}
_, err = s.conn.Write(b)
if s.rate != nil {
s.rate.Get(int64(n))
}
return
}
//解密读
func (s *CryptConn) Read(b []byte) (n int, err error) {
var lens int
var buf []byte
var rb []byte
c := NewConn(s.conn)
if lens, err = c.GetLen(); err != nil {
return
}
if buf, err = c.ReadLen(lens); err != nil {
return
}
if s.crypt {
if rb, err = crypt.AesDecrypt(buf, []byte(cryptKey)); err != nil {
return
}
} else {
rb = buf
}
copy(b, rb)
n = len(rb)
if s.rate != nil {
s.rate.Get(int64(n))
}
return
}

72
lib/conn/snappy.go Normal file
View File

@ -0,0 +1,72 @@
package conn
import (
"github.com/cnlh/nps/lib/crypt"
"github.com/cnlh/nps/lib/lg"
"github.com/cnlh/nps/lib/pool"
"github.com/cnlh/nps/lib/rate"
"github.com/cnlh/nps/lib/snappy"
"log"
"net"
)
type SnappyConn struct {
w *snappy.Writer
r *snappy.Reader
crypt bool
rate *rate.Rate
}
func NewSnappyConn(conn net.Conn, crypt bool, rate *rate.Rate) *SnappyConn {
c := new(SnappyConn)
c.w = snappy.NewBufferedWriter(conn)
c.r = snappy.NewReader(conn)
c.crypt = crypt
c.rate = rate
return c
}
//snappy压缩写 包含加密
func (s *SnappyConn) Write(b []byte) (n int, err error) {
n = len(b)
if s.crypt {
if b, err = crypt.AesEncrypt(b, []byte(cryptKey)); err != nil {
lg.Println("encode crypt error:", err)
return
}
}
if _, err = s.w.Write(b); err != nil {
return
}
if err = s.w.Flush(); err != nil {
return
}
if s.rate != nil {
s.rate.Get(int64(n))
}
return
}
//snappy压缩读 包含解密
func (s *SnappyConn) Read(b []byte) (n int, err error) {
buf := pool.BufPool.Get().([]byte)
defer pool.BufPool.Put(buf)
if n, err = s.r.Read(buf); err != nil {
return
}
var bs []byte
if s.crypt {
if bs, err = crypt.AesDecrypt(buf[:n], []byte(cryptKey)); err != nil {
log.Println("decode crypt error:", err)
return
}
} else {
bs = buf[:n]
}
n = len(bs)
copy(b, bs)
if s.rate != nil {
s.rate.Get(int64(n))
}
return
}

View File

@ -1,4 +1,4 @@
package lib
package crypt
import (
"bytes"
@ -37,21 +37,19 @@ func AesDecrypt(crypted, key []byte) ([]byte, error) {
blockSize := block.BlockSize()
blockMode := cipher.NewCBCDecrypter(block, key[:blockSize])
origData := make([]byte, len(crypted))
// origData := crypted
blockMode.CryptBlocks(origData, crypted)
err, origData = PKCS5UnPadding(origData)
// origData = ZeroUnPadding(origData)
return origData, err
}
//补全
//Completion when the length is insufficient
func PKCS5Padding(ciphertext []byte, blockSize int) []byte {
padding := blockSize - len(ciphertext)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padtext...)
}
//去补
//Remove excess
func PKCS5UnPadding(origData []byte) (error, []byte) {
length := len(origData)
// 去掉最后一个字节 unpadding 次
@ -62,14 +60,14 @@ func PKCS5UnPadding(origData []byte) (error, []byte) {
return nil, origData[:(length - unpadding)]
}
//生成32位md5字串
//Generate 32-bit MD5 strings
func Md5(s string) string {
h := md5.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}
//生成随机验证密钥
//Generating Random Verification Key
func GetRandomString(l int) string {
str := "0123456789abcdefghijklmnopqrstuvwxyz"
bytes := []byte(str)

View File

@ -1,6 +1,7 @@
package lib
package daemon
import (
"github.com/cnlh/nps/lib/common"
"io/ioutil"
"log"
"os"
@ -10,7 +11,7 @@ import (
"strings"
)
func InitDaemon(f string) {
func InitDaemon(f string, runPath string, pidPath string) {
if len(os.Args) < 2 {
return
}
@ -22,22 +23,17 @@ func InitDaemon(f string) {
args = append(args, "-log=file")
switch os.Args[1] {
case "start":
start(args, f)
start(args, f, pidPath, runPath)
os.Exit(0)
case "stop":
stop(f, args[0])
stop(f, args[0], pidPath)
os.Exit(0)
case "restart":
stop(f, args[0])
start(args, f)
os.Exit(0)
case "install":
if f == "nps" {
InstallNps()
}
stop(f, args[0], pidPath)
start(args, f, pidPath, runPath)
os.Exit(0)
case "status":
if status(f) {
if status(f, pidPath) {
log.Printf("%s is running", f)
} else {
log.Printf("%s is not running", f)
@ -46,11 +42,11 @@ func InitDaemon(f string) {
}
}
func status(f string) bool {
func status(f string, pidPath string) bool {
var cmd *exec.Cmd
b, err := ioutil.ReadFile(filepath.Join(GetPidPath(), f+".pid"))
b, err := ioutil.ReadFile(filepath.Join(pidPath, f+".pid"))
if err == nil {
if !IsWindows() {
if !common.IsWindows() {
cmd = exec.Command("/bin/sh", "-c", "ps -ax | awk '{ print $1 }' | grep "+string(b))
} else {
cmd = exec.Command("tasklist", )
@ -63,38 +59,38 @@ func status(f string) bool {
return false
}
func start(osArgs []string, f string) {
if status(f) {
func start(osArgs []string, f string, pidPath, runPath string) {
if status(f, pidPath) {
log.Printf(" %s is running", f)
return
}
cmd := exec.Command(osArgs[0], osArgs[1:]...)
cmd.Start()
if cmd.Process.Pid > 0 {
log.Println("start ok , pid:", cmd.Process.Pid, "config path:", GetRunPath())
log.Println("start ok , pid:", cmd.Process.Pid, "config path:", runPath)
d1 := []byte(strconv.Itoa(cmd.Process.Pid))
ioutil.WriteFile(filepath.Join(GetPidPath(), f+".pid"), d1, 0600)
ioutil.WriteFile(filepath.Join(pidPath, f+".pid"), d1, 0600)
} else {
log.Println("start error")
}
}
func stop(f string, p string) {
if !status(f) {
func stop(f string, p string, pidPath string) {
if !status(f, pidPath) {
log.Printf(" %s is not running", f)
return
}
var c *exec.Cmd
var err error
if IsWindows() {
if common.IsWindows() {
p := strings.Split(p, `\`)
c = exec.Command("taskkill", "/F", "/IM", p[len(p)-1])
} else {
b, err := ioutil.ReadFile(filepath.Join(GetPidPath(), f+".pid"))
b, err := ioutil.ReadFile(filepath.Join(pidPath, f+".pid"))
if err == nil {
c = exec.Command("/bin/bash", "-c", `kill -9 `+string(b))
} else {
log.Fatalln("stop error,PID file does not exist")
log.Fatalln("stop error,pid file does not exist")
}
}
err = c.Run()

19
lib/file/csv.go Normal file
View File

@ -0,0 +1,19 @@
package file
import (
"github.com/cnlh/nps/lib/common"
"sync"
)
var (
CsvDb *Csv
once sync.Once
)
//init csv from file
func GetCsvDb() *Csv {
once.Do(func() {
CsvDb = NewCsv(common.GetRunPath())
CsvDb.Init()
})
return CsvDb
}

View File

@ -1,8 +1,11 @@
package lib
package file
import (
"encoding/csv"
"errors"
"github.com/cnlh/nps/lib/common"
"github.com/cnlh/nps/lib/lg"
"github.com/cnlh/nps/lib/rate"
"os"
"path/filepath"
"strconv"
@ -10,13 +13,10 @@ import (
"sync"
)
var (
CsvDb *Csv
once sync.Once
)
func NewCsv() *Csv {
return new(Csv)
func NewCsv(runPath string) *Csv {
return &Csv{
RunPath: runPath,
}
}
type Csv struct {
@ -24,6 +24,7 @@ type Csv struct {
Path string
Hosts []*Host //域名列表
Clients []*Client //客户端
RunPath string //存储根目录
ClientIncreaseId int //客户端id
TaskIncreaseId int //任务自增ID
sync.Mutex
@ -37,9 +38,9 @@ func (s *Csv) Init() {
func (s *Csv) StoreTasksToCsv() {
// 创建文件
csvFile, err := os.Create(filepath.Join(GetRunPath(), "conf", "tasks.csv"))
csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "tasks.csv"))
if err != nil {
Fatalf(err.Error())
lg.Fatalf(err.Error())
}
defer csvFile.Close()
writer := csv.NewWriter(csvFile)
@ -51,8 +52,8 @@ func (s *Csv) StoreTasksToCsv() {
task.Config.U,
task.Config.P,
task.Config.Compress,
GetStrByBool(task.Status),
GetStrByBool(task.Config.Crypt),
common.GetStrByBool(task.Status),
common.GetStrByBool(task.Config.Crypt),
strconv.Itoa(task.Config.CompressEncode),
strconv.Itoa(task.Config.CompressDecode),
strconv.Itoa(task.Id),
@ -62,7 +63,7 @@ func (s *Csv) StoreTasksToCsv() {
}
err := writer.Write(record)
if err != nil {
Fatalf(err.Error())
lg.Fatalf(err.Error())
}
}
writer.Flush()
@ -87,33 +88,33 @@ func (s *Csv) openFile(path string) ([][]string, error) {
}
func (s *Csv) LoadTaskFromCsv() {
path := filepath.Join(GetRunPath(), "conf", "tasks.csv")
path := filepath.Join(s.RunPath, "conf", "tasks.csv")
records, err := s.openFile(path)
if err != nil {
Fatalln("配置文件打开错误:", path)
lg.Fatalln("配置文件打开错误:", path)
}
var tasks []*Tunnel
// 将每一行数据保存到内存slice中
for _, item := range records {
post := &Tunnel{
TcpPort: GetIntNoErrByStr(item[0]),
TcpPort: common.GetIntNoErrByStr(item[0]),
Mode: item[1],
Target: item[2],
Config: &Config{
U: item[3],
P: item[4],
Compress: item[5],
Crypt: GetBoolByStr(item[7]),
CompressEncode: GetIntNoErrByStr(item[8]),
CompressDecode: GetIntNoErrByStr(item[9]),
Crypt: common.GetBoolByStr(item[7]),
CompressEncode: common.GetIntNoErrByStr(item[8]),
CompressDecode: common.GetIntNoErrByStr(item[9]),
},
Status: GetBoolByStr(item[6]),
Id: GetIntNoErrByStr(item[10]),
UseClientCnf: GetBoolByStr(item[12]),
Status: common.GetBoolByStr(item[6]),
Id: common.GetIntNoErrByStr(item[10]),
UseClientCnf: common.GetBoolByStr(item[12]),
Remark: item[13],
}
post.Flow = new(Flow)
if post.Client, err = s.GetClient(GetIntNoErrByStr(item[11])); err != nil {
if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[11])); err != nil {
continue
}
tasks = append(tasks, post)
@ -135,7 +136,7 @@ func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (int, error) {
s.Lock()
defer s.Unlock()
for _, v := range s.Clients {
if Getverifyval(v.VerifyKey) == vKey && v.Status {
if common.Getverifyval(v.VerifyKey) == vKey && v.Status {
if arr := strings.Split(addr, ":"); len(arr) > 0 {
v.Addr = arr[0]
}
@ -186,7 +187,7 @@ func (s *Csv) GetTask(id int) (v *Tunnel, err error) {
func (s *Csv) StoreHostToCsv() {
// 创建文件
csvFile, err := os.Create(filepath.Join(GetRunPath(), "conf", "hosts.csv"))
csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "hosts.csv"))
if err != nil {
panic(err)
}
@ -214,24 +215,24 @@ func (s *Csv) StoreHostToCsv() {
}
func (s *Csv) LoadClientFromCsv() {
path := filepath.Join(GetRunPath(), "conf", "clients.csv")
path := filepath.Join(s.RunPath, "conf", "clients.csv")
records, err := s.openFile(path)
if err != nil {
Fatalln("配置文件打开错误:", path)
lg.Fatalln("配置文件打开错误:", path)
}
var clients []*Client
// 将每一行数据保存到内存slice中
for _, item := range records {
post := &Client{
Id: GetIntNoErrByStr(item[0]),
Id: common.GetIntNoErrByStr(item[0]),
VerifyKey: item[1],
Remark: item[2],
Status: GetBoolByStr(item[3]),
RateLimit: GetIntNoErrByStr(item[8]),
Status: common.GetBoolByStr(item[3]),
RateLimit: common.GetIntNoErrByStr(item[8]),
Cnf: &Config{
U: item[4],
P: item[5],
Crypt: GetBoolByStr(item[6]),
Crypt: common.GetBoolByStr(item[6]),
Compress: item[7],
},
}
@ -239,21 +240,21 @@ func (s *Csv) LoadClientFromCsv() {
s.ClientIncreaseId = post.Id
}
if post.RateLimit > 0 {
post.Rate = NewRate(int64(post.RateLimit * 1024))
post.Rate = rate.NewRate(int64(post.RateLimit * 1024))
post.Rate.Start()
}
post.Flow = new(Flow)
post.Flow.FlowLimit = int64(GetIntNoErrByStr(item[9]))
post.Flow.FlowLimit = int64(common.GetIntNoErrByStr(item[9]))
clients = append(clients, post)
}
s.Clients = clients
}
func (s *Csv) LoadHostFromCsv() {
path := filepath.Join(GetRunPath(), "conf", "hosts.csv")
path := filepath.Join(s.RunPath, "conf", "hosts.csv")
records, err := s.openFile(path)
if err != nil {
Fatalln("配置文件打开错误:", path)
lg.Fatalln("配置文件打开错误:", path)
}
var hosts []*Host
// 将每一行数据保存到内存slice中
@ -265,7 +266,7 @@ func (s *Csv) LoadHostFromCsv() {
HostChange: item[4],
Remark: item[5],
}
if post.Client, err = s.GetClient(GetIntNoErrByStr(item[2])); err != nil {
if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[2])); err != nil {
continue
}
post.Flow = new(Flow)
@ -387,11 +388,12 @@ func (s *Csv) GetClient(id int) (v *Client, err error) {
err = errors.New("未找到")
return
}
func (s *Csv) StoreClientsToCsv() {
// 创建文件
csvFile, err := os.Create(filepath.Join(GetRunPath(), "conf", "clients.csv"))
csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "clients.csv"))
if err != nil {
Fatalln(err.Error())
lg.Fatalln(err.Error())
}
defer csvFile.Close()
writer := csv.NewWriter(csvFile)
@ -403,24 +405,15 @@ func (s *Csv) StoreClientsToCsv() {
strconv.FormatBool(client.Status),
client.Cnf.U,
client.Cnf.P,
GetStrByBool(client.Cnf.Crypt),
common.GetStrByBool(client.Cnf.Crypt),
client.Cnf.Compress,
strconv.Itoa(client.RateLimit),
strconv.Itoa(int(client.Flow.FlowLimit)),
}
err := writer.Write(record)
if err != nil {
Fatalln(err.Error())
lg.Fatalln(err.Error())
}
}
writer.Flush()
}
//init csv from file
func GetCsvDb() *Csv {
once.Do(func() {
CsvDb = NewCsv()
CsvDb.Init()
})
return CsvDb
}

View File

@ -1,41 +1,11 @@
package lib
package file
import (
"net"
"github.com/cnlh/nps/lib/rate"
"strings"
"sync"
)
type Link struct {
Id int //id
ConnType string //连接类型
Host string //目标
En int //加密
De int //解密
Crypt bool //加密
Conn *Conn
Flow *Flow
UdpListener *net.UDPConn
Rate *Rate
UdpRemoteAddr *net.UDPAddr
}
func NewLink(id int, connType string, host string, en, de int, crypt bool, conn *Conn, flow *Flow, udpListener *net.UDPConn, rate *Rate, UdpRemoteAddr *net.UDPAddr) *Link {
return &Link{
Id: id,
ConnType: connType,
Host: host,
En: en,
De: de,
Crypt: crypt,
Conn: conn,
Flow: flow,
UdpListener: udpListener,
Rate: rate,
UdpRemoteAddr: UdpRemoteAddr,
}
}
type Flow struct {
ExportFlow int64 //出口流量
InletFlow int64 //入口流量
@ -52,15 +22,15 @@ func (s *Flow) Add(in, out int) {
type Client struct {
Cnf *Config
Id int //id
VerifyKey string //验证密钥
Addr string //客户端ip地址
Remark string //备注
Status bool //是否开启
IsConnect bool //是否连接
RateLimit int //速度限制 /kb
Flow *Flow //流量
Rate *Rate //速度控制
Id int //id
VerifyKey string //验证密钥
Addr string //客户端ip地址
Remark string //备注
Status bool //是否开启
IsConnect bool //是否连接
RateLimit int //速度限制 /kb
Flow *Flow //流量
Rate *rate.Rate //速度控制
id int
sync.RWMutex
}
@ -74,7 +44,7 @@ func (s *Client) GetId() int {
type Tunnel struct {
Id int //Id
TcpPort int //服务端与客户端通信端口
TcpPort int //服务端监听端口
Mode string //启动方式
Target string //目标
Status bool //是否开启

View File

@ -1,8 +1,9 @@
package lib
package install
import (
"errors"
"fmt"
"github.com/cnlh/nps/lib/common"
"io"
"log"
"os"
@ -11,22 +12,22 @@ import (
)
func InstallNps() {
path := GetInstallPath()
path := common.GetInstallPath()
MkidrDirAll(path, "conf", "web/static", "web/views")
//复制文件到对应目录
if err := CopyDir(filepath.Join(GetAppPath(), "web", "views"), filepath.Join(path, "web", "views")); err != nil {
if err := CopyDir(filepath.Join(common.GetAppPath(), "web", "views"), filepath.Join(path, "web", "views")); err != nil {
log.Fatalln(err)
}
if err := CopyDir(filepath.Join(GetAppPath(), "web", "static"), filepath.Join(path, "web", "static")); err != nil {
if err := CopyDir(filepath.Join(common.GetAppPath(), "web", "static"), filepath.Join(path, "web", "static")); err != nil {
log.Fatalln(err)
}
if err := CopyDir(filepath.Join(GetAppPath(), "conf"), filepath.Join(path, "conf")); err != nil {
if err := CopyDir(filepath.Join(common.GetAppPath(), "conf"), filepath.Join(path, "conf")); err != nil {
log.Fatalln(err)
}
if !IsWindows() {
if _, err := copyFile(filepath.Join(GetAppPath(), "nps"), "/usr/bin/nps"); err != nil {
if _, err := copyFile(filepath.Join(GetAppPath(), "nps"), "/usr/local/bin/nps"); err != nil {
if !common.IsWindows() {
if _, err := copyFile(filepath.Join(common.GetAppPath(), "nps"), "/usr/bin/nps"); err != nil {
if _, err := copyFile(filepath.Join(common.GetAppPath(), "nps"), "/usr/local/bin/nps"); err != nil {
log.Fatalln(err)
} else {
os.Chmod("/usr/local/bin/nps", 0777)
@ -41,7 +42,7 @@ func InstallNps() {
log.Println("install ok!")
log.Println("Static files and configuration files in the current directory will be useless")
log.Println("The new configuration file is located in", path, "you can edit them")
if !IsWindows() {
if !common.IsWindows() {
log.Println("You can start with nps test|start|stop|restart|status anywhere")
} else {
log.Println("You can copy executable files to any directory and start working with nps.exe test|start|stop|restart|status")

785
lib/kcp/crypt.go Normal file
View File

@ -0,0 +1,785 @@
package kcp
import (
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/sha1"
"github.com/templexxx/xor"
"github.com/tjfoc/gmsm/sm4"
"golang.org/x/crypto/blowfish"
"golang.org/x/crypto/cast5"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/crypto/salsa20"
"golang.org/x/crypto/tea"
"golang.org/x/crypto/twofish"
"golang.org/x/crypto/xtea"
)
var (
initialVector = []byte{167, 115, 79, 156, 18, 172, 27, 1, 164, 21, 242, 193, 252, 120, 230, 107}
saltxor = `sH3CIVoF#rWLtJo6`
)
// BlockCrypt defines encryption/decryption methods for a given byte slice.
// Notes on implementing: the data to be encrypted contains a builtin
// nonce at the first 16 bytes
type BlockCrypt interface {
// Encrypt encrypts the whole block in src into dst.
// Dst and src may point at the same memory.
Encrypt(dst, src []byte)
// Decrypt decrypts the whole block in src into dst.
// Dst and src may point at the same memory.
Decrypt(dst, src []byte)
}
type salsa20BlockCrypt struct {
key [32]byte
}
// NewSalsa20BlockCrypt https://en.wikipedia.org/wiki/Salsa20
func NewSalsa20BlockCrypt(key []byte) (BlockCrypt, error) {
c := new(salsa20BlockCrypt)
copy(c.key[:], key)
return c, nil
}
func (c *salsa20BlockCrypt) Encrypt(dst, src []byte) {
salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key)
copy(dst[:8], src[:8])
}
func (c *salsa20BlockCrypt) Decrypt(dst, src []byte) {
salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key)
copy(dst[:8], src[:8])
}
type sm4BlockCrypt struct {
encbuf [sm4.BlockSize]byte
decbuf [2 * sm4.BlockSize]byte
block cipher.Block
}
// NewSM4BlockCrypt https://github.com/tjfoc/gmsm/tree/master/sm4
func NewSM4BlockCrypt(key []byte) (BlockCrypt, error) {
c := new(sm4BlockCrypt)
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *sm4BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *sm4BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type twofishBlockCrypt struct {
encbuf [twofish.BlockSize]byte
decbuf [2 * twofish.BlockSize]byte
block cipher.Block
}
// NewTwofishBlockCrypt https://en.wikipedia.org/wiki/Twofish
func NewTwofishBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(twofishBlockCrypt)
block, err := twofish.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *twofishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *twofishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type tripleDESBlockCrypt struct {
encbuf [des.BlockSize]byte
decbuf [2 * des.BlockSize]byte
block cipher.Block
}
// NewTripleDESBlockCrypt https://en.wikipedia.org/wiki/Triple_DES
func NewTripleDESBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(tripleDESBlockCrypt)
block, err := des.NewTripleDESCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *tripleDESBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *tripleDESBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type cast5BlockCrypt struct {
encbuf [cast5.BlockSize]byte
decbuf [2 * cast5.BlockSize]byte
block cipher.Block
}
// NewCast5BlockCrypt https://en.wikipedia.org/wiki/CAST-128
func NewCast5BlockCrypt(key []byte) (BlockCrypt, error) {
c := new(cast5BlockCrypt)
block, err := cast5.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *cast5BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *cast5BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type blowfishBlockCrypt struct {
encbuf [blowfish.BlockSize]byte
decbuf [2 * blowfish.BlockSize]byte
block cipher.Block
}
// NewBlowfishBlockCrypt https://en.wikipedia.org/wiki/Blowfish_(cipher)
func NewBlowfishBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(blowfishBlockCrypt)
block, err := blowfish.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *blowfishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *blowfishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type aesBlockCrypt struct {
encbuf [aes.BlockSize]byte
decbuf [2 * aes.BlockSize]byte
block cipher.Block
}
// NewAESBlockCrypt https://en.wikipedia.org/wiki/Advanced_Encryption_Standard
func NewAESBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(aesBlockCrypt)
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *aesBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *aesBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type teaBlockCrypt struct {
encbuf [tea.BlockSize]byte
decbuf [2 * tea.BlockSize]byte
block cipher.Block
}
// NewTEABlockCrypt https://en.wikipedia.org/wiki/Tiny_Encryption_Algorithm
func NewTEABlockCrypt(key []byte) (BlockCrypt, error) {
c := new(teaBlockCrypt)
block, err := tea.NewCipherWithRounds(key, 16)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *teaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *teaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type xteaBlockCrypt struct {
encbuf [xtea.BlockSize]byte
decbuf [2 * xtea.BlockSize]byte
block cipher.Block
}
// NewXTEABlockCrypt https://en.wikipedia.org/wiki/XTEA
func NewXTEABlockCrypt(key []byte) (BlockCrypt, error) {
c := new(xteaBlockCrypt)
block, err := xtea.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *xteaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *xteaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type simpleXORBlockCrypt struct {
xortbl []byte
}
// NewSimpleXORBlockCrypt simple xor with key expanding
func NewSimpleXORBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(simpleXORBlockCrypt)
c.xortbl = pbkdf2.Key(key, []byte(saltxor), 32, mtuLimit, sha1.New)
return c, nil
}
func (c *simpleXORBlockCrypt) Encrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) }
func (c *simpleXORBlockCrypt) Decrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) }
type noneBlockCrypt struct{}
// NewNoneBlockCrypt does nothing but copying
func NewNoneBlockCrypt(key []byte) (BlockCrypt, error) {
return new(noneBlockCrypt), nil
}
func (c *noneBlockCrypt) Encrypt(dst, src []byte) { copy(dst, src) }
func (c *noneBlockCrypt) Decrypt(dst, src []byte) { copy(dst, src) }
// packet encryption with local CFB mode
func encrypt(block cipher.Block, dst, src, buf []byte) {
switch block.BlockSize() {
case 8:
encrypt8(block, dst, src, buf)
case 16:
encrypt16(block, dst, src, buf)
default:
encryptVariant(block, dst, src, buf)
}
}
// optimized encryption for the ciphers which works in 8-bytes
func encrypt8(block cipher.Block, dst, src, buf []byte) {
tbl := buf[:8]
block.Encrypt(tbl, initialVector)
n := len(src) / 8
base := 0
repeat := n / 8
left := n % 8
for i := 0; i < repeat; i++ {
s := src[base:][0:64]
d := dst[base:][0:64]
// 1
xor.BytesSrc1(d[0:8], s[0:8], tbl)
block.Encrypt(tbl, d[0:8])
// 2
xor.BytesSrc1(d[8:16], s[8:16], tbl)
block.Encrypt(tbl, d[8:16])
// 3
xor.BytesSrc1(d[16:24], s[16:24], tbl)
block.Encrypt(tbl, d[16:24])
// 4
xor.BytesSrc1(d[24:32], s[24:32], tbl)
block.Encrypt(tbl, d[24:32])
// 5
xor.BytesSrc1(d[32:40], s[32:40], tbl)
block.Encrypt(tbl, d[32:40])
// 6
xor.BytesSrc1(d[40:48], s[40:48], tbl)
block.Encrypt(tbl, d[40:48])
// 7
xor.BytesSrc1(d[48:56], s[48:56], tbl)
block.Encrypt(tbl, d[48:56])
// 8
xor.BytesSrc1(d[56:64], s[56:64], tbl)
block.Encrypt(tbl, d[56:64])
base += 64
}
switch left {
case 7:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 6:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 5:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 4:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 3:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 2:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 1:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 0:
xor.BytesSrc0(dst[base:], src[base:], tbl)
}
}
// optimized encryption for the ciphers which works in 16-bytes
func encrypt16(block cipher.Block, dst, src, buf []byte) {
tbl := buf[:16]
block.Encrypt(tbl, initialVector)
n := len(src) / 16
base := 0
repeat := n / 8
left := n % 8
for i := 0; i < repeat; i++ {
s := src[base:][0:128]
d := dst[base:][0:128]
// 1
xor.BytesSrc1(d[0:16], s[0:16], tbl)
block.Encrypt(tbl, d[0:16])
// 2
xor.BytesSrc1(d[16:32], s[16:32], tbl)
block.Encrypt(tbl, d[16:32])
// 3
xor.BytesSrc1(d[32:48], s[32:48], tbl)
block.Encrypt(tbl, d[32:48])
// 4
xor.BytesSrc1(d[48:64], s[48:64], tbl)
block.Encrypt(tbl, d[48:64])
// 5
xor.BytesSrc1(d[64:80], s[64:80], tbl)
block.Encrypt(tbl, d[64:80])
// 6
xor.BytesSrc1(d[80:96], s[80:96], tbl)
block.Encrypt(tbl, d[80:96])
// 7
xor.BytesSrc1(d[96:112], s[96:112], tbl)
block.Encrypt(tbl, d[96:112])
// 8
xor.BytesSrc1(d[112:128], s[112:128], tbl)
block.Encrypt(tbl, d[112:128])
base += 128
}
switch left {
case 7:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 6:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 5:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 4:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 3:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 2:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 1:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 0:
xor.BytesSrc0(dst[base:], src[base:], tbl)
}
}
func encryptVariant(block cipher.Block, dst, src, buf []byte) {
blocksize := block.BlockSize()
tbl := buf[:blocksize]
block.Encrypt(tbl, initialVector)
n := len(src) / blocksize
base := 0
repeat := n / 8
left := n % 8
for i := 0; i < repeat; i++ {
// 1
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
// 2
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
// 3
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
// 4
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
// 5
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
// 6
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
// 7
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
// 8
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
}
switch left {
case 7:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
fallthrough
case 6:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
fallthrough
case 5:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
fallthrough
case 4:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
fallthrough
case 3:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
fallthrough
case 2:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
fallthrough
case 1:
xor.BytesSrc1(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += blocksize
fallthrough
case 0:
xor.BytesSrc0(dst[base:], src[base:], tbl)
}
}
// decryption
func decrypt(block cipher.Block, dst, src, buf []byte) {
switch block.BlockSize() {
case 8:
decrypt8(block, dst, src, buf)
case 16:
decrypt16(block, dst, src, buf)
default:
decryptVariant(block, dst, src, buf)
}
}
func decrypt8(block cipher.Block, dst, src, buf []byte) {
tbl := buf[0:8]
next := buf[8:16]
block.Encrypt(tbl, initialVector)
n := len(src) / 8
base := 0
repeat := n / 8
left := n % 8
for i := 0; i < repeat; i++ {
s := src[base:][0:64]
d := dst[base:][0:64]
// 1
block.Encrypt(next, s[0:8])
xor.BytesSrc1(d[0:8], s[0:8], tbl)
// 2
block.Encrypt(tbl, s[8:16])
xor.BytesSrc1(d[8:16], s[8:16], next)
// 3
block.Encrypt(next, s[16:24])
xor.BytesSrc1(d[16:24], s[16:24], tbl)
// 4
block.Encrypt(tbl, s[24:32])
xor.BytesSrc1(d[24:32], s[24:32], next)
// 5
block.Encrypt(next, s[32:40])
xor.BytesSrc1(d[32:40], s[32:40], tbl)
// 6
block.Encrypt(tbl, s[40:48])
xor.BytesSrc1(d[40:48], s[40:48], next)
// 7
block.Encrypt(next, s[48:56])
xor.BytesSrc1(d[48:56], s[48:56], tbl)
// 8
block.Encrypt(tbl, s[56:64])
xor.BytesSrc1(d[56:64], s[56:64], next)
base += 64
}
switch left {
case 7:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 8
fallthrough
case 6:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 8
fallthrough
case 5:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 8
fallthrough
case 4:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 8
fallthrough
case 3:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 8
fallthrough
case 2:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 8
fallthrough
case 1:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 8
fallthrough
case 0:
xor.BytesSrc0(dst[base:], src[base:], tbl)
}
}
func decrypt16(block cipher.Block, dst, src, buf []byte) {
tbl := buf[0:16]
next := buf[16:32]
block.Encrypt(tbl, initialVector)
n := len(src) / 16
base := 0
repeat := n / 8
left := n % 8
for i := 0; i < repeat; i++ {
s := src[base:][0:128]
d := dst[base:][0:128]
// 1
block.Encrypt(next, s[0:16])
xor.BytesSrc1(d[0:16], s[0:16], tbl)
// 2
block.Encrypt(tbl, s[16:32])
xor.BytesSrc1(d[16:32], s[16:32], next)
// 3
block.Encrypt(next, s[32:48])
xor.BytesSrc1(d[32:48], s[32:48], tbl)
// 4
block.Encrypt(tbl, s[48:64])
xor.BytesSrc1(d[48:64], s[48:64], next)
// 5
block.Encrypt(next, s[64:80])
xor.BytesSrc1(d[64:80], s[64:80], tbl)
// 6
block.Encrypt(tbl, s[80:96])
xor.BytesSrc1(d[80:96], s[80:96], next)
// 7
block.Encrypt(next, s[96:112])
xor.BytesSrc1(d[96:112], s[96:112], tbl)
// 8
block.Encrypt(tbl, s[112:128])
xor.BytesSrc1(d[112:128], s[112:128], next)
base += 128
}
switch left {
case 7:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 6:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 5:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 4:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 3:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 2:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 1:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 0:
xor.BytesSrc0(dst[base:], src[base:], tbl)
}
}
func decryptVariant(block cipher.Block, dst, src, buf []byte) {
blocksize := block.BlockSize()
tbl := buf[:blocksize]
next := buf[blocksize:]
block.Encrypt(tbl, initialVector)
n := len(src) / blocksize
base := 0
repeat := n / 8
left := n % 8
for i := 0; i < repeat; i++ {
// 1
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
base += blocksize
// 2
block.Encrypt(tbl, src[base:])
xor.BytesSrc1(dst[base:], src[base:], next)
base += blocksize
// 3
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
base += blocksize
// 4
block.Encrypt(tbl, src[base:])
xor.BytesSrc1(dst[base:], src[base:], next)
base += blocksize
// 5
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
base += blocksize
// 6
block.Encrypt(tbl, src[base:])
xor.BytesSrc1(dst[base:], src[base:], next)
base += blocksize
// 7
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
base += blocksize
// 8
block.Encrypt(tbl, src[base:])
xor.BytesSrc1(dst[base:], src[base:], next)
base += blocksize
}
switch left {
case 7:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += blocksize
fallthrough
case 6:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += blocksize
fallthrough
case 5:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += blocksize
fallthrough
case 4:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += blocksize
fallthrough
case 3:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += blocksize
fallthrough
case 2:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += blocksize
fallthrough
case 1:
block.Encrypt(next, src[base:])
xor.BytesSrc1(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += blocksize
fallthrough
case 0:
xor.BytesSrc0(dst[base:], src[base:], tbl)
}
}

289
lib/kcp/crypt_test.go Normal file
View File

@ -0,0 +1,289 @@
package kcp
import (
"bytes"
"crypto/aes"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"hash/crc32"
"io"
"testing"
)
func TestSM4(t *testing.T) {
bc, err := NewSM4BlockCrypt(pass[:16])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestAES(t *testing.T) {
bc, err := NewAESBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestTEA(t *testing.T) {
bc, err := NewTEABlockCrypt(pass[:16])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestXOR(t *testing.T) {
bc, err := NewSimpleXORBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestBlowfish(t *testing.T) {
bc, err := NewBlowfishBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestNone(t *testing.T) {
bc, err := NewNoneBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestCast5(t *testing.T) {
bc, err := NewCast5BlockCrypt(pass[:16])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func Test3DES(t *testing.T) {
bc, err := NewTripleDESBlockCrypt(pass[:24])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestTwofish(t *testing.T) {
bc, err := NewTwofishBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestXTEA(t *testing.T) {
bc, err := NewXTEABlockCrypt(pass[:16])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestSalsa20(t *testing.T) {
bc, err := NewSalsa20BlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func cryptTest(t *testing.T, bc BlockCrypt) {
data := make([]byte, mtuLimit)
io.ReadFull(rand.Reader, data)
dec := make([]byte, mtuLimit)
enc := make([]byte, mtuLimit)
bc.Encrypt(enc, data)
bc.Decrypt(dec, enc)
if !bytes.Equal(data, dec) {
t.Fail()
}
}
func BenchmarkSM4(b *testing.B) {
bc, err := NewSM4BlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkAES128(b *testing.B) {
bc, err := NewAESBlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkAES192(b *testing.B) {
bc, err := NewAESBlockCrypt(pass[:24])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkAES256(b *testing.B) {
bc, err := NewAESBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkTEA(b *testing.B) {
bc, err := NewTEABlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkXOR(b *testing.B) {
bc, err := NewSimpleXORBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkBlowfish(b *testing.B) {
bc, err := NewBlowfishBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkNone(b *testing.B) {
bc, err := NewNoneBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkCast5(b *testing.B) {
bc, err := NewCast5BlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func Benchmark3DES(b *testing.B) {
bc, err := NewTripleDESBlockCrypt(pass[:24])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkTwofish(b *testing.B) {
bc, err := NewTwofishBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkXTEA(b *testing.B) {
bc, err := NewXTEABlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkSalsa20(b *testing.B) {
bc, err := NewSalsa20BlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func benchCrypt(b *testing.B, bc BlockCrypt) {
data := make([]byte, mtuLimit)
io.ReadFull(rand.Reader, data)
dec := make([]byte, mtuLimit)
enc := make([]byte, mtuLimit)
b.ReportAllocs()
b.SetBytes(int64(len(enc) * 2))
b.ResetTimer()
for i := 0; i < b.N; i++ {
bc.Encrypt(enc, data)
bc.Decrypt(dec, enc)
}
}
func BenchmarkCRC32(b *testing.B) {
content := make([]byte, 1024)
b.SetBytes(int64(len(content)))
for i := 0; i < b.N; i++ {
crc32.ChecksumIEEE(content)
}
}
func BenchmarkCsprngSystem(b *testing.B) {
data := make([]byte, md5.Size)
b.SetBytes(int64(len(data)))
for i := 0; i < b.N; i++ {
io.ReadFull(rand.Reader, data)
}
}
func BenchmarkCsprngMD5(b *testing.B) {
var data [md5.Size]byte
b.SetBytes(md5.Size)
for i := 0; i < b.N; i++ {
data = md5.Sum(data[:])
}
}
func BenchmarkCsprngSHA1(b *testing.B) {
var data [sha1.Size]byte
b.SetBytes(sha1.Size)
for i := 0; i < b.N; i++ {
data = sha1.Sum(data[:])
}
}
func BenchmarkCsprngNonceMD5(b *testing.B) {
var ng nonceMD5
ng.Init()
b.SetBytes(md5.Size)
data := make([]byte, md5.Size)
for i := 0; i < b.N; i++ {
ng.Fill(data)
}
}
func BenchmarkCsprngNonceAES128(b *testing.B) {
var ng nonceAES128
ng.Init()
b.SetBytes(aes.BlockSize)
data := make([]byte, aes.BlockSize)
for i := 0; i < b.N; i++ {
ng.Fill(data)
}
}

52
lib/kcp/entropy.go Normal file
View File

@ -0,0 +1,52 @@
package kcp
import (
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/rand"
"io"
)
// Entropy defines a entropy source
type Entropy interface {
Init()
Fill(nonce []byte)
}
// nonceMD5 nonce generator for packet header
type nonceMD5 struct {
seed [md5.Size]byte
}
func (n *nonceMD5) Init() { /*nothing required*/ }
func (n *nonceMD5) Fill(nonce []byte) {
if n.seed[0] == 0 { // entropy update
io.ReadFull(rand.Reader, n.seed[:])
}
n.seed = md5.Sum(n.seed[:])
copy(nonce, n.seed[:])
}
// nonceAES128 nonce generator for packet headers
type nonceAES128 struct {
seed [aes.BlockSize]byte
block cipher.Block
}
func (n *nonceAES128) Init() {
var key [16]byte //aes-128
io.ReadFull(rand.Reader, key[:])
io.ReadFull(rand.Reader, n.seed[:])
block, _ := aes.NewCipher(key[:])
n.block = block
}
func (n *nonceAES128) Fill(nonce []byte) {
if n.seed[0] == 0 { // entropy update
io.ReadFull(rand.Reader, n.seed[:])
}
n.block.Encrypt(n.seed[:], n.seed[:])
copy(nonce, n.seed[:])
}

311
lib/kcp/fec.go Normal file
View File

@ -0,0 +1,311 @@
package kcp
import (
"encoding/binary"
"sync/atomic"
"github.com/klauspost/reedsolomon"
)
const (
fecHeaderSize = 6
fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size
typeData = 0xf1
typeFEC = 0xf2
)
type (
// fecPacket is a decoded FEC packet
fecPacket struct {
seqid uint32
flag uint16
data []byte
}
// fecDecoder for decoding incoming packets
fecDecoder struct {
rxlimit int // queue size limit
dataShards int
parityShards int
shardSize int
rx []fecPacket // ordered receive queue
// caches
decodeCache [][]byte
flagCache []bool
// zeros
zeros []byte
// RS decoder
codec reedsolomon.Encoder
}
)
func newFECDecoder(rxlimit, dataShards, parityShards int) *fecDecoder {
if dataShards <= 0 || parityShards <= 0 {
return nil
}
if rxlimit < dataShards+parityShards {
return nil
}
dec := new(fecDecoder)
dec.rxlimit = rxlimit
dec.dataShards = dataShards
dec.parityShards = parityShards
dec.shardSize = dataShards + parityShards
codec, err := reedsolomon.New(dataShards, parityShards)
if err != nil {
return nil
}
dec.codec = codec
dec.decodeCache = make([][]byte, dec.shardSize)
dec.flagCache = make([]bool, dec.shardSize)
dec.zeros = make([]byte, mtuLimit)
return dec
}
// decodeBytes a fec packet
func (dec *fecDecoder) decodeBytes(data []byte) fecPacket {
var pkt fecPacket
pkt.seqid = binary.LittleEndian.Uint32(data)
pkt.flag = binary.LittleEndian.Uint16(data[4:])
// allocate memory & copy
buf := xmitBuf.Get().([]byte)[:len(data)-6]
copy(buf, data[6:])
pkt.data = buf
return pkt
}
// decode a fec packet
func (dec *fecDecoder) decode(pkt fecPacket) (recovered [][]byte) {
// insertion
n := len(dec.rx) - 1
insertIdx := 0
for i := n; i >= 0; i-- {
if pkt.seqid == dec.rx[i].seqid { // de-duplicate
xmitBuf.Put(pkt.data)
return nil
} else if _itimediff(pkt.seqid, dec.rx[i].seqid) > 0 { // insertion
insertIdx = i + 1
break
}
}
// insert into ordered rx queue
if insertIdx == n+1 {
dec.rx = append(dec.rx, pkt)
} else {
dec.rx = append(dec.rx, fecPacket{})
copy(dec.rx[insertIdx+1:], dec.rx[insertIdx:]) // shift right
dec.rx[insertIdx] = pkt
}
// shard range for current packet
shardBegin := pkt.seqid - pkt.seqid%uint32(dec.shardSize)
shardEnd := shardBegin + uint32(dec.shardSize) - 1
// max search range in ordered queue for current shard
searchBegin := insertIdx - int(pkt.seqid%uint32(dec.shardSize))
if searchBegin < 0 {
searchBegin = 0
}
searchEnd := searchBegin + dec.shardSize - 1
if searchEnd >= len(dec.rx) {
searchEnd = len(dec.rx) - 1
}
// re-construct datashards
if searchEnd-searchBegin+1 >= dec.dataShards {
var numshard, numDataShard, first, maxlen int
// zero caches
shards := dec.decodeCache
shardsflag := dec.flagCache
for k := range dec.decodeCache {
shards[k] = nil
shardsflag[k] = false
}
// shard assembly
for i := searchBegin; i <= searchEnd; i++ {
seqid := dec.rx[i].seqid
if _itimediff(seqid, shardEnd) > 0 {
break
} else if _itimediff(seqid, shardBegin) >= 0 {
shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data
shardsflag[seqid%uint32(dec.shardSize)] = true
numshard++
if dec.rx[i].flag == typeData {
numDataShard++
}
if numshard == 1 {
first = i
}
if len(dec.rx[i].data) > maxlen {
maxlen = len(dec.rx[i].data)
}
}
}
if numDataShard == dec.dataShards {
// case 1: no loss on data shards
dec.rx = dec.freeRange(first, numshard, dec.rx)
} else if numshard >= dec.dataShards {
// case 2: loss on data shards, but it's recoverable from parity shards
for k := range shards {
if shards[k] != nil {
dlen := len(shards[k])
shards[k] = shards[k][:maxlen]
copy(shards[k][dlen:], dec.zeros)
}
}
if err := dec.codec.ReconstructData(shards); err == nil {
for k := range shards[:dec.dataShards] {
if !shardsflag[k] {
recovered = append(recovered, shards[k])
}
}
}
dec.rx = dec.freeRange(first, numshard, dec.rx)
}
}
// keep rxlimit
if len(dec.rx) > dec.rxlimit {
if dec.rx[0].flag == typeData { // track the unrecoverable data
atomic.AddUint64(&DefaultSnmp.FECShortShards, 1)
}
dec.rx = dec.freeRange(0, 1, dec.rx)
}
return
}
// free a range of fecPacket, and zero for GC recycling
func (dec *fecDecoder) freeRange(first, n int, q []fecPacket) []fecPacket {
for i := first; i < first+n; i++ { // recycle buffer
xmitBuf.Put(q[i].data)
}
copy(q[first:], q[first+n:])
for i := 0; i < n; i++ { // dereference data
q[len(q)-1-i].data = nil
}
return q[:len(q)-n]
}
type (
// fecEncoder for encoding outgoing packets
fecEncoder struct {
dataShards int
parityShards int
shardSize int
paws uint32 // Protect Against Wrapped Sequence numbers
next uint32 // next seqid
shardCount int // count the number of datashards collected
maxSize int // track maximum data length in datashard
headerOffset int // FEC header offset
payloadOffset int // FEC payload offset
// caches
shardCache [][]byte
encodeCache [][]byte
// zeros
zeros []byte
// RS encoder
codec reedsolomon.Encoder
}
)
func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder {
if dataShards <= 0 || parityShards <= 0 {
return nil
}
enc := new(fecEncoder)
enc.dataShards = dataShards
enc.parityShards = parityShards
enc.shardSize = dataShards + parityShards
enc.paws = (0xffffffff/uint32(enc.shardSize) - 1) * uint32(enc.shardSize)
enc.headerOffset = offset
enc.payloadOffset = enc.headerOffset + fecHeaderSize
codec, err := reedsolomon.New(dataShards, parityShards)
if err != nil {
return nil
}
enc.codec = codec
// caches
enc.encodeCache = make([][]byte, enc.shardSize)
enc.shardCache = make([][]byte, enc.shardSize)
for k := range enc.shardCache {
enc.shardCache[k] = make([]byte, mtuLimit)
}
enc.zeros = make([]byte, mtuLimit)
return enc
}
// encodes the packet, outputs parity shards if we have collected quorum datashards
// notice: the contents of 'ps' will be re-written in successive calling
func (enc *fecEncoder) encode(b []byte) (ps [][]byte) {
enc.markData(b[enc.headerOffset:])
binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:])))
// copy data to fec datashards
sz := len(b)
enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz]
copy(enc.shardCache[enc.shardCount], b)
enc.shardCount++
// track max datashard length
if sz > enc.maxSize {
enc.maxSize = sz
}
// Generation of Reed-Solomon Erasure Code
if enc.shardCount == enc.dataShards {
// fill '0' into the tail of each datashard
for i := 0; i < enc.dataShards; i++ {
shard := enc.shardCache[i]
slen := len(shard)
copy(shard[slen:enc.maxSize], enc.zeros)
}
// construct equal-sized slice with stripped header
cache := enc.encodeCache
for k := range cache {
cache[k] = enc.shardCache[k][enc.payloadOffset:enc.maxSize]
}
// encoding
if err := enc.codec.Encode(cache); err == nil {
ps = enc.shardCache[enc.dataShards:]
for k := range ps {
enc.markFEC(ps[k][enc.headerOffset:])
ps[k] = ps[k][:enc.maxSize]
}
}
// counters resetting
enc.shardCount = 0
enc.maxSize = 0
}
return
}
func (enc *fecEncoder) markData(data []byte) {
binary.LittleEndian.PutUint32(data, enc.next)
binary.LittleEndian.PutUint16(data[4:], typeData)
enc.next++
}
func (enc *fecEncoder) markFEC(data []byte) {
binary.LittleEndian.PutUint32(data, enc.next)
binary.LittleEndian.PutUint16(data[4:], typeFEC)
enc.next = (enc.next + 1) % enc.paws
}

43
lib/kcp/fec_test.go Normal file
View File

@ -0,0 +1,43 @@
package kcp
import (
"math/rand"
"testing"
)
func BenchmarkFECDecode(b *testing.B) {
const dataSize = 10
const paritySize = 3
const payLoad = 1500
decoder := newFECDecoder(1024, dataSize, paritySize)
b.ReportAllocs()
b.SetBytes(payLoad)
for i := 0; i < b.N; i++ {
if rand.Int()%(dataSize+paritySize) == 0 { // random loss
continue
}
var pkt fecPacket
pkt.seqid = uint32(i)
if i%(dataSize+paritySize) >= dataSize {
pkt.flag = typeFEC
} else {
pkt.flag = typeData
}
pkt.data = make([]byte, payLoad)
decoder.decode(pkt)
}
}
func BenchmarkFECEncode(b *testing.B) {
const dataSize = 10
const paritySize = 3
const payLoad = 1500
b.ReportAllocs()
b.SetBytes(payLoad)
encoder := newFECEncoder(dataSize, paritySize, 0)
for i := 0; i < b.N; i++ {
data := make([]byte, payLoad)
encoder.encode(data)
}
}

1012
lib/kcp/kcp.go Normal file

File diff suppressed because it is too large Load Diff

302
lib/kcp/kcp_test.go Normal file
View File

@ -0,0 +1,302 @@
package kcp
import (
"bytes"
"container/list"
"encoding/binary"
"fmt"
"math/rand"
"sync"
"testing"
"time"
)
func iclock() int32 {
return int32(currentMs())
}
type DelayPacket struct {
_ptr []byte
_size int
_ts int32
}
func (p *DelayPacket) Init(size int, src []byte) {
p._ptr = make([]byte, size)
p._size = size
copy(p._ptr, src[:size])
}
func (p *DelayPacket) ptr() []byte { return p._ptr }
func (p *DelayPacket) size() int { return p._size }
func (p *DelayPacket) ts() int32 { return p._ts }
func (p *DelayPacket) setts(ts int32) { p._ts = ts }
type DelayTunnel struct{ *list.List }
type LatencySimulator struct {
current int32
lostrate, rttmin, rttmax, nmax int
p12 DelayTunnel
p21 DelayTunnel
r12 *rand.Rand
r21 *rand.Rand
}
// lostrate: 往返一周丢包率的百分比,默认 10%
// rttminrtt最小值默认 60
// rttmaxrtt最大值默认 125
//func (p *LatencySimulator)Init(int lostrate = 10, int rttmin = 60, int rttmax = 125, int nmax = 1000):
func (p *LatencySimulator) Init(lostrate, rttmin, rttmax, nmax int) {
p.r12 = rand.New(rand.NewSource(9))
p.r21 = rand.New(rand.NewSource(99))
p.p12 = DelayTunnel{list.New()}
p.p21 = DelayTunnel{list.New()}
p.current = iclock()
p.lostrate = lostrate / 2 // 上面数据是往返丢包率单程除以2
p.rttmin = rttmin / 2
p.rttmax = rttmax / 2
p.nmax = nmax
}
// 发送数据
// peer - 端点0/1从0发送从1接收从1发送从0接收
func (p *LatencySimulator) send(peer int, data []byte, size int) int {
rnd := 0
if peer == 0 {
rnd = p.r12.Intn(100)
} else {
rnd = p.r21.Intn(100)
}
//println("!!!!!!!!!!!!!!!!!!!!", rnd, p.lostrate, peer)
if rnd < p.lostrate {
return 0
}
pkt := &DelayPacket{}
pkt.Init(size, data)
p.current = iclock()
delay := p.rttmin
if p.rttmax > p.rttmin {
delay += rand.Int() % (p.rttmax - p.rttmin)
}
pkt.setts(p.current + int32(delay))
if peer == 0 {
p.p12.PushBack(pkt)
} else {
p.p21.PushBack(pkt)
}
return 1
}
// 接收数据
func (p *LatencySimulator) recv(peer int, data []byte, maxsize int) int32 {
var it *list.Element
if peer == 0 {
it = p.p21.Front()
if p.p21.Len() == 0 {
return -1
}
} else {
it = p.p12.Front()
if p.p12.Len() == 0 {
return -1
}
}
pkt := it.Value.(*DelayPacket)
p.current = iclock()
if p.current < pkt.ts() {
return -2
}
if maxsize < pkt.size() {
return -3
}
if peer == 0 {
p.p21.Remove(it)
} else {
p.p12.Remove(it)
}
maxsize = pkt.size()
copy(data, pkt.ptr()[:maxsize])
return int32(maxsize)
}
//=====================================================================
//=====================================================================
// 模拟网络
var vnet *LatencySimulator
// 测试用例
func test(mode int) {
// 创建模拟网络丢包率10%Rtt 60ms~125ms
vnet = &LatencySimulator{}
vnet.Init(10, 60, 125, 1000)
// 创建两个端点的 kcp对象第一个参数 conv是会话编号同一个会话需要相同
// 最后一个是 user参数用来传递标识
output1 := func(buf []byte, size int) {
if vnet.send(0, buf, size) != 1 {
}
}
output2 := func(buf []byte, size int) {
if vnet.send(1, buf, size) != 1 {
}
}
kcp1 := NewKCP(0x11223344, output1)
kcp2 := NewKCP(0x11223344, output2)
current := uint32(iclock())
slap := current + 20
index := 0
next := 0
var sumrtt uint32
count := 0
maxrtt := 0
// 配置窗口大小平均延迟200ms每20ms发送一个包
// 而考虑到丢包重发设置最大收发窗口为128
kcp1.WndSize(128, 128)
kcp2.WndSize(128, 128)
// 判断测试用例的模式
if mode == 0 {
// 默认模式
kcp1.NoDelay(0, 10, 0, 0)
kcp2.NoDelay(0, 10, 0, 0)
} else if mode == 1 {
// 普通模式,关闭流控等
kcp1.NoDelay(0, 10, 0, 1)
kcp2.NoDelay(0, 10, 0, 1)
} else {
// 启动快速模式
// 第二个参数 nodelay-启用以后若干常规加速将启动
// 第三个参数 interval为内部处理时钟默认设置为 10ms
// 第四个参数 resend为快速重传指标设置为2
// 第五个参数 为是否禁用常规流控,这里禁止
kcp1.NoDelay(1, 10, 2, 1)
kcp2.NoDelay(1, 10, 2, 1)
}
buffer := make([]byte, 2000)
var hr int32
ts1 := iclock()
for {
time.Sleep(1 * time.Millisecond)
current = uint32(iclock())
kcp1.Update()
kcp2.Update()
// 每隔 20mskcp1发送数据
for ; current >= slap; slap += 20 {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(index))
index++
binary.Write(buf, binary.LittleEndian, uint32(current))
// 发送上层协议包
kcp1.Send(buf.Bytes())
//println("now", iclock())
}
// 处理虚拟网络检测是否有udp包从p1->p2
for {
hr = vnet.recv(1, buffer, 2000)
if hr < 0 {
break
}
// 如果 p2收到udp则作为下层协议输入到kcp2
kcp2.Input(buffer[:hr], true, false)
}
// 处理虚拟网络检测是否有udp包从p2->p1
for {
hr = vnet.recv(0, buffer, 2000)
if hr < 0 {
break
}
// 如果 p1收到udp则作为下层协议输入到kcp1
kcp1.Input(buffer[:hr], true, false)
//println("@@@@", hr, r)
}
// kcp2接收到任何包都返回回去
for {
hr = int32(kcp2.Recv(buffer[:10]))
// 没有收到包就退出
if hr < 0 {
break
}
// 如果收到包就回射
buf := bytes.NewReader(buffer)
var sn uint32
binary.Read(buf, binary.LittleEndian, &sn)
kcp2.Send(buffer[:hr])
}
// kcp1收到kcp2的回射数据
for {
hr = int32(kcp1.Recv(buffer[:10]))
buf := bytes.NewReader(buffer)
// 没有收到包就退出
if hr < 0 {
break
}
var sn uint32
var ts, rtt uint32
binary.Read(buf, binary.LittleEndian, &sn)
binary.Read(buf, binary.LittleEndian, &ts)
rtt = uint32(current) - ts
if sn != uint32(next) {
// 如果收到的包不连续
//for i:=0;i<8 ;i++ {
//println("---", i, buffer[i])
//}
println("ERROR sn ", count, "<->", next, sn)
return
}
next++
sumrtt += rtt
count++
if rtt > uint32(maxrtt) {
maxrtt = int(rtt)
}
//println("[RECV] mode=", mode, " sn=", sn, " rtt=", rtt)
}
if next > 100 {
break
}
}
ts1 = iclock() - ts1
names := []string{"default", "normal", "fast"}
fmt.Printf("%s mode result (%dms):\n", names[mode], ts1)
fmt.Printf("avgrtt=%d maxrtt=%d\n", int(sumrtt/uint32(count)), maxrtt)
}
func TestNetwork(t *testing.T) {
test(0) // 默认模式,类似 TCP正常模式无快速重传常规流控
test(1) // 普通模式,关闭流控等
test(2) // 快速模式,所有开关都打开,且关闭流控
}
func BenchmarkFlush(b *testing.B) {
kcp := NewKCP(1, func(buf []byte, size int) {})
kcp.snd_buf = make([]segment, 1024)
for k := range kcp.snd_buf {
kcp.snd_buf[k].xmit = 1
kcp.snd_buf[k].resendts = currentMs() + 10000
}
b.ResetTimer()
b.ReportAllocs()
var mu sync.Mutex
for i := 0; i < b.N; i++ {
mu.Lock()
kcp.flush(false)
mu.Unlock()
}
}

963
lib/kcp/sess.go Normal file
View File

@ -0,0 +1,963 @@
package kcp
import (
"crypto/rand"
"encoding/binary"
"hash/crc32"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pkg/errors"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type errTimeout struct {
error
}
func (errTimeout) Timeout() bool { return true }
func (errTimeout) Temporary() bool { return true }
func (errTimeout) Error() string { return "i/o timeout" }
const (
// 16-bytes nonce for each packet
nonceSize = 16
// 4-bytes packet checksum
crcSize = 4
// overall crypto header size
cryptHeaderSize = nonceSize + crcSize
// maximum packet size
mtuLimit = 1500
// FEC keeps rxFECMulti* (dataShard+parityShard) ordered packets in memory
rxFECMulti = 3
// accept backlog
acceptBacklog = 128
)
const (
errBrokenPipe = "broken pipe"
errInvalidOperation = "invalid operation"
)
var (
// a system-wide packet buffer shared among sending, receiving and FEC
// to mitigate high-frequency memory allocation for packets
xmitBuf sync.Pool
)
func init() {
xmitBuf.New = func() interface{} {
return make([]byte, mtuLimit)
}
}
type (
// UDPSession defines a KCP session implemented by UDP
UDPSession struct {
updaterIdx int // record slice index in updater
conn net.PacketConn // the underlying packet connection
kcp *KCP // KCP ARQ protocol
l *Listener // pointing to the Listener object if it's been accepted by a Listener
block BlockCrypt // block encryption object
// kcp receiving is based on packets
// recvbuf turns packets into stream
recvbuf []byte
bufptr []byte
// header extended output buffer, if has header
ext []byte
// FEC codec
fecDecoder *fecDecoder
fecEncoder *fecEncoder
// settings
remote net.Addr // remote peer address
rd time.Time // read deadline
wd time.Time // write deadline
headerSize int // the header size additional to a KCP frame
ackNoDelay bool // send ack immediately for each incoming packet(testing purpose)
writeDelay bool // delay kcp.flush() for Write() for bulk transfer
dup int // duplicate udp packets(testing purpose)
// notifications
die chan struct{} // notify current session has Closed
chReadEvent chan struct{} // notify Read() can be called without blocking
chWriteEvent chan struct{} // notify Write() can be called without blocking
chReadError chan error // notify PacketConn.Read() have an error
chWriteError chan error // notify PacketConn.Write() have an error
// nonce generator
nonce Entropy
isClosed bool // flag the session has Closed
mu sync.Mutex
}
setReadBuffer interface {
SetReadBuffer(bytes int) error
}
setWriteBuffer interface {
SetWriteBuffer(bytes int) error
}
)
// newUDPSession create a new udp session for client or server
func newUDPSession(conv uint32, dataShards, parityShards int, l *Listener, conn net.PacketConn, remote net.Addr, block BlockCrypt) *UDPSession {
sess := new(UDPSession)
sess.die = make(chan struct{})
sess.nonce = new(nonceAES128)
sess.nonce.Init()
sess.chReadEvent = make(chan struct{}, 1)
sess.chWriteEvent = make(chan struct{}, 1)
sess.chReadError = make(chan error, 1)
sess.chWriteError = make(chan error, 1)
sess.remote = remote
sess.conn = conn
sess.l = l
sess.block = block
sess.recvbuf = make([]byte, mtuLimit)
// FEC codec initialization
sess.fecDecoder = newFECDecoder(rxFECMulti*(dataShards+parityShards), dataShards, parityShards)
if sess.block != nil {
sess.fecEncoder = newFECEncoder(dataShards, parityShards, cryptHeaderSize)
} else {
sess.fecEncoder = newFECEncoder(dataShards, parityShards, 0)
}
// calculate additional header size introduced by FEC and encryption
if sess.block != nil {
sess.headerSize += cryptHeaderSize
}
if sess.fecEncoder != nil {
sess.headerSize += fecHeaderSizePlus2
}
// we only need to allocate extended packet buffer if we have the additional header
if sess.headerSize > 0 {
sess.ext = make([]byte, mtuLimit)
}
sess.kcp = NewKCP(conv, func(buf []byte, size int) {
if size >= IKCP_OVERHEAD {
sess.output(buf[:size])
}
})
sess.kcp.SetMtu(IKCP_MTU_DEF - sess.headerSize)
// register current session to the global updater,
// which call sess.update() periodically.
updater.addSession(sess)
if sess.l == nil { // it's a client connection
go sess.readLoop()
atomic.AddUint64(&DefaultSnmp.ActiveOpens, 1)
} else {
atomic.AddUint64(&DefaultSnmp.PassiveOpens, 1)
}
currestab := atomic.AddUint64(&DefaultSnmp.CurrEstab, 1)
maxconn := atomic.LoadUint64(&DefaultSnmp.MaxConn)
if currestab > maxconn {
atomic.CompareAndSwapUint64(&DefaultSnmp.MaxConn, maxconn, currestab)
}
return sess
}
// Read implements net.Conn
func (s *UDPSession) Read(b []byte) (n int, err error) {
for {
s.mu.Lock()
if len(s.bufptr) > 0 { // copy from buffer into b
n = copy(b, s.bufptr)
s.bufptr = s.bufptr[n:]
s.mu.Unlock()
atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(n))
return n, nil
}
if s.isClosed {
s.mu.Unlock()
return 0, errors.New(errBrokenPipe)
}
if size := s.kcp.PeekSize(); size > 0 { // peek data size from kcp
if len(b) >= size { // receive data into 'b' directly
s.kcp.Recv(b)
s.mu.Unlock()
atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(size))
return size, nil
}
// if necessary resize the stream buffer to guarantee a sufficent buffer space
if cap(s.recvbuf) < size {
s.recvbuf = make([]byte, size)
}
// resize the length of recvbuf to correspond to data size
s.recvbuf = s.recvbuf[:size]
s.kcp.Recv(s.recvbuf)
n = copy(b, s.recvbuf) // copy to 'b'
s.bufptr = s.recvbuf[n:] // pointer update
s.mu.Unlock()
atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(n))
return n, nil
}
// deadline for current reading operation
var timeout *time.Timer
var c <-chan time.Time
if !s.rd.IsZero() {
if time.Now().After(s.rd) {
s.mu.Unlock()
return 0, errTimeout{}
}
delay := s.rd.Sub(time.Now())
timeout = time.NewTimer(delay)
c = timeout.C
}
s.mu.Unlock()
// wait for read event or timeout
select {
case <-s.chReadEvent:
case <-c:
case <-s.die:
case err = <-s.chReadError:
if timeout != nil {
timeout.Stop()
}
return n, err
}
if timeout != nil {
timeout.Stop()
}
}
}
// Write implements net.Conn
func (s *UDPSession) Write(b []byte) (n int, err error) {
for {
s.mu.Lock()
if s.isClosed {
s.mu.Unlock()
return 0, errors.New(errBrokenPipe)
}
// controls how much data will be sent to kcp core
// to prevent the memory from exhuasting
if s.kcp.WaitSnd() < int(s.kcp.snd_wnd) {
n = len(b)
for {
if len(b) <= int(s.kcp.mss) {
s.kcp.Send(b)
break
} else {
s.kcp.Send(b[:s.kcp.mss])
b = b[s.kcp.mss:]
}
}
// flush immediately if the queue is full
if s.kcp.WaitSnd() >= int(s.kcp.snd_wnd) || !s.writeDelay {
s.kcp.flush(false)
}
s.mu.Unlock()
atomic.AddUint64(&DefaultSnmp.BytesSent, uint64(n))
return n, nil
}
// deadline for current writing operation
var timeout *time.Timer
var c <-chan time.Time
if !s.wd.IsZero() {
if time.Now().After(s.wd) {
s.mu.Unlock()
return 0, errTimeout{}
}
delay := s.wd.Sub(time.Now())
timeout = time.NewTimer(delay)
c = timeout.C
}
s.mu.Unlock()
// wait for write event or timeout
select {
case <-s.chWriteEvent:
case <-c:
case <-s.die:
case err = <-s.chWriteError:
if timeout != nil {
timeout.Stop()
}
return n, err
}
if timeout != nil {
timeout.Stop()
}
}
}
// Close closes the connection.
func (s *UDPSession) Close() error {
// remove current session from updater & listener(if necessary)
updater.removeSession(s)
if s.l != nil { // notify listener
s.l.closeSession(s.remote)
}
s.mu.Lock()
defer s.mu.Unlock()
if s.isClosed {
return errors.New(errBrokenPipe)
}
close(s.die)
s.isClosed = true
atomic.AddUint64(&DefaultSnmp.CurrEstab, ^uint64(0))
if s.l == nil { // client socket close
return s.conn.Close()
}
return nil
}
// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
func (s *UDPSession) LocalAddr() net.Addr { return s.conn.LocalAddr() }
// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
func (s *UDPSession) RemoteAddr() net.Addr { return s.remote }
// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
func (s *UDPSession) SetDeadline(t time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
s.rd = t
s.wd = t
s.notifyReadEvent()
s.notifyWriteEvent()
return nil
}
// SetReadDeadline implements the Conn SetReadDeadline method.
func (s *UDPSession) SetReadDeadline(t time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
s.rd = t
s.notifyReadEvent()
return nil
}
// SetWriteDeadline implements the Conn SetWriteDeadline method.
func (s *UDPSession) SetWriteDeadline(t time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
s.wd = t
s.notifyWriteEvent()
return nil
}
// SetWriteDelay delays write for bulk transfer until the next update interval
func (s *UDPSession) SetWriteDelay(delay bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.writeDelay = delay
}
// SetWindowSize set maximum window size
func (s *UDPSession) SetWindowSize(sndwnd, rcvwnd int) {
s.mu.Lock()
defer s.mu.Unlock()
s.kcp.WndSize(sndwnd, rcvwnd)
}
// SetMtu sets the maximum transmission unit(not including UDP header)
func (s *UDPSession) SetMtu(mtu int) bool {
if mtu > mtuLimit {
return false
}
s.mu.Lock()
defer s.mu.Unlock()
s.kcp.SetMtu(mtu - s.headerSize)
return true
}
// SetStreamMode toggles the stream mode on/off
func (s *UDPSession) SetStreamMode(enable bool) {
s.mu.Lock()
defer s.mu.Unlock()
if enable {
s.kcp.stream = 1
} else {
s.kcp.stream = 0
}
}
// SetACKNoDelay changes ack flush option, set true to flush ack immediately,
func (s *UDPSession) SetACKNoDelay(nodelay bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.ackNoDelay = nodelay
}
// SetDUP duplicates udp packets for kcp output, for testing purpose only
func (s *UDPSession) SetDUP(dup int) {
s.mu.Lock()
defer s.mu.Unlock()
s.dup = dup
}
// SetNoDelay calls nodelay() of kcp
// https://github.com/skywind3000/kcp/blob/master/README.en.md#protocol-configuration
func (s *UDPSession) SetNoDelay(nodelay, interval, resend, nc int) {
s.mu.Lock()
defer s.mu.Unlock()
s.kcp.NoDelay(nodelay, interval, resend, nc)
}
// SetDSCP sets the 6bit DSCP field of IP header, no effect if it's accepted from Listener
func (s *UDPSession) SetDSCP(dscp int) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.l == nil {
if nc, ok := s.conn.(net.Conn); ok {
if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err != nil {
return ipv6.NewConn(nc).SetTrafficClass(dscp)
}
return nil
}
}
return errors.New(errInvalidOperation)
}
// SetReadBuffer sets the socket read buffer, no effect if it's accepted from Listener
func (s *UDPSession) SetReadBuffer(bytes int) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.l == nil {
if nc, ok := s.conn.(setReadBuffer); ok {
return nc.SetReadBuffer(bytes)
}
}
return errors.New(errInvalidOperation)
}
// SetWriteBuffer sets the socket write buffer, no effect if it's accepted from Listener
func (s *UDPSession) SetWriteBuffer(bytes int) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.l == nil {
if nc, ok := s.conn.(setWriteBuffer); ok {
return nc.SetWriteBuffer(bytes)
}
}
return errors.New(errInvalidOperation)
}
// post-processing for sending a packet from kcp core
// steps:
// 0. Header extending
// 1. FEC packet generation
// 2. CRC32 integrity
// 3. Encryption
// 4. WriteTo kernel
func (s *UDPSession) output(buf []byte) {
var ecc [][]byte
// 0. extend buf's header space(if necessary)
ext := buf
if s.headerSize > 0 {
ext = s.ext[:s.headerSize+len(buf)]
copy(ext[s.headerSize:], buf)
}
// 1. FEC encoding
if s.fecEncoder != nil {
ecc = s.fecEncoder.encode(ext)
}
// 2&3. crc32 & encryption
if s.block != nil {
s.nonce.Fill(ext[:nonceSize])
checksum := crc32.ChecksumIEEE(ext[cryptHeaderSize:])
binary.LittleEndian.PutUint32(ext[nonceSize:], checksum)
s.block.Encrypt(ext, ext)
for k := range ecc {
s.nonce.Fill(ecc[k][:nonceSize])
checksum := crc32.ChecksumIEEE(ecc[k][cryptHeaderSize:])
binary.LittleEndian.PutUint32(ecc[k][nonceSize:], checksum)
s.block.Encrypt(ecc[k], ecc[k])
}
}
// 4. WriteTo kernel
nbytes := 0
npkts := 0
for i := 0; i < s.dup+1; i++ {
if n, err := s.conn.WriteTo(ext, s.remote); err == nil {
nbytes += n
npkts++
} else {
s.notifyWriteError(err)
}
}
for k := range ecc {
if n, err := s.conn.WriteTo(ecc[k], s.remote); err == nil {
nbytes += n
npkts++
} else {
s.notifyWriteError(err)
}
}
atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts))
atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes))
}
// kcp update, returns interval for next calling
func (s *UDPSession) update() (interval time.Duration) {
s.mu.Lock()
waitsnd := s.kcp.WaitSnd()
interval = time.Duration(s.kcp.flush(false)) * time.Millisecond
if s.kcp.WaitSnd() < waitsnd {
s.notifyWriteEvent()
}
s.mu.Unlock()
return
}
// GetConv gets conversation id of a session
func (s *UDPSession) GetConv() uint32 { return s.kcp.conv }
func (s *UDPSession) notifyReadEvent() {
select {
case s.chReadEvent <- struct{}{}:
default:
}
}
func (s *UDPSession) notifyWriteEvent() {
select {
case s.chWriteEvent <- struct{}{}:
default:
}
}
func (s *UDPSession) notifyWriteError(err error) {
select {
case s.chWriteError <- err:
default:
}
}
func (s *UDPSession) kcpInput(data []byte) {
var kcpInErrors, fecErrs, fecRecovered, fecParityShards uint64
if s.fecDecoder != nil {
if len(data) > fecHeaderSize { // must be larger than fec header size
f := s.fecDecoder.decodeBytes(data)
if f.flag == typeData || f.flag == typeFEC { // header check
if f.flag == typeFEC {
fecParityShards++
}
recovers := s.fecDecoder.decode(f)
s.mu.Lock()
waitsnd := s.kcp.WaitSnd()
if f.flag == typeData {
if ret := s.kcp.Input(data[fecHeaderSizePlus2:], true, s.ackNoDelay); ret != 0 {
kcpInErrors++
}
}
for _, r := range recovers {
if len(r) >= 2 { // must be larger than 2bytes
sz := binary.LittleEndian.Uint16(r)
if int(sz) <= len(r) && sz >= 2 {
if ret := s.kcp.Input(r[2:sz], false, s.ackNoDelay); ret == 0 {
fecRecovered++
} else {
kcpInErrors++
}
} else {
fecErrs++
}
} else {
fecErrs++
}
}
// to notify the readers to receive the data
if n := s.kcp.PeekSize(); n > 0 {
s.notifyReadEvent()
}
// to notify the writers when queue is shorter(e.g. ACKed)
if s.kcp.WaitSnd() < waitsnd {
s.notifyWriteEvent()
}
s.mu.Unlock()
} else {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
}
} else {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
}
} else {
s.mu.Lock()
waitsnd := s.kcp.WaitSnd()
if ret := s.kcp.Input(data, true, s.ackNoDelay); ret != 0 {
kcpInErrors++
}
if n := s.kcp.PeekSize(); n > 0 {
s.notifyReadEvent()
}
if s.kcp.WaitSnd() < waitsnd {
s.notifyWriteEvent()
}
s.mu.Unlock()
}
atomic.AddUint64(&DefaultSnmp.InPkts, 1)
atomic.AddUint64(&DefaultSnmp.InBytes, uint64(len(data)))
if fecParityShards > 0 {
atomic.AddUint64(&DefaultSnmp.FECParityShards, fecParityShards)
}
if kcpInErrors > 0 {
atomic.AddUint64(&DefaultSnmp.KCPInErrors, kcpInErrors)
}
if fecErrs > 0 {
atomic.AddUint64(&DefaultSnmp.FECErrs, fecErrs)
}
if fecRecovered > 0 {
atomic.AddUint64(&DefaultSnmp.FECRecovered, fecRecovered)
}
}
// the read loop for a client session
func (s *UDPSession) readLoop() {
buf := make([]byte, mtuLimit)
var src string
for {
if n, addr, err := s.conn.ReadFrom(buf); err == nil {
// make sure the packet is from the same source
if src == "" { // set source address
src = addr.String()
} else if addr.String() != src {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
continue
}
if n >= s.headerSize+IKCP_OVERHEAD {
data := buf[:n]
dataValid := false
if s.block != nil {
s.block.Decrypt(data, data)
data = data[nonceSize:]
checksum := crc32.ChecksumIEEE(data[crcSize:])
if checksum == binary.LittleEndian.Uint32(data) {
data = data[crcSize:]
dataValid = true
} else {
atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1)
}
} else if s.block == nil {
dataValid = true
}
if dataValid {
s.kcpInput(data)
}
} else {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
}
} else {
s.chReadError <- err
return
}
}
}
type (
// Listener defines a server which will be waiting to accept incoming connections
Listener struct {
block BlockCrypt // block encryption
dataShards int // FEC data shard
parityShards int // FEC parity shard
fecDecoder *fecDecoder // FEC mock initialization
conn net.PacketConn // the underlying packet connection
sessions map[string]*UDPSession // all sessions accepted by this Listener
sessionLock sync.Mutex
chAccepts chan *UDPSession // Listen() backlog
chSessionClosed chan net.Addr // session close queue
headerSize int // the additional header to a KCP frame
die chan struct{} // notify the listener has closed
rd atomic.Value // read deadline for Accept()
wd atomic.Value
}
)
// monitor incoming data for all connections of server
func (l *Listener) monitor() {
// a cache for session object last used
var lastAddr string
var lastSession *UDPSession
buf := make([]byte, mtuLimit)
for {
if n, from, err := l.conn.ReadFrom(buf); err == nil {
if n >= l.headerSize+IKCP_OVERHEAD {
data := buf[:n]
dataValid := false
if l.block != nil {
l.block.Decrypt(data, data)
data = data[nonceSize:]
checksum := crc32.ChecksumIEEE(data[crcSize:])
if checksum == binary.LittleEndian.Uint32(data) {
data = data[crcSize:]
dataValid = true
} else {
atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1)
}
} else if l.block == nil {
dataValid = true
}
if dataValid {
addr := from.String()
var s *UDPSession
var ok bool
// the packets received from an address always come in batch,
// cache the session for next packet, without querying map.
if addr == lastAddr {
s, ok = lastSession, true
} else {
l.sessionLock.Lock()
if s, ok = l.sessions[addr]; ok {
lastSession = s
lastAddr = addr
}
l.sessionLock.Unlock()
}
if !ok { // new session
if len(l.chAccepts) < cap(l.chAccepts) { // do not let the new sessions overwhelm accept queue
var conv uint32
convValid := false
if l.fecDecoder != nil {
isfec := binary.LittleEndian.Uint16(data[4:])
if isfec == typeData {
conv = binary.LittleEndian.Uint32(data[fecHeaderSizePlus2:])
convValid = true
}
} else {
conv = binary.LittleEndian.Uint32(data)
convValid = true
}
if convValid { // creates a new session only if the 'conv' field in kcp is accessible
s := newUDPSession(conv, l.dataShards, l.parityShards, l, l.conn, from, l.block)
s.kcpInput(data)
l.sessionLock.Lock()
l.sessions[addr] = s
l.sessionLock.Unlock()
l.chAccepts <- s
}
}
} else {
s.kcpInput(data)
}
}
} else {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
}
} else {
return
}
}
}
// SetReadBuffer sets the socket read buffer for the Listener
func (l *Listener) SetReadBuffer(bytes int) error {
if nc, ok := l.conn.(setReadBuffer); ok {
return nc.SetReadBuffer(bytes)
}
return errors.New(errInvalidOperation)
}
// SetWriteBuffer sets the socket write buffer for the Listener
func (l *Listener) SetWriteBuffer(bytes int) error {
if nc, ok := l.conn.(setWriteBuffer); ok {
return nc.SetWriteBuffer(bytes)
}
return errors.New(errInvalidOperation)
}
// SetDSCP sets the 6bit DSCP field of IP header
func (l *Listener) SetDSCP(dscp int) error {
if nc, ok := l.conn.(net.Conn); ok {
if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err != nil {
return ipv6.NewConn(nc).SetTrafficClass(dscp)
}
return nil
}
return errors.New(errInvalidOperation)
}
// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn.
func (l *Listener) Accept() (net.Conn, error) {
return l.AcceptKCP()
}
// AcceptKCP accepts a KCP connection
func (l *Listener) AcceptKCP() (*UDPSession, error) {
var timeout <-chan time.Time
if tdeadline, ok := l.rd.Load().(time.Time); ok && !tdeadline.IsZero() {
timeout = time.After(tdeadline.Sub(time.Now()))
}
select {
case <-timeout:
return nil, &errTimeout{}
case c := <-l.chAccepts:
return c, nil
case <-l.die:
return nil, errors.New(errBrokenPipe)
}
}
// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
func (l *Listener) SetDeadline(t time.Time) error {
l.SetReadDeadline(t)
l.SetWriteDeadline(t)
return nil
}
// SetReadDeadline implements the Conn SetReadDeadline method.
func (l *Listener) SetReadDeadline(t time.Time) error {
l.rd.Store(t)
return nil
}
// SetWriteDeadline implements the Conn SetWriteDeadline method.
func (l *Listener) SetWriteDeadline(t time.Time) error {
l.wd.Store(t)
return nil
}
// Close stops listening on the UDP address. Already Accepted connections are not closed.
func (l *Listener) Close() error {
close(l.die)
return l.conn.Close()
}
// closeSession notify the listener that a session has closed
func (l *Listener) closeSession(remote net.Addr) (ret bool) {
l.sessionLock.Lock()
defer l.sessionLock.Unlock()
if _, ok := l.sessions[remote.String()]; ok {
delete(l.sessions, remote.String())
return true
}
return false
}
// Addr returns the listener's network address, The Addr returned is shared by all invocations of Addr, so do not modify it.
func (l *Listener) Addr() net.Addr { return l.conn.LocalAddr() }
// Listen listens for incoming KCP packets addressed to the local address laddr on the network "udp",
func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr, nil, 0, 0) }
// ListenWithOptions listens for incoming KCP packets addressed to the local address laddr on the network "udp" with packet encryption,
// rdataShards, parityShards defines Reed-Solomon Erasure Coding parametes
func ListenWithOptions(laddr string, block BlockCrypt, dataShards, parityShards int) (*Listener, error) {
udpaddr, err := net.ResolveUDPAddr("udp", laddr)
if err != nil {
return nil, errors.Wrap(err, "net.ResolveUDPAddr")
}
conn, err := net.ListenUDP("udp", udpaddr)
if err != nil {
return nil, errors.Wrap(err, "net.ListenUDP")
}
return ServeConn(block, dataShards, parityShards, conn)
}
// ServeConn serves KCP protocol for a single packet connection.
func ServeConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*Listener, error) {
l := new(Listener)
l.conn = conn
l.sessions = make(map[string]*UDPSession)
l.chAccepts = make(chan *UDPSession, acceptBacklog)
l.chSessionClosed = make(chan net.Addr)
l.die = make(chan struct{})
l.dataShards = dataShards
l.parityShards = parityShards
l.block = block
l.fecDecoder = newFECDecoder(rxFECMulti*(dataShards+parityShards), dataShards, parityShards)
// calculate header size
if l.block != nil {
l.headerSize += cryptHeaderSize
}
if l.fecDecoder != nil {
l.headerSize += fecHeaderSizePlus2
}
go l.monitor()
return l, nil
}
// Dial connects to the remote address "raddr" on the network "udp"
func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr, nil, 0, 0) }
// DialWithOptions connects to the remote address "raddr" on the network "udp" with packet encryption
func DialWithOptions(raddr string, block BlockCrypt, dataShards, parityShards int) (*UDPSession, error) {
// network type detection
udpaddr, err := net.ResolveUDPAddr("udp", raddr)
if err != nil {
return nil, errors.Wrap(err, "net.ResolveUDPAddr")
}
network := "udp4"
if udpaddr.IP.To4() == nil {
network = "udp"
}
conn, err := net.ListenUDP(network, nil)
if err != nil {
return nil, errors.Wrap(err, "net.DialUDP")
}
return NewConn(raddr, block, dataShards, parityShards, conn)
}
// NewConn establishes a session and talks KCP protocol over a packet connection.
func NewConn(raddr string, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) {
udpaddr, err := net.ResolveUDPAddr("udp", raddr)
if err != nil {
return nil, errors.Wrap(err, "net.ResolveUDPAddr")
}
var convid uint32
binary.Read(rand.Reader, binary.LittleEndian, &convid)
return newUDPSession(convid, dataShards, parityShards, nil, conn, udpaddr, block), nil
}
// monotonic reference time point
var refTime time.Time = time.Now()
// currentMs returns current elasped monotonic milliseconds since program startup
func currentMs() uint32 { return uint32(time.Now().Sub(refTime) / time.Millisecond) }

475
lib/kcp/sess_test.go Normal file
View File

@ -0,0 +1,475 @@
package kcp
import (
"crypto/sha1"
"fmt"
"io"
"log"
"net"
"net/http"
_ "net/http/pprof"
"sync"
"testing"
"time"
"golang.org/x/crypto/pbkdf2"
)
const portEcho = "127.0.0.1:9999"
const portSink = "127.0.0.1:19999"
const portTinyBufferEcho = "127.0.0.1:29999"
const portListerner = "127.0.0.1:9998"
var key = []byte("testkey")
var pass = pbkdf2.Key(key, []byte(portSink), 4096, 32, sha1.New)
func init() {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
go echoServer()
go sinkServer()
go tinyBufferEchoServer()
println("beginning tests, encryption:salsa20, fec:10/3")
}
func dialEcho() (*UDPSession, error) {
//block, _ := NewNoneBlockCrypt(pass)
//block, _ := NewSimpleXORBlockCrypt(pass)
//block, _ := NewTEABlockCrypt(pass[:16])
//block, _ := NewAESBlockCrypt(pass)
block, _ := NewSalsa20BlockCrypt(pass)
sess, err := DialWithOptions(portEcho, block, 10, 3)
if err != nil {
panic(err)
}
sess.SetStreamMode(true)
sess.SetStreamMode(false)
sess.SetStreamMode(true)
sess.SetWindowSize(1024, 1024)
sess.SetReadBuffer(16 * 1024 * 1024)
sess.SetWriteBuffer(16 * 1024 * 1024)
sess.SetStreamMode(true)
sess.SetNoDelay(1, 10, 2, 1)
sess.SetMtu(1400)
sess.SetMtu(1600)
sess.SetMtu(1400)
sess.SetACKNoDelay(true)
sess.SetACKNoDelay(false)
sess.SetDeadline(time.Now().Add(time.Minute))
return sess, err
}
func dialSink() (*UDPSession, error) {
sess, err := DialWithOptions(portSink, nil, 0, 0)
if err != nil {
panic(err)
}
sess.SetStreamMode(true)
sess.SetWindowSize(1024, 1024)
sess.SetReadBuffer(16 * 1024 * 1024)
sess.SetWriteBuffer(16 * 1024 * 1024)
sess.SetStreamMode(true)
sess.SetNoDelay(1, 10, 2, 1)
sess.SetMtu(1400)
sess.SetACKNoDelay(false)
sess.SetDeadline(time.Now().Add(time.Minute))
return sess, err
}
func dialTinyBufferEcho() (*UDPSession, error) {
//block, _ := NewNoneBlockCrypt(pass)
//block, _ := NewSimpleXORBlockCrypt(pass)
//block, _ := NewTEABlockCrypt(pass[:16])
//block, _ := NewAESBlockCrypt(pass)
block, _ := NewSalsa20BlockCrypt(pass)
sess, err := DialWithOptions(portTinyBufferEcho, block, 10, 3)
if err != nil {
panic(err)
}
return sess, err
}
//////////////////////////
func listenEcho() (net.Listener, error) {
//block, _ := NewNoneBlockCrypt(pass)
//block, _ := NewSimpleXORBlockCrypt(pass)
//block, _ := NewTEABlockCrypt(pass[:16])
//block, _ := NewAESBlockCrypt(pass)
block, _ := NewSalsa20BlockCrypt(pass)
return ListenWithOptions(portEcho, block, 10, 3)
}
func listenTinyBufferEcho() (net.Listener, error) {
//block, _ := NewNoneBlockCrypt(pass)
//block, _ := NewSimpleXORBlockCrypt(pass)
//block, _ := NewTEABlockCrypt(pass[:16])
//block, _ := NewAESBlockCrypt(pass)
block, _ := NewSalsa20BlockCrypt(pass)
return ListenWithOptions(portTinyBufferEcho, block, 10, 3)
}
func listenSink() (net.Listener, error) {
return ListenWithOptions(portSink, nil, 0, 0)
}
func echoServer() {
l, err := listenEcho()
if err != nil {
panic(err)
}
go func() {
kcplistener := l.(*Listener)
kcplistener.SetReadBuffer(4 * 1024 * 1024)
kcplistener.SetWriteBuffer(4 * 1024 * 1024)
kcplistener.SetDSCP(46)
for {
s, err := l.Accept()
if err != nil {
return
}
// coverage test
s.(*UDPSession).SetReadBuffer(4 * 1024 * 1024)
s.(*UDPSession).SetWriteBuffer(4 * 1024 * 1024)
go handleEcho(s.(*UDPSession))
}
}()
}
func sinkServer() {
l, err := listenSink()
if err != nil {
panic(err)
}
go func() {
kcplistener := l.(*Listener)
kcplistener.SetReadBuffer(4 * 1024 * 1024)
kcplistener.SetWriteBuffer(4 * 1024 * 1024)
kcplistener.SetDSCP(46)
for {
s, err := l.Accept()
if err != nil {
return
}
go handleSink(s.(*UDPSession))
}
}()
}
func tinyBufferEchoServer() {
l, err := listenTinyBufferEcho()
if err != nil {
panic(err)
}
go func() {
for {
s, err := l.Accept()
if err != nil {
return
}
go handleTinyBufferEcho(s.(*UDPSession))
}
}()
}
///////////////////////////
func handleEcho(conn *UDPSession) {
conn.SetStreamMode(true)
conn.SetWindowSize(4096, 4096)
conn.SetNoDelay(1, 10, 2, 1)
conn.SetDSCP(46)
conn.SetMtu(1400)
conn.SetACKNoDelay(false)
conn.SetReadDeadline(time.Now().Add(time.Hour))
conn.SetWriteDeadline(time.Now().Add(time.Hour))
buf := make([]byte, 65536)
for {
n, err := conn.Read(buf)
if err != nil {
panic(err)
}
conn.Write(buf[:n])
}
}
func handleSink(conn *UDPSession) {
conn.SetStreamMode(true)
conn.SetWindowSize(4096, 4096)
conn.SetNoDelay(1, 10, 2, 1)
conn.SetDSCP(46)
conn.SetMtu(1400)
conn.SetACKNoDelay(false)
conn.SetReadDeadline(time.Now().Add(time.Hour))
conn.SetWriteDeadline(time.Now().Add(time.Hour))
buf := make([]byte, 65536)
for {
_, err := conn.Read(buf)
if err != nil {
panic(err)
}
}
}
func handleTinyBufferEcho(conn *UDPSession) {
conn.SetStreamMode(true)
buf := make([]byte, 2)
for {
n, err := conn.Read(buf)
if err != nil {
panic(err)
}
conn.Write(buf[:n])
}
}
///////////////////////////
func TestTimeout(t *testing.T) {
cli, err := dialEcho()
if err != nil {
panic(err)
}
buf := make([]byte, 10)
//timeout
cli.SetDeadline(time.Now().Add(time.Second))
<-time.After(2 * time.Second)
n, err := cli.Read(buf)
if n != 0 || err == nil {
t.Fail()
}
cli.Close()
}
func TestSendRecv(t *testing.T) {
cli, err := dialEcho()
if err != nil {
panic(err)
}
cli.SetWriteDelay(true)
cli.SetDUP(1)
const N = 100
buf := make([]byte, 10)
for i := 0; i < N; i++ {
msg := fmt.Sprintf("hello%v", i)
cli.Write([]byte(msg))
if n, err := cli.Read(buf); err == nil {
if string(buf[:n]) != msg {
t.Fail()
}
} else {
panic(err)
}
}
cli.Close()
}
func TestTinyBufferReceiver(t *testing.T) {
cli, err := dialTinyBufferEcho()
if err != nil {
panic(err)
}
const N = 100
snd := byte(0)
fillBuffer := func(buf []byte) {
for i := 0; i < len(buf); i++ {
buf[i] = snd
snd++
}
}
rcv := byte(0)
check := func(buf []byte) bool {
for i := 0; i < len(buf); i++ {
if buf[i] != rcv {
return false
}
rcv++
}
return true
}
sndbuf := make([]byte, 7)
rcvbuf := make([]byte, 7)
for i := 0; i < N; i++ {
fillBuffer(sndbuf)
cli.Write(sndbuf)
if n, err := io.ReadFull(cli, rcvbuf); err == nil {
if !check(rcvbuf[:n]) {
t.Fail()
}
} else {
panic(err)
}
}
cli.Close()
}
func TestClose(t *testing.T) {
cli, err := dialEcho()
if err != nil {
panic(err)
}
buf := make([]byte, 10)
cli.Close()
if cli.Close() == nil {
t.Fail()
}
n, err := cli.Write(buf)
if n != 0 || err == nil {
t.Fail()
}
n, err = cli.Read(buf)
if n != 0 || err == nil {
t.Fail()
}
cli.Close()
}
func TestParallel1024CLIENT_64BMSG_64CNT(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1024)
for i := 0; i < 1024; i++ {
go parallel_client(&wg)
}
wg.Wait()
}
func parallel_client(wg *sync.WaitGroup) (err error) {
cli, err := dialEcho()
if err != nil {
panic(err)
}
err = echo_tester(cli, 64, 64)
wg.Done()
return
}
func BenchmarkEchoSpeed4K(b *testing.B) {
speedclient(b, 4096)
}
func BenchmarkEchoSpeed64K(b *testing.B) {
speedclient(b, 65536)
}
func BenchmarkEchoSpeed512K(b *testing.B) {
speedclient(b, 524288)
}
func BenchmarkEchoSpeed1M(b *testing.B) {
speedclient(b, 1048576)
}
func speedclient(b *testing.B, nbytes int) {
b.ReportAllocs()
cli, err := dialEcho()
if err != nil {
panic(err)
}
if err := echo_tester(cli, nbytes, b.N); err != nil {
b.Fail()
}
b.SetBytes(int64(nbytes))
}
func BenchmarkSinkSpeed4K(b *testing.B) {
sinkclient(b, 4096)
}
func BenchmarkSinkSpeed64K(b *testing.B) {
sinkclient(b, 65536)
}
func BenchmarkSinkSpeed256K(b *testing.B) {
sinkclient(b, 524288)
}
func BenchmarkSinkSpeed1M(b *testing.B) {
sinkclient(b, 1048576)
}
func sinkclient(b *testing.B, nbytes int) {
b.ReportAllocs()
cli, err := dialSink()
if err != nil {
panic(err)
}
sink_tester(cli, nbytes, b.N)
b.SetBytes(int64(nbytes))
}
func echo_tester(cli net.Conn, msglen, msgcount int) error {
buf := make([]byte, msglen)
for i := 0; i < msgcount; i++ {
// send packet
if _, err := cli.Write(buf); err != nil {
return err
}
// receive packet
nrecv := 0
for {
n, err := cli.Read(buf)
if err != nil {
return err
} else {
nrecv += n
if nrecv == msglen {
break
}
}
}
}
return nil
}
func sink_tester(cli *UDPSession, msglen, msgcount int) error {
// sender
buf := make([]byte, msglen)
for i := 0; i < msgcount; i++ {
if _, err := cli.Write(buf); err != nil {
return err
}
}
return nil
}
func TestSNMP(t *testing.T) {
t.Log(DefaultSnmp.Copy())
t.Log(DefaultSnmp.Header())
t.Log(DefaultSnmp.ToSlice())
DefaultSnmp.Reset()
t.Log(DefaultSnmp.ToSlice())
}
func TestListenerClose(t *testing.T) {
l, err := ListenWithOptions(portListerner, nil, 10, 3)
if err != nil {
t.Fail()
}
l.SetReadDeadline(time.Now().Add(time.Second))
l.SetWriteDeadline(time.Now().Add(time.Second))
l.SetDeadline(time.Now().Add(time.Second))
time.Sleep(2 * time.Second)
if _, err := l.Accept(); err == nil {
t.Fail()
}
l.Close()
fakeaddr, _ := net.ResolveUDPAddr("udp6", "127.0.0.1:1111")
if l.closeSession(fakeaddr) {
t.Fail()
}
}

164
lib/kcp/snmp.go Normal file
View File

@ -0,0 +1,164 @@
package kcp
import (
"fmt"
"sync/atomic"
)
// Snmp defines network statistics indicator
type Snmp struct {
BytesSent uint64 // bytes sent from upper level
BytesReceived uint64 // bytes received to upper level
MaxConn uint64 // max number of connections ever reached
ActiveOpens uint64 // accumulated active open connections
PassiveOpens uint64 // accumulated passive open connections
CurrEstab uint64 // current number of established connections
InErrs uint64 // UDP read errors reported from net.PacketConn
InCsumErrors uint64 // checksum errors from CRC32
KCPInErrors uint64 // packet iput errors reported from KCP
InPkts uint64 // incoming packets count
OutPkts uint64 // outgoing packets count
InSegs uint64 // incoming KCP segments
OutSegs uint64 // outgoing KCP segments
InBytes uint64 // UDP bytes received
OutBytes uint64 // UDP bytes sent
RetransSegs uint64 // accmulated retransmited segments
FastRetransSegs uint64 // accmulated fast retransmitted segments
EarlyRetransSegs uint64 // accmulated early retransmitted segments
LostSegs uint64 // number of segs infered as lost
RepeatSegs uint64 // number of segs duplicated
FECRecovered uint64 // correct packets recovered from FEC
FECErrs uint64 // incorrect packets recovered from FEC
FECParityShards uint64 // FEC segments received
FECShortShards uint64 // number of data shards that's not enough for recovery
}
func newSnmp() *Snmp {
return new(Snmp)
}
// Header returns all field names
func (s *Snmp) Header() []string {
return []string{
"BytesSent",
"BytesReceived",
"MaxConn",
"ActiveOpens",
"PassiveOpens",
"CurrEstab",
"InErrs",
"InCsumErrors",
"KCPInErrors",
"InPkts",
"OutPkts",
"InSegs",
"OutSegs",
"InBytes",
"OutBytes",
"RetransSegs",
"FastRetransSegs",
"EarlyRetransSegs",
"LostSegs",
"RepeatSegs",
"FECParityShards",
"FECErrs",
"FECRecovered",
"FECShortShards",
}
}
// ToSlice returns current snmp info as slice
func (s *Snmp) ToSlice() []string {
snmp := s.Copy()
return []string{
fmt.Sprint(snmp.BytesSent),
fmt.Sprint(snmp.BytesReceived),
fmt.Sprint(snmp.MaxConn),
fmt.Sprint(snmp.ActiveOpens),
fmt.Sprint(snmp.PassiveOpens),
fmt.Sprint(snmp.CurrEstab),
fmt.Sprint(snmp.InErrs),
fmt.Sprint(snmp.InCsumErrors),
fmt.Sprint(snmp.KCPInErrors),
fmt.Sprint(snmp.InPkts),
fmt.Sprint(snmp.OutPkts),
fmt.Sprint(snmp.InSegs),
fmt.Sprint(snmp.OutSegs),
fmt.Sprint(snmp.InBytes),
fmt.Sprint(snmp.OutBytes),
fmt.Sprint(snmp.RetransSegs),
fmt.Sprint(snmp.FastRetransSegs),
fmt.Sprint(snmp.EarlyRetransSegs),
fmt.Sprint(snmp.LostSegs),
fmt.Sprint(snmp.RepeatSegs),
fmt.Sprint(snmp.FECParityShards),
fmt.Sprint(snmp.FECErrs),
fmt.Sprint(snmp.FECRecovered),
fmt.Sprint(snmp.FECShortShards),
}
}
// Copy make a copy of current snmp snapshot
func (s *Snmp) Copy() *Snmp {
d := newSnmp()
d.BytesSent = atomic.LoadUint64(&s.BytesSent)
d.BytesReceived = atomic.LoadUint64(&s.BytesReceived)
d.MaxConn = atomic.LoadUint64(&s.MaxConn)
d.ActiveOpens = atomic.LoadUint64(&s.ActiveOpens)
d.PassiveOpens = atomic.LoadUint64(&s.PassiveOpens)
d.CurrEstab = atomic.LoadUint64(&s.CurrEstab)
d.InErrs = atomic.LoadUint64(&s.InErrs)
d.InCsumErrors = atomic.LoadUint64(&s.InCsumErrors)
d.KCPInErrors = atomic.LoadUint64(&s.KCPInErrors)
d.InPkts = atomic.LoadUint64(&s.InPkts)
d.OutPkts = atomic.LoadUint64(&s.OutPkts)
d.InSegs = atomic.LoadUint64(&s.InSegs)
d.OutSegs = atomic.LoadUint64(&s.OutSegs)
d.InBytes = atomic.LoadUint64(&s.InBytes)
d.OutBytes = atomic.LoadUint64(&s.OutBytes)
d.RetransSegs = atomic.LoadUint64(&s.RetransSegs)
d.FastRetransSegs = atomic.LoadUint64(&s.FastRetransSegs)
d.EarlyRetransSegs = atomic.LoadUint64(&s.EarlyRetransSegs)
d.LostSegs = atomic.LoadUint64(&s.LostSegs)
d.RepeatSegs = atomic.LoadUint64(&s.RepeatSegs)
d.FECParityShards = atomic.LoadUint64(&s.FECParityShards)
d.FECErrs = atomic.LoadUint64(&s.FECErrs)
d.FECRecovered = atomic.LoadUint64(&s.FECRecovered)
d.FECShortShards = atomic.LoadUint64(&s.FECShortShards)
return d
}
// Reset values to zero
func (s *Snmp) Reset() {
atomic.StoreUint64(&s.BytesSent, 0)
atomic.StoreUint64(&s.BytesReceived, 0)
atomic.StoreUint64(&s.MaxConn, 0)
atomic.StoreUint64(&s.ActiveOpens, 0)
atomic.StoreUint64(&s.PassiveOpens, 0)
atomic.StoreUint64(&s.CurrEstab, 0)
atomic.StoreUint64(&s.InErrs, 0)
atomic.StoreUint64(&s.InCsumErrors, 0)
atomic.StoreUint64(&s.KCPInErrors, 0)
atomic.StoreUint64(&s.InPkts, 0)
atomic.StoreUint64(&s.OutPkts, 0)
atomic.StoreUint64(&s.InSegs, 0)
atomic.StoreUint64(&s.OutSegs, 0)
atomic.StoreUint64(&s.InBytes, 0)
atomic.StoreUint64(&s.OutBytes, 0)
atomic.StoreUint64(&s.RetransSegs, 0)
atomic.StoreUint64(&s.FastRetransSegs, 0)
atomic.StoreUint64(&s.EarlyRetransSegs, 0)
atomic.StoreUint64(&s.LostSegs, 0)
atomic.StoreUint64(&s.RepeatSegs, 0)
atomic.StoreUint64(&s.FECParityShards, 0)
atomic.StoreUint64(&s.FECErrs, 0)
atomic.StoreUint64(&s.FECRecovered, 0)
atomic.StoreUint64(&s.FECShortShards, 0)
}
// DefaultSnmp is the global KCP connection statistics collector
var DefaultSnmp *Snmp
func init() {
DefaultSnmp = newSnmp()
}

104
lib/kcp/updater.go Normal file
View File

@ -0,0 +1,104 @@
package kcp
import (
"container/heap"
"sync"
"time"
)
var updater updateHeap
func init() {
updater.init()
go updater.updateTask()
}
// entry contains a session update info
type entry struct {
ts time.Time
s *UDPSession
}
// a global heap managed kcp.flush() caller
type updateHeap struct {
entries []entry
mu sync.Mutex
chWakeUp chan struct{}
}
func (h *updateHeap) Len() int { return len(h.entries) }
func (h *updateHeap) Less(i, j int) bool { return h.entries[i].ts.Before(h.entries[j].ts) }
func (h *updateHeap) Swap(i, j int) {
h.entries[i], h.entries[j] = h.entries[j], h.entries[i]
h.entries[i].s.updaterIdx = i
h.entries[j].s.updaterIdx = j
}
func (h *updateHeap) Push(x interface{}) {
h.entries = append(h.entries, x.(entry))
n := len(h.entries)
h.entries[n-1].s.updaterIdx = n - 1
}
func (h *updateHeap) Pop() interface{} {
n := len(h.entries)
x := h.entries[n-1]
h.entries[n-1].s.updaterIdx = -1
h.entries[n-1] = entry{} // manual set nil for GC
h.entries = h.entries[0 : n-1]
return x
}
func (h *updateHeap) init() {
h.chWakeUp = make(chan struct{}, 1)
}
func (h *updateHeap) addSession(s *UDPSession) {
h.mu.Lock()
heap.Push(h, entry{time.Now(), s})
h.mu.Unlock()
h.wakeup()
}
func (h *updateHeap) removeSession(s *UDPSession) {
h.mu.Lock()
if s.updaterIdx != -1 {
heap.Remove(h, s.updaterIdx)
}
h.mu.Unlock()
}
func (h *updateHeap) wakeup() {
select {
case h.chWakeUp <- struct{}{}:
default:
}
}
func (h *updateHeap) updateTask() {
var timer <-chan time.Time
for {
select {
case <-timer:
case <-h.chWakeUp:
}
h.mu.Lock()
hlen := h.Len()
for i := 0; i < hlen; i++ {
entry := &h.entries[0]
if time.Now().After(entry.ts) {
interval := entry.s.update()
entry.ts = time.Now().Add(interval)
heap.Fix(h, 0)
} else {
break
}
}
if hlen > 0 {
timer = time.After(h.entries[0].ts.Sub(time.Now()))
}
h.mu.Unlock()
}
}

View File

@ -1,4 +1,4 @@
package lib
package lg
import (
"log"
@ -9,10 +9,10 @@ import (
var Log *log.Logger
func InitLogFile(f string, isStdout bool) {
func InitLogFile(f string, isStdout bool, logPath string) {
var prefix string
if !isStdout {
logFile, err := os.OpenFile(filepath.Join(GetLogPath(), f+"_log.txt"), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0766)
logFile, err := os.OpenFile(filepath.Join(logPath, f+"_log.txt"), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0766)
if err != nil {
log.Fatalln("open file error !", err)
}

View File

@ -1,43 +0,0 @@
package lib
import (
"sync"
)
const poolSize = 64 * 1024
const poolSizeSmall = 100
const poolSizeUdp = 1472
const poolSizeCopy = 32 * 1024
var BufPool = sync.Pool{
New: func() interface{} {
return make([]byte, poolSize)
},
}
var BufPoolUdp = sync.Pool{
New: func() interface{} {
return make([]byte, poolSizeUdp)
},
}
var BufPoolMax = sync.Pool{
New: func() interface{} {
return make([]byte, poolSize)
},
}
var BufPoolSmall = sync.Pool{
New: func() interface{} {
return make([]byte, poolSizeSmall)
},
}
var BufPoolCopy = sync.Pool{
New: func() interface{} {
return make([]byte, poolSizeCopy)
},
}
func PutBufPoolCopy(buf []byte) {
if cap(buf) == poolSizeCopy {
BufPoolCopy.Put(buf[:poolSizeCopy])
}
}

49
lib/pool/pool.go Normal file
View File

@ -0,0 +1,49 @@
package pool
import (
"sync"
)
const PoolSize = 64 * 1024
const PoolSizeSmall = 100
const PoolSizeUdp = 1472
const PoolSizeCopy = 32 * 1024
var BufPool = sync.Pool{
New: func() interface{} {
return make([]byte, PoolSize)
},
}
var BufPoolUdp = sync.Pool{
New: func() interface{} {
return make([]byte, PoolSizeUdp)
},
}
var BufPoolMax = sync.Pool{
New: func() interface{} {
return make([]byte, PoolSize)
},
}
var BufPoolSmall = sync.Pool{
New: func() interface{} {
return make([]byte, PoolSizeSmall)
},
}
var BufPoolCopy = sync.Pool{
New: func() interface{} {
return make([]byte, PoolSizeCopy)
},
}
func PutBufPoolCopy(buf []byte) {
if cap(buf) == PoolSizeCopy {
BufPoolCopy.Put(buf[:PoolSizeCopy])
}
}
func PutBufPoolUdp(buf []byte) {
if cap(buf) == PoolSizeUdp {
BufPoolUdp.Put(buf[:PoolSizeUdp])
}
}

View File

@ -1,4 +1,4 @@
package lib
package rate
import (
"sync/atomic"

237
lib/snappy/decode.go Normal file
View File

@ -0,0 +1,237 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package snappy
import (
"encoding/binary"
"errors"
"io"
)
var (
// ErrCorrupt reports that the input is invalid.
ErrCorrupt = errors.New("snappy: corrupt input")
// ErrTooLarge reports that the uncompressed length is too large.
ErrTooLarge = errors.New("snappy: decoded block is too large")
// ErrUnsupported reports that the input isn't supported.
ErrUnsupported = errors.New("snappy: unsupported input")
errUnsupportedLiteralLength = errors.New("snappy: unsupported literal length")
)
// DecodedLen returns the length of the decoded block.
func DecodedLen(src []byte) (int, error) {
v, _, err := decodedLen(src)
return v, err
}
// decodedLen returns the length of the decoded block and the number of bytes
// that the length header occupied.
func decodedLen(src []byte) (blockLen, headerLen int, err error) {
v, n := binary.Uvarint(src)
if n <= 0 || v > 0xffffffff {
return 0, 0, ErrCorrupt
}
const wordSize = 32 << (^uint(0) >> 32 & 1)
if wordSize == 32 && v > 0x7fffffff {
return 0, 0, ErrTooLarge
}
return int(v), n, nil
}
const (
decodeErrCodeCorrupt = 1
decodeErrCodeUnsupportedLiteralLength = 2
)
// Decode returns the decoded form of src. The returned slice may be a sub-
// slice of dst if dst was large enough to hold the entire decoded block.
// Otherwise, a newly allocated slice will be returned.
//
// The dst and src must not overlap. It is valid to pass a nil dst.
func Decode(dst, src []byte) ([]byte, error) {
dLen, s, err := decodedLen(src)
if err != nil {
return nil, err
}
if dLen <= len(dst) {
dst = dst[:dLen]
} else {
dst = make([]byte, dLen)
}
switch decode(dst, src[s:]) {
case 0:
return dst, nil
case decodeErrCodeUnsupportedLiteralLength:
return nil, errUnsupportedLiteralLength
}
return nil, ErrCorrupt
}
// NewReader returns a new Reader that decompresses from r, using the framing
// format described at
// https://github.com/google/snappy/blob/master/framing_format.txt
func NewReader(r io.Reader) *Reader {
return &Reader{
r: r,
decoded: make([]byte, maxBlockSize),
buf: make([]byte, maxEncodedLenOfMaxBlockSize+checksumSize),
}
}
// Reader is an io.Reader that can read Snappy-compressed bytes.
type Reader struct {
r io.Reader
err error
decoded []byte
buf []byte
// decoded[i:j] contains decoded bytes that have not yet been passed on.
i, j int
readHeader bool
}
// Reset discards any buffered data, resets all state, and switches the Snappy
// reader to read from r. This permits reusing a Reader rather than allocating
// a new one.
func (r *Reader) Reset(reader io.Reader) {
r.r = reader
r.err = nil
r.i = 0
r.j = 0
r.readHeader = false
}
func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) {
if _, r.err = io.ReadFull(r.r, p); r.err != nil {
if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
r.err = ErrCorrupt
}
return false
}
return true
}
// Read satisfies the io.Reader interface.
func (r *Reader) Read(p []byte) (int, error) {
if r.err != nil {
return 0, r.err
}
for {
if r.i < r.j {
n := copy(p, r.decoded[r.i:r.j])
r.i += n
return n, nil
}
if !r.readFull(r.buf[:4], true) {
return 0, r.err
}
chunkType := r.buf[0]
if !r.readHeader {
if chunkType != chunkTypeStreamIdentifier {
r.err = ErrCorrupt
return 0, r.err
}
r.readHeader = true
}
chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
if chunkLen > len(r.buf) {
r.err = ErrUnsupported
return 0, r.err
}
// The chunk types are specified at
// https://github.com/google/snappy/blob/master/framing_format.txt
switch chunkType {
case chunkTypeCompressedData:
// Section 4.2. Compressed data (chunk type 0x00).
if chunkLen < checksumSize {
r.err = ErrCorrupt
return 0, r.err
}
buf := r.buf[:chunkLen]
if !r.readFull(buf, false) {
return 0, r.err
}
checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
buf = buf[checksumSize:]
n, err := DecodedLen(buf)
if err != nil {
r.err = err
return 0, r.err
}
if n > len(r.decoded) {
r.err = ErrCorrupt
return 0, r.err
}
if _, err := Decode(r.decoded, buf); err != nil {
r.err = err
return 0, r.err
}
if crc(r.decoded[:n]) != checksum {
r.err = ErrCorrupt
return 0, r.err
}
r.i, r.j = 0, n
continue
case chunkTypeUncompressedData:
// Section 4.3. Uncompressed data (chunk type 0x01).
if chunkLen < checksumSize {
r.err = ErrCorrupt
return 0, r.err
}
buf := r.buf[:checksumSize]
if !r.readFull(buf, false) {
return 0, r.err
}
checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
// Read directly into r.decoded instead of via r.buf.
n := chunkLen - checksumSize
if n > len(r.decoded) {
r.err = ErrCorrupt
return 0, r.err
}
if !r.readFull(r.decoded[:n], false) {
return 0, r.err
}
if crc(r.decoded[:n]) != checksum {
r.err = ErrCorrupt
return 0, r.err
}
r.i, r.j = 0, n
continue
case chunkTypeStreamIdentifier:
// Section 4.1. Stream identifier (chunk type 0xff).
if chunkLen != len(magicBody) {
r.err = ErrCorrupt
return 0, r.err
}
if !r.readFull(r.buf[:len(magicBody)], false) {
return 0, r.err
}
for i := 0; i < len(magicBody); i++ {
if r.buf[i] != magicBody[i] {
r.err = ErrCorrupt
return 0, r.err
}
}
continue
}
if chunkType <= 0x7f {
// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
r.err = ErrUnsupported
return 0, r.err
}
// Section 4.4 Padding (chunk type 0xfe).
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
if !r.readFull(r.buf[:chunkLen], false) {
return 0, r.err
}
}
}

View File

@ -0,0 +1,14 @@
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
package snappy
// decode has the same semantics as in decode_other.go.
//
//go:noescape
func decode(dst, src []byte) int

490
lib/snappy/decode_amd64.s Normal file
View File

@ -0,0 +1,490 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
#include "textflag.h"
// The asm code generally follows the pure Go code in decode_other.go, except
// where marked with a "!!!".
// func decode(dst, src []byte) int
//
// All local variables fit into registers. The non-zero stack size is only to
// spill registers and push args when issuing a CALL. The register allocation:
// - AX scratch
// - BX scratch
// - CX length or x
// - DX offset
// - SI &src[s]
// - DI &dst[d]
// + R8 dst_base
// + R9 dst_len
// + R10 dst_base + dst_len
// + R11 src_base
// + R12 src_len
// + R13 src_base + src_len
// - R14 used by doCopy
// - R15 used by doCopy
//
// The registers R8-R13 (marked with a "+") are set at the start of the
// function, and after a CALL returns, and are not otherwise modified.
//
// The d variable is implicitly DI - R8, and len(dst)-d is R10 - DI.
// The s variable is implicitly SI - R11, and len(src)-s is R13 - SI.
TEXT ·decode(SB), NOSPLIT, $48-56
// Initialize SI, DI and R8-R13.
MOVQ dst_base+0(FP), R8
MOVQ dst_len+8(FP), R9
MOVQ R8, DI
MOVQ R8, R10
ADDQ R9, R10
MOVQ src_base+24(FP), R11
MOVQ src_len+32(FP), R12
MOVQ R11, SI
MOVQ R11, R13
ADDQ R12, R13
loop:
// for s < len(src)
CMPQ SI, R13
JEQ end
// CX = uint32(src[s])
//
// switch src[s] & 0x03
MOVBLZX (SI), CX
MOVL CX, BX
ANDL $3, BX
CMPL BX, $1
JAE tagCopy
// ----------------------------------------
// The code below handles literal tags.
// case tagLiteral:
// x := uint32(src[s] >> 2)
// switch
SHRL $2, CX
CMPL CX, $60
JAE tagLit60Plus
// case x < 60:
// s++
INCQ SI
doLit:
// This is the end of the inner "switch", when we have a literal tag.
//
// We assume that CX == x and x fits in a uint32, where x is the variable
// used in the pure Go decode_other.go code.
// length = int(x) + 1
//
// Unlike the pure Go code, we don't need to check if length <= 0 because
// CX can hold 64 bits, so the increment cannot overflow.
INCQ CX
// Prepare to check if copying length bytes will run past the end of dst or
// src.
//
// AX = len(dst) - d
// BX = len(src) - s
MOVQ R10, AX
SUBQ DI, AX
MOVQ R13, BX
SUBQ SI, BX
// !!! Try a faster technique for short (16 or fewer bytes) copies.
//
// if length > 16 || len(dst)-d < 16 || len(src)-s < 16 {
// goto callMemmove // Fall back on calling runtime·memmove.
// }
//
// The C++ snappy code calls this TryFastAppend. It also checks len(src)-s
// against 21 instead of 16, because it cannot assume that all of its input
// is contiguous in memory and so it needs to leave enough source bytes to
// read the next tag without refilling buffers, but Go's Decode assumes
// contiguousness (the src argument is a []byte).
CMPQ CX, $16
JGT callMemmove
CMPQ AX, $16
JLT callMemmove
CMPQ BX, $16
JLT callMemmove
// !!! Implement the copy from src to dst as a 16-byte load and store.
// (Decode's documentation says that dst and src must not overlap.)
//
// This always copies 16 bytes, instead of only length bytes, but that's
// OK. If the input is a valid Snappy encoding then subsequent iterations
// will fix up the overrun. Otherwise, Decode returns a nil []byte (and a
// non-nil error), so the overrun will be ignored.
//
// Note that on amd64, it is legal and cheap to issue unaligned 8-byte or
// 16-byte loads and stores. This technique probably wouldn't be as
// effective on architectures that are fussier about alignment.
MOVOU 0(SI), X0
MOVOU X0, 0(DI)
// d += length
// s += length
ADDQ CX, DI
ADDQ CX, SI
JMP loop
callMemmove:
// if length > len(dst)-d || length > len(src)-s { etc }
CMPQ CX, AX
JGT errCorrupt
CMPQ CX, BX
JGT errCorrupt
// copy(dst[d:], src[s:s+length])
//
// This means calling runtime·memmove(&dst[d], &src[s], length), so we push
// DI, SI and CX as arguments. Coincidentally, we also need to spill those
// three registers to the stack, to save local variables across the CALL.
MOVQ DI, 0(SP)
MOVQ SI, 8(SP)
MOVQ CX, 16(SP)
MOVQ DI, 24(SP)
MOVQ SI, 32(SP)
MOVQ CX, 40(SP)
CALL runtime·memmove(SB)
// Restore local variables: unspill registers from the stack and
// re-calculate R8-R13.
MOVQ 24(SP), DI
MOVQ 32(SP), SI
MOVQ 40(SP), CX
MOVQ dst_base+0(FP), R8
MOVQ dst_len+8(FP), R9
MOVQ R8, R10
ADDQ R9, R10
MOVQ src_base+24(FP), R11
MOVQ src_len+32(FP), R12
MOVQ R11, R13
ADDQ R12, R13
// d += length
// s += length
ADDQ CX, DI
ADDQ CX, SI
JMP loop
tagLit60Plus:
// !!! This fragment does the
//
// s += x - 58; if uint(s) > uint(len(src)) { etc }
//
// checks. In the asm version, we code it once instead of once per switch case.
ADDQ CX, SI
SUBQ $58, SI
MOVQ SI, BX
SUBQ R11, BX
CMPQ BX, R12
JA errCorrupt
// case x == 60:
CMPL CX, $61
JEQ tagLit61
JA tagLit62Plus
// x = uint32(src[s-1])
MOVBLZX -1(SI), CX
JMP doLit
tagLit61:
// case x == 61:
// x = uint32(src[s-2]) | uint32(src[s-1])<<8
MOVWLZX -2(SI), CX
JMP doLit
tagLit62Plus:
CMPL CX, $62
JA tagLit63
// case x == 62:
// x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
MOVWLZX -3(SI), CX
MOVBLZX -1(SI), BX
SHLL $16, BX
ORL BX, CX
JMP doLit
tagLit63:
// case x == 63:
// x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
MOVL -4(SI), CX
JMP doLit
// The code above handles literal tags.
// ----------------------------------------
// The code below handles copy tags.
tagCopy4:
// case tagCopy4:
// s += 5
ADDQ $5, SI
// if uint(s) > uint(len(src)) { etc }
MOVQ SI, BX
SUBQ R11, BX
CMPQ BX, R12
JA errCorrupt
// length = 1 + int(src[s-5])>>2
SHRQ $2, CX
INCQ CX
// offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
MOVLQZX -4(SI), DX
JMP doCopy
tagCopy2:
// case tagCopy2:
// s += 3
ADDQ $3, SI
// if uint(s) > uint(len(src)) { etc }
MOVQ SI, BX
SUBQ R11, BX
CMPQ BX, R12
JA errCorrupt
// length = 1 + int(src[s-3])>>2
SHRQ $2, CX
INCQ CX
// offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
MOVWQZX -2(SI), DX
JMP doCopy
tagCopy:
// We have a copy tag. We assume that:
// - BX == src[s] & 0x03
// - CX == src[s]
CMPQ BX, $2
JEQ tagCopy2
JA tagCopy4
// case tagCopy1:
// s += 2
ADDQ $2, SI
// if uint(s) > uint(len(src)) { etc }
MOVQ SI, BX
SUBQ R11, BX
CMPQ BX, R12
JA errCorrupt
// offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
MOVQ CX, DX
ANDQ $0xe0, DX
SHLQ $3, DX
MOVBQZX -1(SI), BX
ORQ BX, DX
// length = 4 + int(src[s-2])>>2&0x7
SHRQ $2, CX
ANDQ $7, CX
ADDQ $4, CX
doCopy:
// This is the end of the outer "switch", when we have a copy tag.
//
// We assume that:
// - CX == length && CX > 0
// - DX == offset
// if offset <= 0 { etc }
CMPQ DX, $0
JLE errCorrupt
// if d < offset { etc }
MOVQ DI, BX
SUBQ R8, BX
CMPQ BX, DX
JLT errCorrupt
// if length > len(dst)-d { etc }
MOVQ R10, BX
SUBQ DI, BX
CMPQ CX, BX
JGT errCorrupt
// forwardCopy(dst[d:d+length], dst[d-offset:]); d += length
//
// Set:
// - R14 = len(dst)-d
// - R15 = &dst[d-offset]
MOVQ R10, R14
SUBQ DI, R14
MOVQ DI, R15
SUBQ DX, R15
// !!! Try a faster technique for short (16 or fewer bytes) forward copies.
//
// First, try using two 8-byte load/stores, similar to the doLit technique
// above. Even if dst[d:d+length] and dst[d-offset:] can overlap, this is
// still OK if offset >= 8. Note that this has to be two 8-byte load/stores
// and not one 16-byte load/store, and the first store has to be before the
// second load, due to the overlap if offset is in the range [8, 16).
//
// if length > 16 || offset < 8 || len(dst)-d < 16 {
// goto slowForwardCopy
// }
// copy 16 bytes
// d += length
CMPQ CX, $16
JGT slowForwardCopy
CMPQ DX, $8
JLT slowForwardCopy
CMPQ R14, $16
JLT slowForwardCopy
MOVQ 0(R15), AX
MOVQ AX, 0(DI)
MOVQ 8(R15), BX
MOVQ BX, 8(DI)
ADDQ CX, DI
JMP loop
slowForwardCopy:
// !!! If the forward copy is longer than 16 bytes, or if offset < 8, we
// can still try 8-byte load stores, provided we can overrun up to 10 extra
// bytes. As above, the overrun will be fixed up by subsequent iterations
// of the outermost loop.
//
// The C++ snappy code calls this technique IncrementalCopyFastPath. Its
// commentary says:
//
// ----
//
// The main part of this loop is a simple copy of eight bytes at a time
// until we've copied (at least) the requested amount of bytes. However,
// if d and d-offset are less than eight bytes apart (indicating a
// repeating pattern of length < 8), we first need to expand the pattern in
// order to get the correct results. For instance, if the buffer looks like
// this, with the eight-byte <d-offset> and <d> patterns marked as
// intervals:
//
// abxxxxxxxxxxxx
// [------] d-offset
// [------] d
//
// a single eight-byte copy from <d-offset> to <d> will repeat the pattern
// once, after which we can move <d> two bytes without moving <d-offset>:
//
// ababxxxxxxxxxx
// [------] d-offset
// [------] d
//
// and repeat the exercise until the two no longer overlap.
//
// This allows us to do very well in the special case of one single byte
// repeated many times, without taking a big hit for more general cases.
//
// The worst case of extra writing past the end of the match occurs when
// offset == 1 and length == 1; the last copy will read from byte positions
// [0..7] and write to [4..11], whereas it was only supposed to write to
// position 1. Thus, ten excess bytes.
//
// ----
//
// That "10 byte overrun" worst case is confirmed by Go's
// TestSlowForwardCopyOverrun, which also tests the fixUpSlowForwardCopy
// and finishSlowForwardCopy algorithm.
//
// if length > len(dst)-d-10 {
// goto verySlowForwardCopy
// }
SUBQ $10, R14
CMPQ CX, R14
JGT verySlowForwardCopy
makeOffsetAtLeast8:
// !!! As above, expand the pattern so that offset >= 8 and we can use
// 8-byte load/stores.
//
// for offset < 8 {
// copy 8 bytes from dst[d-offset:] to dst[d:]
// length -= offset
// d += offset
// offset += offset
// // The two previous lines together means that d-offset, and therefore
// // R15, is unchanged.
// }
CMPQ DX, $8
JGE fixUpSlowForwardCopy
MOVQ (R15), BX
MOVQ BX, (DI)
SUBQ DX, CX
ADDQ DX, DI
ADDQ DX, DX
JMP makeOffsetAtLeast8
fixUpSlowForwardCopy:
// !!! Add length (which might be negative now) to d (implied by DI being
// &dst[d]) so that d ends up at the right place when we jump back to the
// top of the loop. Before we do that, though, we save DI to AX so that, if
// length is positive, copying the remaining length bytes will write to the
// right place.
MOVQ DI, AX
ADDQ CX, DI
finishSlowForwardCopy:
// !!! Repeat 8-byte load/stores until length <= 0. Ending with a negative
// length means that we overrun, but as above, that will be fixed up by
// subsequent iterations of the outermost loop.
CMPQ CX, $0
JLE loop
MOVQ (R15), BX
MOVQ BX, (AX)
ADDQ $8, R15
ADDQ $8, AX
SUBQ $8, CX
JMP finishSlowForwardCopy
verySlowForwardCopy:
// verySlowForwardCopy is a simple implementation of forward copy. In C
// parlance, this is a do/while loop instead of a while loop, since we know
// that length > 0. In Go syntax:
//
// for {
// dst[d] = dst[d - offset]
// d++
// length--
// if length == 0 {
// break
// }
// }
MOVB (R15), BX
MOVB BX, (DI)
INCQ R15
INCQ DI
DECQ CX
JNZ verySlowForwardCopy
JMP loop
// The code above handles copy tags.
// ----------------------------------------
end:
// This is the end of the "for s < len(src)".
//
// if d != len(dst) { etc }
CMPQ DI, R10
JNE errCorrupt
// return 0
MOVQ $0, ret+48(FP)
RET
errCorrupt:
// return decodeErrCodeCorrupt
MOVQ $1, ret+48(FP)
RET

101
lib/snappy/decode_other.go Normal file
View File

@ -0,0 +1,101 @@
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64 appengine !gc noasm
package snappy
// decode writes the decoding of src to dst. It assumes that the varint-encoded
// length of the decompressed bytes has already been read, and that len(dst)
// equals that length.
//
// It returns 0 on success or a decodeErrCodeXxx error code on failure.
func decode(dst, src []byte) int {
var d, s, offset, length int
for s < len(src) {
switch src[s] & 0x03 {
case tagLiteral:
x := uint32(src[s] >> 2)
switch {
case x < 60:
s++
case x == 60:
s += 2
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
x = uint32(src[s-1])
case x == 61:
s += 3
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
x = uint32(src[s-2]) | uint32(src[s-1])<<8
case x == 62:
s += 4
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
case x == 63:
s += 5
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
}
length = int(x) + 1
if length <= 0 {
return decodeErrCodeUnsupportedLiteralLength
}
if length > len(dst)-d || length > len(src)-s {
return decodeErrCodeCorrupt
}
copy(dst[d:], src[s:s+length])
d += length
s += length
continue
case tagCopy1:
s += 2
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
length = 4 + int(src[s-2])>>2&0x7
offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
case tagCopy2:
s += 3
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
length = 1 + int(src[s-3])>>2
offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
case tagCopy4:
s += 5
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
length = 1 + int(src[s-5])>>2
offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
}
if offset <= 0 || d < offset || length > len(dst)-d {
return decodeErrCodeCorrupt
}
// Copy from an earlier sub-slice of dst to a later sub-slice. Unlike
// the built-in copy function, this byte-by-byte copy always runs
// forwards, even if the slices overlap. Conceptually, this is:
//
// d += forwardCopy(dst[d:d+length], dst[d-offset:])
for end := d + length; d != end; d++ {
dst[d] = dst[d-offset]
}
}
if d != len(dst) {
return decodeErrCodeCorrupt
}
return 0
}

285
lib/snappy/encode.go Normal file
View File

@ -0,0 +1,285 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package snappy
import (
"encoding/binary"
"errors"
"io"
)
// Encode returns the encoded form of src. The returned slice may be a sub-
// slice of dst if dst was large enough to hold the entire encoded block.
// Otherwise, a newly allocated slice will be returned.
//
// The dst and src must not overlap. It is valid to pass a nil dst.
func Encode(dst, src []byte) []byte {
if n := MaxEncodedLen(len(src)); n < 0 {
panic(ErrTooLarge)
} else if len(dst) < n {
dst = make([]byte, n)
}
// The block starts with the varint-encoded length of the decompressed bytes.
d := binary.PutUvarint(dst, uint64(len(src)))
for len(src) > 0 {
p := src
src = nil
if len(p) > maxBlockSize {
p, src = p[:maxBlockSize], p[maxBlockSize:]
}
if len(p) < minNonLiteralBlockSize {
d += emitLiteral(dst[d:], p)
} else {
d += encodeBlock(dst[d:], p)
}
}
return dst[:d]
}
// inputMargin is the minimum number of extra input bytes to keep, inside
// encodeBlock's inner loop. On some architectures, this margin lets us
// implement a fast path for emitLiteral, where the copy of short (<= 16 byte)
// literals can be implemented as a single load to and store from a 16-byte
// register. That literal's actual length can be as short as 1 byte, so this
// can copy up to 15 bytes too much, but that's OK as subsequent iterations of
// the encoding loop will fix up the copy overrun, and this inputMargin ensures
// that we don't overrun the dst and src buffers.
const inputMargin = 16 - 1
// minNonLiteralBlockSize is the minimum size of the input to encodeBlock that
// could be encoded with a copy tag. This is the minimum with respect to the
// algorithm used by encodeBlock, not a minimum enforced by the file format.
//
// The encoded output must start with at least a 1 byte literal, as there are
// no previous bytes to copy. A minimal (1 byte) copy after that, generated
// from an emitCopy call in encodeBlock's main loop, would require at least
// another inputMargin bytes, for the reason above: we want any emitLiteral
// calls inside encodeBlock's main loop to use the fast path if possible, which
// requires being able to overrun by inputMargin bytes. Thus,
// minNonLiteralBlockSize equals 1 + 1 + inputMargin.
//
// The C++ code doesn't use this exact threshold, but it could, as discussed at
// https://groups.google.com/d/topic/snappy-compression/oGbhsdIJSJ8/discussion
// The difference between Go (2+inputMargin) and C++ (inputMargin) is purely an
// optimization. It should not affect the encoded form. This is tested by
// TestSameEncodingAsCppShortCopies.
const minNonLiteralBlockSize = 1 + 1 + inputMargin
// MaxEncodedLen returns the maximum length of a snappy block, given its
// uncompressed length.
//
// It will return a negative value if srcLen is too large to encode.
func MaxEncodedLen(srcLen int) int {
n := uint64(srcLen)
if n > 0xffffffff {
return -1
}
// Compressed data can be defined as:
// compressed := item* literal*
// item := literal* copy
//
// The trailing literal sequence has a space blowup of at most 62/60
// since a literal of length 60 needs one tag byte + one extra byte
// for length information.
//
// Item blowup is trickier to measure. Suppose the "copy" op copies
// 4 bytes of data. Because of a special check in the encoding code,
// we produce a 4-byte copy only if the offset is < 65536. Therefore
// the copy op takes 3 bytes to encode, and this type of item leads
// to at most the 62/60 blowup for representing literals.
//
// Suppose the "copy" op copies 5 bytes of data. If the offset is big
// enough, it will take 5 bytes to encode the copy op. Therefore the
// worst case here is a one-byte literal followed by a five-byte copy.
// That is, 6 bytes of input turn into 7 bytes of "compressed" data.
//
// This last factor dominates the blowup, so the final estimate is:
n = 32 + n + n/6
if n > 0xffffffff {
return -1
}
return int(n)
}
var errClosed = errors.New("snappy: Writer is closed")
// NewWriter returns a new Writer that compresses to w.
//
// The Writer returned does not buffer writes. There is no need to Flush or
// Close such a Writer.
//
// Deprecated: the Writer returned is not suitable for many small writes, only
// for few large writes. Use NewBufferedWriter instead, which is efficient
// regardless of the frequency and shape of the writes, and remember to Close
// that Writer when done.
func NewWriter(w io.Writer) *Writer {
return &Writer{
w: w,
obuf: make([]byte, obufLen),
}
}
// NewBufferedWriter returns a new Writer that compresses to w, using the
// framing format described at
// https://github.com/google/snappy/blob/master/framing_format.txt
//
// The Writer returned buffers writes. Users must call Close to guarantee all
// data has been forwarded to the underlying io.Writer. They may also call
// Flush zero or more times before calling Close.
func NewBufferedWriter(w io.Writer) *Writer {
return &Writer{
w: w,
ibuf: make([]byte, 0, maxBlockSize),
obuf: make([]byte, obufLen),
}
}
// Writer is an io.Writer that can write Snappy-compressed bytes.
type Writer struct {
w io.Writer
err error
// ibuf is a buffer for the incoming (uncompressed) bytes.
//
// Its use is optional. For backwards compatibility, Writers created by the
// NewWriter function have ibuf == nil, do not buffer incoming bytes, and
// therefore do not need to be Flush'ed or Close'd.
ibuf []byte
// obuf is a buffer for the outgoing (compressed) bytes.
obuf []byte
// wroteStreamHeader is whether we have written the stream header.
wroteStreamHeader bool
}
// Reset discards the writer's state and switches the Snappy writer to write to
// w. This permits reusing a Writer rather than allocating a new one.
func (w *Writer) Reset(writer io.Writer) {
w.w = writer
w.err = nil
if w.ibuf != nil {
w.ibuf = w.ibuf[:0]
}
w.wroteStreamHeader = false
}
// Write satisfies the io.Writer interface.
func (w *Writer) Write(p []byte) (nRet int, errRet error) {
if w.ibuf == nil {
// Do not buffer incoming bytes. This does not perform or compress well
// if the caller of Writer.Write writes many small slices. This
// behavior is therefore deprecated, but still supported for backwards
// compatibility with code that doesn't explicitly Flush or Close.
return w.write(p)
}
// The remainder of this method is based on bufio.Writer.Write from the
// standard library.
for len(p) > (cap(w.ibuf)-len(w.ibuf)) && w.err == nil {
var n int
if len(w.ibuf) == 0 {
// Large write, empty buffer.
// Write directly from p to avoid copy.
n, _ = w.write(p)
} else {
n = copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
w.ibuf = w.ibuf[:len(w.ibuf)+n]
w.Flush()
}
nRet += n
p = p[n:]
}
if w.err != nil {
return nRet, w.err
}
n := copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
w.ibuf = w.ibuf[:len(w.ibuf)+n]
nRet += n
return nRet, nil
}
func (w *Writer) write(p []byte) (nRet int, errRet error) {
if w.err != nil {
return 0, w.err
}
for len(p) > 0 {
obufStart := len(magicChunk)
if !w.wroteStreamHeader {
w.wroteStreamHeader = true
copy(w.obuf, magicChunk)
obufStart = 0
}
var uncompressed []byte
if len(p) > maxBlockSize {
uncompressed, p = p[:maxBlockSize], p[maxBlockSize:]
} else {
uncompressed, p = p, nil
}
checksum := crc(uncompressed)
// Compress the buffer, discarding the result if the improvement
// isn't at least 12.5%.
compressed := Encode(w.obuf[obufHeaderLen:], uncompressed)
chunkType := uint8(chunkTypeCompressedData)
chunkLen := 4 + len(compressed)
obufEnd := obufHeaderLen + len(compressed)
if len(compressed) >= len(uncompressed)-len(uncompressed)/8 {
chunkType = chunkTypeUncompressedData
chunkLen = 4 + len(uncompressed)
obufEnd = obufHeaderLen
}
// Fill in the per-chunk header that comes before the body.
w.obuf[len(magicChunk)+0] = chunkType
w.obuf[len(magicChunk)+1] = uint8(chunkLen >> 0)
w.obuf[len(magicChunk)+2] = uint8(chunkLen >> 8)
w.obuf[len(magicChunk)+3] = uint8(chunkLen >> 16)
w.obuf[len(magicChunk)+4] = uint8(checksum >> 0)
w.obuf[len(magicChunk)+5] = uint8(checksum >> 8)
w.obuf[len(magicChunk)+6] = uint8(checksum >> 16)
w.obuf[len(magicChunk)+7] = uint8(checksum >> 24)
if _, err := w.w.Write(w.obuf[obufStart:obufEnd]); err != nil {
w.err = err
return nRet, err
}
if chunkType == chunkTypeUncompressedData {
if _, err := w.w.Write(uncompressed); err != nil {
w.err = err
return nRet, err
}
}
nRet += len(uncompressed)
}
return nRet, nil
}
// Flush flushes the Writer to its underlying io.Writer.
func (w *Writer) Flush() error {
if w.err != nil {
return w.err
}
if len(w.ibuf) == 0 {
return nil
}
w.write(w.ibuf)
w.ibuf = w.ibuf[:0]
return w.err
}
// Close calls Flush and then closes the Writer.
func (w *Writer) Close() error {
w.Flush()
ret := w.err
if w.err == nil {
w.err = errClosed
}
return ret
}

View File

@ -0,0 +1,29 @@
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
package snappy
// emitLiteral has the same semantics as in encode_other.go.
//
//go:noescape
func emitLiteral(dst, lit []byte) int
// emitCopy has the same semantics as in encode_other.go.
//
//go:noescape
func emitCopy(dst []byte, offset, length int) int
// extendMatch has the same semantics as in encode_other.go.
//
//go:noescape
func extendMatch(src []byte, i, j int) int
// encodeBlock has the same semantics as in encode_other.go.
//
//go:noescape
func encodeBlock(dst, src []byte) (d int)

730
lib/snappy/encode_amd64.s Normal file
View File

@ -0,0 +1,730 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
#include "textflag.h"
// The XXX lines assemble on Go 1.4, 1.5 and 1.7, but not 1.6, due to a
// Go toolchain regression. See https://github.com/golang/go/issues/15426 and
// https://github.com/golang/snappy/issues/29
//
// As a workaround, the package was built with a known good assembler, and
// those instructions were disassembled by "objdump -d" to yield the
// 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15
// style comments, in AT&T asm syntax. Note that rsp here is a physical
// register, not Go/asm's SP pseudo-register (see https://golang.org/doc/asm).
// The instructions were then encoded as "BYTE $0x.." sequences, which assemble
// fine on Go 1.6.
// The asm code generally follows the pure Go code in encode_other.go, except
// where marked with a "!!!".
// ----------------------------------------------------------------------------
// func emitLiteral(dst, lit []byte) int
//
// All local variables fit into registers. The register allocation:
// - AX len(lit)
// - BX n
// - DX return value
// - DI &dst[i]
// - R10 &lit[0]
//
// The 24 bytes of stack space is to call runtime·memmove.
//
// The unusual register allocation of local variables, such as R10 for the
// source pointer, matches the allocation used at the call site in encodeBlock,
// which makes it easier to manually inline this function.
TEXT ·emitLiteral(SB), NOSPLIT, $24-56
MOVQ dst_base+0(FP), DI
MOVQ lit_base+24(FP), R10
MOVQ lit_len+32(FP), AX
MOVQ AX, DX
MOVL AX, BX
SUBL $1, BX
CMPL BX, $60
JLT oneByte
CMPL BX, $256
JLT twoBytes
threeBytes:
MOVB $0xf4, 0(DI)
MOVW BX, 1(DI)
ADDQ $3, DI
ADDQ $3, DX
JMP memmove
twoBytes:
MOVB $0xf0, 0(DI)
MOVB BX, 1(DI)
ADDQ $2, DI
ADDQ $2, DX
JMP memmove
oneByte:
SHLB $2, BX
MOVB BX, 0(DI)
ADDQ $1, DI
ADDQ $1, DX
memmove:
MOVQ DX, ret+48(FP)
// copy(dst[i:], lit)
//
// This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push
// DI, R10 and AX as arguments.
MOVQ DI, 0(SP)
MOVQ R10, 8(SP)
MOVQ AX, 16(SP)
CALL runtime·memmove(SB)
RET
// ----------------------------------------------------------------------------
// func emitCopy(dst []byte, offset, length int) int
//
// All local variables fit into registers. The register allocation:
// - AX length
// - SI &dst[0]
// - DI &dst[i]
// - R11 offset
//
// The unusual register allocation of local variables, such as R11 for the
// offset, matches the allocation used at the call site in encodeBlock, which
// makes it easier to manually inline this function.
TEXT ·emitCopy(SB), NOSPLIT, $0-48
MOVQ dst_base+0(FP), DI
MOVQ DI, SI
MOVQ offset+24(FP), R11
MOVQ length+32(FP), AX
loop0:
// for length >= 68 { etc }
CMPL AX, $68
JLT step1
// Emit a length 64 copy, encoded as 3 bytes.
MOVB $0xfe, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
SUBL $64, AX
JMP loop0
step1:
// if length > 64 { etc }
CMPL AX, $64
JLE step2
// Emit a length 60 copy, encoded as 3 bytes.
MOVB $0xee, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
SUBL $60, AX
step2:
// if length >= 12 || offset >= 2048 { goto step3 }
CMPL AX, $12
JGE step3
CMPL R11, $2048
JGE step3
// Emit the remaining copy, encoded as 2 bytes.
MOVB R11, 1(DI)
SHRL $8, R11
SHLB $5, R11
SUBB $4, AX
SHLB $2, AX
ORB AX, R11
ORB $1, R11
MOVB R11, 0(DI)
ADDQ $2, DI
// Return the number of bytes written.
SUBQ SI, DI
MOVQ DI, ret+40(FP)
RET
step3:
// Emit the remaining copy, encoded as 3 bytes.
SUBL $1, AX
SHLB $2, AX
ORB $2, AX
MOVB AX, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
// Return the number of bytes written.
SUBQ SI, DI
MOVQ DI, ret+40(FP)
RET
// ----------------------------------------------------------------------------
// func extendMatch(src []byte, i, j int) int
//
// All local variables fit into registers. The register allocation:
// - DX &src[0]
// - SI &src[j]
// - R13 &src[len(src) - 8]
// - R14 &src[len(src)]
// - R15 &src[i]
//
// The unusual register allocation of local variables, such as R15 for a source
// pointer, matches the allocation used at the call site in encodeBlock, which
// makes it easier to manually inline this function.
TEXT ·extendMatch(SB), NOSPLIT, $0-48
MOVQ src_base+0(FP), DX
MOVQ src_len+8(FP), R14
MOVQ i+24(FP), R15
MOVQ j+32(FP), SI
ADDQ DX, R14
ADDQ DX, R15
ADDQ DX, SI
MOVQ R14, R13
SUBQ $8, R13
cmp8:
// As long as we are 8 or more bytes before the end of src, we can load and
// compare 8 bytes at a time. If those 8 bytes are equal, repeat.
CMPQ SI, R13
JA cmp1
MOVQ (R15), AX
MOVQ (SI), BX
CMPQ AX, BX
JNE bsf
ADDQ $8, R15
ADDQ $8, SI
JMP cmp8
bsf:
// If those 8 bytes were not equal, XOR the two 8 byte values, and return
// the index of the first byte that differs. The BSF instruction finds the
// least significant 1 bit, the amd64 architecture is little-endian, and
// the shift by 3 converts a bit index to a byte index.
XORQ AX, BX
BSFQ BX, BX
SHRQ $3, BX
ADDQ BX, SI
// Convert from &src[ret] to ret.
SUBQ DX, SI
MOVQ SI, ret+40(FP)
RET
cmp1:
// In src's tail, compare 1 byte at a time.
CMPQ SI, R14
JAE extendMatchEnd
MOVB (R15), AX
MOVB (SI), BX
CMPB AX, BX
JNE extendMatchEnd
ADDQ $1, R15
ADDQ $1, SI
JMP cmp1
extendMatchEnd:
// Convert from &src[ret] to ret.
SUBQ DX, SI
MOVQ SI, ret+40(FP)
RET
// ----------------------------------------------------------------------------
// func encodeBlock(dst, src []byte) (d int)
//
// All local variables fit into registers, other than "var table". The register
// allocation:
// - AX . .
// - BX . .
// - CX 56 shift (note that amd64 shifts by non-immediates must use CX).
// - DX 64 &src[0], tableSize
// - SI 72 &src[s]
// - DI 80 &dst[d]
// - R9 88 sLimit
// - R10 . &src[nextEmit]
// - R11 96 prevHash, currHash, nextHash, offset
// - R12 104 &src[base], skip
// - R13 . &src[nextS], &src[len(src) - 8]
// - R14 . len(src), bytesBetweenHashLookups, &src[len(src)], x
// - R15 112 candidate
//
// The second column (56, 64, etc) is the stack offset to spill the registers
// when calling other functions. We could pack this slightly tighter, but it's
// simpler to have a dedicated spill map independent of the function called.
//
// "var table [maxTableSize]uint16" takes up 32768 bytes of stack space. An
// extra 56 bytes, to call other functions, and an extra 64 bytes, to spill
// local variables (registers) during calls gives 32768 + 56 + 64 = 32888.
TEXT ·encodeBlock(SB), 0, $32888-56
MOVQ dst_base+0(FP), DI
MOVQ src_base+24(FP), SI
MOVQ src_len+32(FP), R14
// shift, tableSize := uint32(32-8), 1<<8
MOVQ $24, CX
MOVQ $256, DX
calcShift:
// for ; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 {
// shift--
// }
CMPQ DX, $16384
JGE varTable
CMPQ DX, R14
JGE varTable
SUBQ $1, CX
SHLQ $1, DX
JMP calcShift
varTable:
// var table [maxTableSize]uint16
//
// In the asm code, unlike the Go code, we can zero-initialize only the
// first tableSize elements. Each uint16 element is 2 bytes and each MOVOU
// writes 16 bytes, so we can do only tableSize/8 writes instead of the
// 2048 writes that would zero-initialize all of table's 32768 bytes.
SHRQ $3, DX
LEAQ table-32768(SP), BX
PXOR X0, X0
memclr:
MOVOU X0, 0(BX)
ADDQ $16, BX
SUBQ $1, DX
JNZ memclr
// !!! DX = &src[0]
MOVQ SI, DX
// sLimit := len(src) - inputMargin
MOVQ R14, R9
SUBQ $15, R9
// !!! Pre-emptively spill CX, DX and R9 to the stack. Their values don't
// change for the rest of the function.
MOVQ CX, 56(SP)
MOVQ DX, 64(SP)
MOVQ R9, 88(SP)
// nextEmit := 0
MOVQ DX, R10
// s := 1
ADDQ $1, SI
// nextHash := hash(load32(src, s), shift)
MOVL 0(SI), R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
outer:
// for { etc }
// skip := 32
MOVQ $32, R12
// nextS := s
MOVQ SI, R13
// candidate := 0
MOVQ $0, R15
inner0:
// for { etc }
// s := nextS
MOVQ R13, SI
// bytesBetweenHashLookups := skip >> 5
MOVQ R12, R14
SHRQ $5, R14
// nextS = s + bytesBetweenHashLookups
ADDQ R14, R13
// skip += bytesBetweenHashLookups
ADDQ R14, R12
// if nextS > sLimit { goto emitRemainder }
MOVQ R13, AX
SUBQ DX, AX
CMPQ AX, R9
JA emitRemainder
// candidate = int(table[nextHash])
// XXX: MOVWQZX table-32768(SP)(R11*2), R15
// XXX: 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15
BYTE $0x4e
BYTE $0x0f
BYTE $0xb7
BYTE $0x7c
BYTE $0x5c
BYTE $0x78
// table[nextHash] = uint16(s)
MOVQ SI, AX
SUBQ DX, AX
// XXX: MOVW AX, table-32768(SP)(R11*2)
// XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2)
BYTE $0x66
BYTE $0x42
BYTE $0x89
BYTE $0x44
BYTE $0x5c
BYTE $0x78
// nextHash = hash(load32(src, nextS), shift)
MOVL 0(R13), R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
// if load32(src, s) != load32(src, candidate) { continue } break
MOVL 0(SI), AX
MOVL (DX)(R15*1), BX
CMPL AX, BX
JNE inner0
fourByteMatch:
// As per the encode_other.go code:
//
// A 4-byte match has been found. We'll later see etc.
// !!! Jump to a fast path for short (<= 16 byte) literals. See the comment
// on inputMargin in encode.go.
MOVQ SI, AX
SUBQ R10, AX
CMPQ AX, $16
JLE emitLiteralFastPath
// ----------------------------------------
// Begin inline of the emitLiteral call.
//
// d += emitLiteral(dst[d:], src[nextEmit:s])
MOVL AX, BX
SUBL $1, BX
CMPL BX, $60
JLT inlineEmitLiteralOneByte
CMPL BX, $256
JLT inlineEmitLiteralTwoBytes
inlineEmitLiteralThreeBytes:
MOVB $0xf4, 0(DI)
MOVW BX, 1(DI)
ADDQ $3, DI
JMP inlineEmitLiteralMemmove
inlineEmitLiteralTwoBytes:
MOVB $0xf0, 0(DI)
MOVB BX, 1(DI)
ADDQ $2, DI
JMP inlineEmitLiteralMemmove
inlineEmitLiteralOneByte:
SHLB $2, BX
MOVB BX, 0(DI)
ADDQ $1, DI
inlineEmitLiteralMemmove:
// Spill local variables (registers) onto the stack; call; unspill.
//
// copy(dst[i:], lit)
//
// This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push
// DI, R10 and AX as arguments.
MOVQ DI, 0(SP)
MOVQ R10, 8(SP)
MOVQ AX, 16(SP)
ADDQ AX, DI // Finish the "d +=" part of "d += emitLiteral(etc)".
MOVQ SI, 72(SP)
MOVQ DI, 80(SP)
MOVQ R15, 112(SP)
CALL runtime·memmove(SB)
MOVQ 56(SP), CX
MOVQ 64(SP), DX
MOVQ 72(SP), SI
MOVQ 80(SP), DI
MOVQ 88(SP), R9
MOVQ 112(SP), R15
JMP inner1
inlineEmitLiteralEnd:
// End inline of the emitLiteral call.
// ----------------------------------------
emitLiteralFastPath:
// !!! Emit the 1-byte encoding "uint8(len(lit)-1)<<2".
MOVB AX, BX
SUBB $1, BX
SHLB $2, BX
MOVB BX, (DI)
ADDQ $1, DI
// !!! Implement the copy from lit to dst as a 16-byte load and store.
// (Encode's documentation says that dst and src must not overlap.)
//
// This always copies 16 bytes, instead of only len(lit) bytes, but that's
// OK. Subsequent iterations will fix up the overrun.
//
// Note that on amd64, it is legal and cheap to issue unaligned 8-byte or
// 16-byte loads and stores. This technique probably wouldn't be as
// effective on architectures that are fussier about alignment.
MOVOU 0(R10), X0
MOVOU X0, 0(DI)
ADDQ AX, DI
inner1:
// for { etc }
// base := s
MOVQ SI, R12
// !!! offset := base - candidate
MOVQ R12, R11
SUBQ R15, R11
SUBQ DX, R11
// ----------------------------------------
// Begin inline of the extendMatch call.
//
// s = extendMatch(src, candidate+4, s+4)
// !!! R14 = &src[len(src)]
MOVQ src_len+32(FP), R14
ADDQ DX, R14
// !!! R13 = &src[len(src) - 8]
MOVQ R14, R13
SUBQ $8, R13
// !!! R15 = &src[candidate + 4]
ADDQ $4, R15
ADDQ DX, R15
// !!! s += 4
ADDQ $4, SI
inlineExtendMatchCmp8:
// As long as we are 8 or more bytes before the end of src, we can load and
// compare 8 bytes at a time. If those 8 bytes are equal, repeat.
CMPQ SI, R13
JA inlineExtendMatchCmp1
MOVQ (R15), AX
MOVQ (SI), BX
CMPQ AX, BX
JNE inlineExtendMatchBSF
ADDQ $8, R15
ADDQ $8, SI
JMP inlineExtendMatchCmp8
inlineExtendMatchBSF:
// If those 8 bytes were not equal, XOR the two 8 byte values, and return
// the index of the first byte that differs. The BSF instruction finds the
// least significant 1 bit, the amd64 architecture is little-endian, and
// the shift by 3 converts a bit index to a byte index.
XORQ AX, BX
BSFQ BX, BX
SHRQ $3, BX
ADDQ BX, SI
JMP inlineExtendMatchEnd
inlineExtendMatchCmp1:
// In src's tail, compare 1 byte at a time.
CMPQ SI, R14
JAE inlineExtendMatchEnd
MOVB (R15), AX
MOVB (SI), BX
CMPB AX, BX
JNE inlineExtendMatchEnd
ADDQ $1, R15
ADDQ $1, SI
JMP inlineExtendMatchCmp1
inlineExtendMatchEnd:
// End inline of the extendMatch call.
// ----------------------------------------
// ----------------------------------------
// Begin inline of the emitCopy call.
//
// d += emitCopy(dst[d:], base-candidate, s-base)
// !!! length := s - base
MOVQ SI, AX
SUBQ R12, AX
inlineEmitCopyLoop0:
// for length >= 68 { etc }
CMPL AX, $68
JLT inlineEmitCopyStep1
// Emit a length 64 copy, encoded as 3 bytes.
MOVB $0xfe, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
SUBL $64, AX
JMP inlineEmitCopyLoop0
inlineEmitCopyStep1:
// if length > 64 { etc }
CMPL AX, $64
JLE inlineEmitCopyStep2
// Emit a length 60 copy, encoded as 3 bytes.
MOVB $0xee, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
SUBL $60, AX
inlineEmitCopyStep2:
// if length >= 12 || offset >= 2048 { goto inlineEmitCopyStep3 }
CMPL AX, $12
JGE inlineEmitCopyStep3
CMPL R11, $2048
JGE inlineEmitCopyStep3
// Emit the remaining copy, encoded as 2 bytes.
MOVB R11, 1(DI)
SHRL $8, R11
SHLB $5, R11
SUBB $4, AX
SHLB $2, AX
ORB AX, R11
ORB $1, R11
MOVB R11, 0(DI)
ADDQ $2, DI
JMP inlineEmitCopyEnd
inlineEmitCopyStep3:
// Emit the remaining copy, encoded as 3 bytes.
SUBL $1, AX
SHLB $2, AX
ORB $2, AX
MOVB AX, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
inlineEmitCopyEnd:
// End inline of the emitCopy call.
// ----------------------------------------
// nextEmit = s
MOVQ SI, R10
// if s >= sLimit { goto emitRemainder }
MOVQ SI, AX
SUBQ DX, AX
CMPQ AX, R9
JAE emitRemainder
// As per the encode_other.go code:
//
// We could immediately etc.
// x := load64(src, s-1)
MOVQ -1(SI), R14
// prevHash := hash(uint32(x>>0), shift)
MOVL R14, R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
// table[prevHash] = uint16(s-1)
MOVQ SI, AX
SUBQ DX, AX
SUBQ $1, AX
// XXX: MOVW AX, table-32768(SP)(R11*2)
// XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2)
BYTE $0x66
BYTE $0x42
BYTE $0x89
BYTE $0x44
BYTE $0x5c
BYTE $0x78
// currHash := hash(uint32(x>>8), shift)
SHRQ $8, R14
MOVL R14, R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
// candidate = int(table[currHash])
// XXX: MOVWQZX table-32768(SP)(R11*2), R15
// XXX: 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15
BYTE $0x4e
BYTE $0x0f
BYTE $0xb7
BYTE $0x7c
BYTE $0x5c
BYTE $0x78
// table[currHash] = uint16(s)
ADDQ $1, AX
// XXX: MOVW AX, table-32768(SP)(R11*2)
// XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2)
BYTE $0x66
BYTE $0x42
BYTE $0x89
BYTE $0x44
BYTE $0x5c
BYTE $0x78
// if uint32(x>>8) == load32(src, candidate) { continue }
MOVL (DX)(R15*1), BX
CMPL R14, BX
JEQ inner1
// nextHash = hash(uint32(x>>16), shift)
SHRQ $8, R14
MOVL R14, R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
// s++
ADDQ $1, SI
// break out of the inner1 for loop, i.e. continue the outer loop.
JMP outer
emitRemainder:
// if nextEmit < len(src) { etc }
MOVQ src_len+32(FP), AX
ADDQ DX, AX
CMPQ R10, AX
JEQ encodeBlockEnd
// d += emitLiteral(dst[d:], src[nextEmit:])
//
// Push args.
MOVQ DI, 0(SP)
MOVQ $0, 8(SP) // Unnecessary, as the callee ignores it, but conservative.
MOVQ $0, 16(SP) // Unnecessary, as the callee ignores it, but conservative.
MOVQ R10, 24(SP)
SUBQ R10, AX
MOVQ AX, 32(SP)
MOVQ AX, 40(SP) // Unnecessary, as the callee ignores it, but conservative.
// Spill local variables (registers) onto the stack; call; unspill.
MOVQ DI, 80(SP)
CALL ·emitLiteral(SB)
MOVQ 80(SP), DI
// Finish the "d +=" part of "d += emitLiteral(etc)".
ADDQ 48(SP), DI
encodeBlockEnd:
MOVQ dst_base+0(FP), AX
SUBQ AX, DI
MOVQ DI, d+48(FP)
RET

238
lib/snappy/encode_other.go Normal file
View File

@ -0,0 +1,238 @@
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64 appengine !gc noasm
package snappy
func load32(b []byte, i int) uint32 {
b = b[i : i+4 : len(b)] // Help the compiler eliminate bounds checks on the next line.
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func load64(b []byte, i int) uint64 {
b = b[i : i+8 : len(b)] // Help the compiler eliminate bounds checks on the next line.
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
// emitLiteral writes a literal chunk and returns the number of bytes written.
//
// It assumes that:
// dst is long enough to hold the encoded bytes
// 1 <= len(lit) && len(lit) <= 65536
func emitLiteral(dst, lit []byte) int {
i, n := 0, uint(len(lit)-1)
switch {
case n < 60:
dst[0] = uint8(n)<<2 | tagLiteral
i = 1
case n < 1<<8:
dst[0] = 60<<2 | tagLiteral
dst[1] = uint8(n)
i = 2
default:
dst[0] = 61<<2 | tagLiteral
dst[1] = uint8(n)
dst[2] = uint8(n >> 8)
i = 3
}
return i + copy(dst[i:], lit)
}
// emitCopy writes a copy chunk and returns the number of bytes written.
//
// It assumes that:
// dst is long enough to hold the encoded bytes
// 1 <= offset && offset <= 65535
// 4 <= length && length <= 65535
func emitCopy(dst []byte, offset, length int) int {
i := 0
// The maximum length for a single tagCopy1 or tagCopy2 op is 64 bytes. The
// threshold for this loop is a little higher (at 68 = 64 + 4), and the
// length emitted down below is is a little lower (at 60 = 64 - 4), because
// it's shorter to encode a length 67 copy as a length 60 tagCopy2 followed
// by a length 7 tagCopy1 (which encodes as 3+2 bytes) than to encode it as
// a length 64 tagCopy2 followed by a length 3 tagCopy2 (which encodes as
// 3+3 bytes). The magic 4 in the 64±4 is because the minimum length for a
// tagCopy1 op is 4 bytes, which is why a length 3 copy has to be an
// encodes-as-3-bytes tagCopy2 instead of an encodes-as-2-bytes tagCopy1.
for length >= 68 {
// Emit a length 64 copy, encoded as 3 bytes.
dst[i+0] = 63<<2 | tagCopy2
dst[i+1] = uint8(offset)
dst[i+2] = uint8(offset >> 8)
i += 3
length -= 64
}
if length > 64 {
// Emit a length 60 copy, encoded as 3 bytes.
dst[i+0] = 59<<2 | tagCopy2
dst[i+1] = uint8(offset)
dst[i+2] = uint8(offset >> 8)
i += 3
length -= 60
}
if length >= 12 || offset >= 2048 {
// Emit the remaining copy, encoded as 3 bytes.
dst[i+0] = uint8(length-1)<<2 | tagCopy2
dst[i+1] = uint8(offset)
dst[i+2] = uint8(offset >> 8)
return i + 3
}
// Emit the remaining copy, encoded as 2 bytes.
dst[i+0] = uint8(offset>>8)<<5 | uint8(length-4)<<2 | tagCopy1
dst[i+1] = uint8(offset)
return i + 2
}
// extendMatch returns the largest k such that k <= len(src) and that
// src[i:i+k-j] and src[j:k] have the same contents.
//
// It assumes that:
// 0 <= i && i < j && j <= len(src)
func extendMatch(src []byte, i, j int) int {
for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 {
}
return j
}
func hash(u, shift uint32) uint32 {
return (u * 0x1e35a7bd) >> shift
}
// encodeBlock encodes a non-empty src to a guaranteed-large-enough dst. It
// assumes that the varint-encoded length of the decompressed bytes has already
// been written.
//
// It also assumes that:
// len(dst) >= MaxEncodedLen(len(src)) &&
// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize
func encodeBlock(dst, src []byte) (d int) {
// Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive.
// The table element type is uint16, as s < sLimit and sLimit < len(src)
// and len(src) <= maxBlockSize and maxBlockSize == 65536.
const (
maxTableSize = 1 << 14
// tableMask is redundant, but helps the compiler eliminate bounds
// checks.
tableMask = maxTableSize - 1
)
shift := uint32(32 - 8)
for tableSize := 1 << 8; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 {
shift--
}
// In Go, all array elements are zero-initialized, so there is no advantage
// to a smaller tableSize per se. However, it matches the C++ algorithm,
// and in the asm versions of this code, we can get away with zeroing only
// the first tableSize elements.
var table [maxTableSize]uint16
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := len(src) - inputMargin
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := 0
// The encoded form must start with a literal, as there are no previous
// bytes to copy, so we start looking for hash matches at s == 1.
s := 1
nextHash := hash(load32(src, s), shift)
for {
// Copied from the C++ snappy implementation:
//
// Heuristic match skipping: If 32 bytes are scanned with no matches
// found, start looking only at every other byte. If 32 more bytes are
// scanned (or skipped), look at every third byte, etc.. When a match
// is found, immediately go back to looking at every byte. This is a
// small loss (~5% performance, ~0.1% density) for compressible data
// due to more bookkeeping, but for non-compressible data (such as
// JPEG) it's a huge win since the compressor quickly "realizes" the
// data is incompressible and doesn't bother looking for matches
// everywhere.
//
// The "skip" variable keeps track of how many bytes there are since
// the last match; dividing it by 32 (ie. right-shifting by five) gives
// the number of bytes to move ahead for each iteration.
skip := 32
nextS := s
candidate := 0
for {
s = nextS
bytesBetweenHashLookups := skip >> 5
nextS = s + bytesBetweenHashLookups
skip += bytesBetweenHashLookups
if nextS > sLimit {
goto emitRemainder
}
candidate = int(table[nextHash&tableMask])
table[nextHash&tableMask] = uint16(s)
nextHash = hash(load32(src, nextS), shift)
if load32(src, s) == load32(src, candidate) {
break
}
}
// A 4-byte match has been found. We'll later see if more than 4 bytes
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
// them as literal bytes.
d += emitLiteral(dst[d:], src[nextEmit:s])
// Call emitCopy, and then see if another emitCopy could be our next
// move. Repeat until we find no match for the input immediately after
// what was consumed by the last emitCopy call.
//
// If we exit this loop normally then we need to call emitLiteral next,
// though we don't yet know how big the literal will be. We handle that
// by proceeding to the next iteration of the main loop. We also can
// exit this loop via goto if we get close to exhausting the input.
for {
// Invariant: we have a 4-byte match at s, and no need to emit any
// literal bytes prior to s.
base := s
// Extend the 4-byte match as long as possible.
//
// This is an inlined version of:
// s = extendMatch(src, candidate+4, s+4)
s += 4
for i := candidate + 4; s < len(src) && src[i] == src[s]; i, s = i+1, s+1 {
}
d += emitCopy(dst[d:], base-candidate, s-base)
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
// We could immediately start working at s now, but to improve
// compression we first update the hash table at s-1 and at s. If
// another emitCopy is not our next move, also calculate nextHash
// at s+1. At least on GOARCH=amd64, these three hash calculations
// are faster as one load64 call (with some shifts) instead of
// three load32 calls.
x := load64(src, s-1)
prevHash := hash(uint32(x>>0), shift)
table[prevHash&tableMask] = uint16(s - 1)
currHash := hash(uint32(x>>8), shift)
candidate = int(table[currHash&tableMask])
table[currHash&tableMask] = uint16(s)
if uint32(x>>8) != load32(src, candidate) {
nextHash = hash(uint32(x>>16), shift)
s++
break
}
}
}
emitRemainder:
if nextEmit < len(src) {
d += emitLiteral(dst[d:], src[nextEmit:])
}
return d
}

1965
lib/snappy/golden_test.go Normal file

File diff suppressed because it is too large Load Diff

98
lib/snappy/snappy.go Normal file
View File

@ -0,0 +1,98 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package snappy implements the Snappy compression format. It aims for very
// high speeds and reasonable compression.
//
// There are actually two Snappy formats: block and stream. They are related,
// but different: trying to decompress block-compressed data as a Snappy stream
// will fail, and vice versa. The block format is the Decode and Encode
// functions and the stream format is the Reader and Writer types.
//
// The block format, the more common case, is used when the complete size (the
// number of bytes) of the original data is known upfront, at the time
// compression starts. The stream format, also known as the framing format, is
// for when that isn't always true.
//
// The canonical, C++ implementation is at https://github.com/google/snappy and
// it only implements the block format.
package snappy
import (
"hash/crc32"
)
/*
Each encoded block begins with the varint-encoded length of the decoded data,
followed by a sequence of chunks. Chunks begin and end on byte boundaries. The
first byte of each chunk is broken into its 2 least and 6 most significant bits
called l and m: l ranges in [0, 4) and m ranges in [0, 64). l is the chunk tag.
Zero means a literal tag. All other values mean a copy tag.
For literal tags:
- If m < 60, the next 1 + m bytes are literal bytes.
- Otherwise, let n be the little-endian unsigned integer denoted by the next
m - 59 bytes. The next 1 + n bytes after that are literal bytes.
For copy tags, length bytes are copied from offset bytes ago, in the style of
Lempel-Ziv compression algorithms. In particular:
- For l == 1, the offset ranges in [0, 1<<11) and the length in [4, 12).
The length is 4 + the low 3 bits of m. The high 3 bits of m form bits 8-10
of the offset. The next byte is bits 0-7 of the offset.
- For l == 2, the offset ranges in [0, 1<<16) and the length in [1, 65).
The length is 1 + m. The offset is the little-endian unsigned integer
denoted by the next 2 bytes.
- For l == 3, this tag is a legacy format that is no longer issued by most
encoders. Nonetheless, the offset ranges in [0, 1<<32) and the length in
[1, 65). The length is 1 + m. The offset is the little-endian unsigned
integer denoted by the next 4 bytes.
*/
const (
tagLiteral = 0x00
tagCopy1 = 0x01
tagCopy2 = 0x02
tagCopy4 = 0x03
)
const (
checksumSize = 4
chunkHeaderSize = 4
magicChunk = "\xff\x06\x00\x00" + magicBody
magicBody = "sNaPpY"
// maxBlockSize is the maximum size of the input to encodeBlock. It is not
// part of the wire format per se, but some parts of the encoder assume
// that an offset fits into a uint16.
//
// Also, for the framing format (Writer type instead of Encode function),
// https://github.com/google/snappy/blob/master/framing_format.txt says
// that "the uncompressed data in a chunk must be no longer than 65536
// bytes".
maxBlockSize = 65536
// maxEncodedLenOfMaxBlockSize equals MaxEncodedLen(maxBlockSize), but is
// hard coded to be a const instead of a variable, so that obufLen can also
// be a const. Their equivalence is confirmed by
// TestMaxEncodedLenOfMaxBlockSize.
maxEncodedLenOfMaxBlockSize = 76490
obufHeaderLen = len(magicChunk) + checksumSize + chunkHeaderSize
obufLen = obufHeaderLen + maxEncodedLenOfMaxBlockSize
)
const (
chunkTypeCompressedData = 0x00
chunkTypeUncompressedData = 0x01
chunkTypePadding = 0xfe
chunkTypeStreamIdentifier = 0xff
)
var crcTable = crc32.MakeTable(crc32.Castagnoli)
// crc implements the checksum specified in section 3 of
// https://github.com/google/snappy/blob/master/framing_format.txt
func crc(b []byte) uint32 {
c := crc32.Update(0, crcTable, b)
return uint32(c>>15|c<<17) + 0xa282ead8
}

1353
lib/snappy/snappy_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,10 @@ package server
import (
"errors"
"github.com/cnlh/nps/bridge"
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/common"
"github.com/cnlh/nps/lib/conn"
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/lib/pool"
"net"
"net/http"
"sync"
@ -13,8 +16,8 @@ import (
type server struct {
id int
bridge *bridge.Bridge
task *lib.Tunnel
config *lib.Config
task *file.Tunnel
config *file.Config
errorContent []byte
sync.Mutex
}
@ -26,7 +29,7 @@ func (s *server) FlowAdd(in, out int64) {
s.task.Flow.InletFlow += in
}
func (s *server) FlowAddHost(host *lib.Host, in, out int64) {
func (s *server) FlowAddHost(host *file.Host, in, out int64) {
s.Lock()
defer s.Unlock()
host.Flow.ExportFlow += out
@ -36,7 +39,7 @@ func (s *server) FlowAddHost(host *lib.Host, in, out int64) {
//热更新配置
func (s *server) ResetConfig() bool {
//获取最新数据
task, err := lib.GetCsvDb().GetTask(s.task.Id)
task, err := file.GetCsvDb().GetTask(s.task.Id)
if err != nil {
return false
}
@ -45,7 +48,7 @@ func (s *server) ResetConfig() bool {
}
s.task.UseClientCnf = task.UseClientCnf
//使用客户端配置
client, err := lib.GetCsvDb().GetClient(s.task.Client.Id)
client, err := file.GetCsvDb().GetClient(s.task.Client.Id)
if s.task.UseClientCnf {
if err == nil {
s.config.U = client.Cnf.U
@ -62,11 +65,11 @@ func (s *server) ResetConfig() bool {
}
}
s.task.Client.Rate = client.Rate
s.config.CompressDecode, s.config.CompressEncode = lib.GetCompressType(s.config.Compress)
s.config.CompressDecode, s.config.CompressEncode = common.GetCompressType(s.config.Compress)
return true
}
func (s *server) linkCopy(link *lib.Link, c *lib.Conn, rb []byte, tunnel *lib.Conn, flow *lib.Flow) {
func (s *server) linkCopy(link *conn.Link, c *conn.Conn, rb []byte, tunnel *conn.Conn, flow *file.Flow) {
if rb != nil {
if _, err := tunnel.SendMsg(rb, link); err != nil {
c.Close()
@ -74,32 +77,32 @@ func (s *server) linkCopy(link *lib.Link, c *lib.Conn, rb []byte, tunnel *lib.Co
}
flow.Add(len(rb), 0)
}
buf := pool.BufPoolCopy.Get().([]byte)
for {
buf := lib.BufPoolCopy.Get().([]byte)
if n, err := c.Read(buf); err != nil {
tunnel.SendMsg([]byte(lib.IO_EOF), link)
tunnel.SendMsg([]byte(common.IO_EOF), link)
break
} else {
if _, err := tunnel.SendMsg(buf[:n], link); err != nil {
lib.PutBufPoolCopy(buf)
c.Close()
break
}
lib.PutBufPoolCopy(buf)
flow.Add(n, 0)
}
}
pool.PutBufPoolCopy(buf)
}
func (s *server) writeConnFail(c net.Conn) {
c.Write([]byte(lib.ConnectionFailBytes))
c.Write([]byte(common.ConnectionFailBytes))
c.Write(s.errorContent)
}
//权限认证
func (s *server) auth(r *http.Request, c *lib.Conn, u, p string) error {
if u != "" && p != "" && !lib.CheckAuth(r, u, p) {
c.Write([]byte(lib.UnauthorizedBytes))
func (s *server) auth(r *http.Request, c *conn.Conn, u, p string) error {
if u != "" && p != "" && !common.CheckAuth(r, u, p) {
c.Write([]byte(common.UnauthorizedBytes))
c.Close()
return errors.New("401 Unauthorized")
}

View File

@ -3,9 +3,13 @@ package server
import (
"bufio"
"crypto/tls"
"github.com/astaxie/beego"
"github.com/cnlh/nps/lib/beego"
"github.com/cnlh/nps/bridge"
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/conn"
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/lib/lg"
"github.com/cnlh/nps/lib/common"
"log"
"net/http"
"net/http/httputil"
"path/filepath"
@ -22,7 +26,7 @@ type httpServer struct {
stop chan bool
}
func NewHttp(bridge *bridge.Bridge, c *lib.Tunnel) *httpServer {
func NewHttp(bridge *bridge.Bridge, c *file.Tunnel) *httpServer {
httpPort, _ := beego.AppConfig.Int("httpProxyPort")
httpsPort, _ := beego.AppConfig.Int("httpsProxyPort")
pemPath := beego.AppConfig.String("pemPath")
@ -44,33 +48,33 @@ func NewHttp(bridge *bridge.Bridge, c *lib.Tunnel) *httpServer {
func (s *httpServer) Start() error {
var err error
var http, https *http.Server
if s.errorContent, err = lib.ReadAllFromFile(filepath.Join(lib.GetRunPath(), "web", "static", "page", "error.html")); err != nil {
if s.errorContent, err = common.ReadAllFromFile(filepath.Join(common.GetRunPath(), "web", "static", "page", "error.html")); err != nil {
s.errorContent = []byte("easyProxy 404")
}
if s.httpPort > 0 {
http = s.NewServer(s.httpPort)
go func() {
lib.Println("启动http监听,端口为", s.httpPort)
lg.Println("启动http监听,端口为", s.httpPort)
err := http.ListenAndServe()
if err != nil {
lib.Fatalln(err)
lg.Fatalln(err)
}
}()
}
if s.httpsPort > 0 {
if !lib.FileExists(s.pemPath) {
lib.Fatalf("ssl certFile文件%s不存在", s.pemPath)
if !common.FileExists(s.pemPath) {
lg.Fatalf("ssl certFile文件%s不存在", s.pemPath)
}
if !lib.FileExists(s.keyPath) {
lib.Fatalf("ssl keyFile文件%s不存在", s.keyPath)
if !common.FileExists(s.keyPath) {
lg.Fatalf("ssl keyFile文件%s不存在", s.keyPath)
}
https = s.NewServer(s.httpsPort)
go func() {
lib.Println("启动https监听,端口为", s.httpsPort)
lg.Println("启动https监听,端口为", s.httpsPort)
err := https.ListenAndServeTLS(s.pemPath, s.keyPath)
if err != nil {
lib.Fatalln(err)
lg.Fatalln(err)
}
}()
}
@ -96,40 +100,41 @@ func (s *httpServer) handleTunneling(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
return
}
conn, _, err := hijacker.Hijack()
c, _, err := hijacker.Hijack()
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
}
s.process(lib.NewConn(conn), r)
s.process(conn.NewConn(c), r)
}
func (s *httpServer) process(c *lib.Conn, r *http.Request) {
func (s *httpServer) process(c *conn.Conn, r *http.Request) {
//多客户端域名代理
var (
isConn = true
link *lib.Link
host *lib.Host
tunnel *lib.Conn
lk *conn.Link
host *file.Host
tunnel *conn.Conn
err error
)
for {
//首次获取conn
if isConn {
if host, err = GetInfoByHost(r.Host); err != nil {
lib.Printf("the host %s is not found !", r.Host)
lg.Printf("the host %s is not found !", r.Host)
break
}
//流量限制
if host.Client.Flow.FlowLimit > 0 && (host.Client.Flow.FlowLimit<<20) < (host.Client.Flow.ExportFlow+host.Client.Flow.InletFlow) {
break
}
host.Client.Cnf.CompressDecode, host.Client.Cnf.CompressEncode = lib.GetCompressType(host.Client.Cnf.Compress)
host.Client.Cnf.CompressDecode, host.Client.Cnf.CompressEncode = common.GetCompressType(host.Client.Cnf.Compress)
//权限控制
if err = s.auth(r, c, host.Client.Cnf.U, host.Client.Cnf.P); err != nil {
break
}
link = lib.NewLink(host.Client.GetId(), lib.CONN_TCP, host.GetRandomTarget(), host.Client.Cnf.CompressEncode, host.Client.Cnf.CompressDecode, host.Client.Cnf.Crypt, c, host.Flow, nil, host.Client.Rate, nil)
if tunnel, err = s.bridge.SendLinkInfo(host.Client.Id, link); err != nil {
lk = conn.NewLink(host.Client.GetId(), common.CONN_TCP, host.GetRandomTarget(), host.Client.Cnf.CompressEncode, host.Client.Cnf.CompressDecode, host.Client.Cnf.Crypt, c, host.Flow, nil, host.Client.Rate, nil)
if tunnel, err = s.bridge.SendLinkInfo(host.Client.Id, lk); err != nil {
log.Println(err)
break
}
isConn = false
@ -140,13 +145,13 @@ func (s *httpServer) process(c *lib.Conn, r *http.Request) {
}
}
//根据设定修改header和host
lib.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, c.Conn.RemoteAddr().String())
common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, c.Conn.RemoteAddr().String())
b, err := httputil.DumpRequest(r, true)
if err != nil {
break
}
host.Flow.Add(len(b), 0)
if _, err := tunnel.SendMsg(b, link); err != nil {
if _, err := tunnel.SendMsg(b, lk); err != nil {
c.Close()
break
}
@ -155,7 +160,7 @@ func (s *httpServer) process(c *lib.Conn, r *http.Request) {
if isConn {
s.writeConnFail(c.Conn)
} else {
tunnel.SendMsg([]byte(lib.IO_EOF), link)
tunnel.SendMsg([]byte(common.IO_EOF), lk)
}
c.Close()

View File

@ -3,7 +3,8 @@ package server
import (
"errors"
"github.com/cnlh/nps/bridge"
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/lib/lg"
"reflect"
"strings"
)
@ -11,44 +12,41 @@ import (
var (
Bridge *bridge.Bridge
RunList map[int]interface{} //运行中的任务
startFinish chan bool
)
func init() {
RunList = make(map[int]interface{})
startFinish = make(chan bool)
}
//从csv文件中恢复任务
func InitFromCsv() {
for _, v := range lib.GetCsvDb().Tasks {
for _, v := range file.GetCsvDb().Tasks {
if v.Status {
lib.Println("启动模式:", v.Mode, "监听端口:", v.TcpPort)
lg.Println("启动模式:", v.Mode, "监听端口:", v.TcpPort)
AddTask(v)
}
}
}
//start a new server
func StartNewServer(bridgePort int, cnf *lib.Tunnel) {
Bridge = bridge.NewTunnel(bridgePort, RunList)
func StartNewServer(bridgePort int, cnf *file.Tunnel, bridgeType string) {
Bridge = bridge.NewTunnel(bridgePort, RunList, bridgeType)
if err := Bridge.StartTunnel(); err != nil {
lib.Fatalln("服务端开启失败", err)
lg.Fatalln("服务端开启失败", err)
}
if svr := NewMode(Bridge, cnf); svr != nil {
RunList[cnf.Id] = svr
err := reflect.ValueOf(svr).MethodByName("Start").Call(nil)[0]
if err.Interface() != nil {
lib.Fatalln(err)
lg.Fatalln(err)
}
} else {
lib.Fatalln("启动模式不正确")
lg.Fatalln("启动模式不正确")
}
}
//new a server by mode name
func NewMode(Bridge *bridge.Bridge, c *lib.Tunnel) interface{} {
func NewMode(Bridge *bridge.Bridge, c *file.Tunnel) interface{} {
switch c.Mode {
case "tunnelServer":
return NewTunnelModeServer(ProcessTunnel, Bridge, c)
@ -60,17 +58,15 @@ func NewMode(Bridge *bridge.Bridge, c *lib.Tunnel) interface{} {
return NewUdpModeServer(Bridge, c)
case "webServer":
InitFromCsv()
t := &lib.Tunnel{
t := &file.Tunnel{
TcpPort: 0,
Mode: "httpHostServer",
Target: "",
Config: &lib.Config{},
Config: &file.Config{},
Status: true,
}
AddTask(t)
return NewWebServer(Bridge)
case "hostServer":
return NewHostServer(c)
case "httpHostServer":
return NewHttp(Bridge, c)
}
@ -81,11 +77,11 @@ func NewMode(Bridge *bridge.Bridge, c *lib.Tunnel) interface{} {
func StopServer(id int) error {
if v, ok := RunList[id]; ok {
reflect.ValueOf(v).MethodByName("Close").Call(nil)
if t, err := lib.GetCsvDb().GetTask(id); err != nil {
if t, err := file.GetCsvDb().GetTask(id); err != nil {
return err
} else {
t.Status = false
lib.GetCsvDb().UpdateTask(t)
file.GetCsvDb().UpdateTask(t)
}
return nil
}
@ -93,13 +89,13 @@ func StopServer(id int) error {
}
//add task
func AddTask(t *lib.Tunnel) error {
func AddTask(t *file.Tunnel) error {
if svr := NewMode(Bridge, t); svr != nil {
RunList[t.Id] = svr
go func() {
err := reflect.ValueOf(svr).MethodByName("Start").Call(nil)[0]
if err.Interface() != nil {
lib.Fatalln("服务端", t.Id, "启动失败,错误:", err)
lg.Fatalln("客户端", t.Id, "启动失败,错误:", err)
delete(RunList, t.Id)
}
}()
@ -111,12 +107,12 @@ func AddTask(t *lib.Tunnel) error {
//start task
func StartTask(id int) error {
if t, err := lib.GetCsvDb().GetTask(id); err != nil {
if t, err := file.GetCsvDb().GetTask(id); err != nil {
return err
} else {
AddTask(t)
t.Status = true
lib.GetCsvDb().UpdateTask(t)
file.GetCsvDb().UpdateTask(t)
}
return nil
}
@ -126,12 +122,12 @@ func DelTask(id int) error {
if err := StopServer(id); err != nil {
return err
}
return lib.GetCsvDb().DelTask(id)
return file.GetCsvDb().DelTask(id)
}
//get key by host from x
func GetInfoByHost(host string) (h *lib.Host, err error) {
for _, v := range lib.GetCsvDb().Hosts {
func GetInfoByHost(host string) (h *file.Host, err error) {
for _, v := range file.GetCsvDb().Hosts {
s := strings.Split(host, ":")
if s[0] == v.Host {
h = v
@ -143,10 +139,10 @@ func GetInfoByHost(host string) (h *lib.Host, err error) {
}
//get task list by page num
func GetTunnel(start, length int, typeVal string, clientId int) ([]*lib.Tunnel, int) {
list := make([]*lib.Tunnel, 0)
func GetTunnel(start, length int, typeVal string, clientId int) ([]*file.Tunnel, int) {
list := make([]*file.Tunnel, 0)
var cnt int
for _, v := range lib.GetCsvDb().Tasks {
for _, v := range file.GetCsvDb().Tasks {
if (typeVal != "" && v.Mode != typeVal) || (typeVal == "" && clientId != v.Client.Id) {
continue
}
@ -171,13 +167,13 @@ func GetTunnel(start, length int, typeVal string, clientId int) ([]*lib.Tunnel,
}
//获取客户端列表
func GetClientList(start, length int) (list []*lib.Client, cnt int) {
list, cnt = lib.GetCsvDb().GetClientList(start, length)
func GetClientList(start, length int) (list []*file.Client, cnt int) {
list, cnt = file.GetCsvDb().GetClientList(start, length)
dealClientData(list)
return
}
func dealClientData(list []*lib.Client) {
func dealClientData(list []*file.Client) {
for _, v := range list {
if _, ok := Bridge.Client[v.Id]; ok {
v.IsConnect = true
@ -186,13 +182,13 @@ func dealClientData(list []*lib.Client) {
}
v.Flow.InletFlow = 0
v.Flow.ExportFlow = 0
for _, h := range lib.GetCsvDb().Hosts {
for _, h := range file.GetCsvDb().Hosts {
if h.Client.Id == v.Id {
v.Flow.InletFlow += h.Flow.InletFlow
v.Flow.ExportFlow += h.Flow.ExportFlow
}
}
for _, t := range lib.GetCsvDb().Tasks {
for _, t := range file.GetCsvDb().Tasks {
if t.Client.Id == v.Id {
v.Flow.InletFlow += t.Flow.InletFlow
v.Flow.ExportFlow += t.Flow.ExportFlow
@ -204,14 +200,14 @@ func dealClientData(list []*lib.Client) {
//根据客户端id删除其所属的所有隧道和域名
func DelTunnelAndHostByClientId(clientId int) {
for _, v := range lib.GetCsvDb().Tasks {
for _, v := range file.GetCsvDb().Tasks {
if v.Client.Id == clientId {
DelTask(v.Id)
}
}
for _, v := range lib.GetCsvDb().Hosts {
for _, v := range file.GetCsvDb().Hosts {
if v.Client.Id == clientId {
lib.GetCsvDb().DelHost(v.Host)
file.GetCsvDb().DelHost(v.Host)
}
}
}
@ -223,9 +219,9 @@ func DelClientConnect(clientId int) {
func GetDashboardData() map[string]int {
data := make(map[string]int)
data["hostCount"] = len(lib.GetCsvDb().Hosts)
data["clientCount"] = len(lib.GetCsvDb().Clients)
list := lib.GetCsvDb().Clients
data["hostCount"] = len(file.GetCsvDb().Hosts)
data["clientCount"] = len(file.GetCsvDb().Clients)
list := file.GetCsvDb().Clients
dealClientData(list)
c := 0
var in, out int64
@ -239,7 +235,7 @@ func GetDashboardData() map[string]int {
data["clientOnlineCount"] = c
data["inletFlowCount"] = int(in)
data["exportFlowCount"] = int(out)
for _, v := range lib.GetCsvDb().Tasks {
for _, v := range file.GetCsvDb().Tasks {
switch v.Mode {
case "tunnelServer":
data["tunnelServerCount"] += 1

View File

@ -4,7 +4,10 @@ import (
"encoding/binary"
"errors"
"github.com/cnlh/nps/bridge"
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/common"
"github.com/cnlh/nps/lib/conn"
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/lib/lg"
"io"
"net"
"strconv"
@ -65,7 +68,7 @@ func (s *Sock5ModeServer) handleRequest(c net.Conn) {
_, err := io.ReadFull(c, header)
if err != nil {
lib.Println("illegal request", err)
lg.Println("illegal request", err)
c.Close()
return
}
@ -135,18 +138,18 @@ func (s *Sock5ModeServer) doConnect(c net.Conn, command uint8) {
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
var ltype string
if command == associateMethod {
ltype = lib.CONN_UDP
ltype = common.CONN_UDP
} else {
ltype = lib.CONN_TCP
ltype = common.CONN_TCP
}
link := lib.NewLink(s.task.Client.GetId(), ltype, addr, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, lib.NewConn(c), s.task.Flow, nil, s.task.Client.Rate, nil)
link := conn.NewLink(s.task.Client.GetId(), ltype, addr, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, conn.NewConn(c), s.task.Flow, nil, s.task.Client.Rate, nil)
if tunnel, err := s.bridge.SendLinkInfo(s.task.Client.Id, link); err != nil {
c.Close()
return
} else {
s.sendReply(c, succeeded)
s.linkCopy(link, lib.NewConn(c), nil, tunnel, s.task.Flow)
s.linkCopy(link, conn.NewConn(c), nil, tunnel, s.task.Flow)
}
return
}
@ -162,7 +165,7 @@ func (s *Sock5ModeServer) handleBind(c net.Conn) {
//udp
func (s *Sock5ModeServer) handleUDP(c net.Conn) {
lib.Println("UDP Associate")
lg.Println("UDP Associate")
/*
+----+------+------+----------+----------+----------+
|RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
@ -175,7 +178,7 @@ func (s *Sock5ModeServer) handleUDP(c net.Conn) {
// relay udp datagram silently, without any notification to the requesting client
if buf[2] != 0 {
// does not support fragmentation, drop it
lib.Println("does not support fragmentation, drop")
lg.Println("does not support fragmentation, drop")
dummy := make([]byte, maxUDPPacketSize)
c.Read(dummy)
}
@ -187,13 +190,13 @@ func (s *Sock5ModeServer) handleUDP(c net.Conn) {
func (s *Sock5ModeServer) handleConn(c net.Conn) {
buf := make([]byte, 2)
if _, err := io.ReadFull(c, buf); err != nil {
lib.Println("negotiation err", err)
lg.Println("negotiation err", err)
c.Close()
return
}
if version := buf[0]; version != 5 {
lib.Println("only support socks5, request from: ", c.RemoteAddr())
lg.Println("only support socks5, request from: ", c.RemoteAddr())
c.Close()
return
}
@ -201,7 +204,7 @@ func (s *Sock5ModeServer) handleConn(c net.Conn) {
methods := make([]byte, nMethods)
if len, err := c.Read(methods); len != int(nMethods) || err != nil {
lib.Println("wrong method")
lg.Println("wrong method")
c.Close()
return
}
@ -210,7 +213,7 @@ func (s *Sock5ModeServer) handleConn(c net.Conn) {
c.Write(buf)
if err := s.Auth(c); err != nil {
c.Close()
lib.Println("验证失败:", err)
lg.Println("验证失败:", err)
return
}
} else {
@ -269,7 +272,7 @@ func (s *Sock5ModeServer) Start() error {
if strings.Contains(err.Error(), "use of closed network connection") {
break
}
lib.Fatalln("accept error: ", err)
lg.Fatalln("accept error: ", err)
}
if !s.ResetConfig() {
conn.Close()
@ -286,11 +289,11 @@ func (s *Sock5ModeServer) Close() error {
}
//new
func NewSock5ModeServer(bridge *bridge.Bridge, task *lib.Tunnel) *Sock5ModeServer {
func NewSock5ModeServer(bridge *bridge.Bridge, task *file.Tunnel) *Sock5ModeServer {
s := new(Sock5ModeServer)
s.bridge = bridge
s.task = task
s.config = lib.DeepCopyConfig(task.Config)
s.config = file.DeepCopyConfig(task.Config)
if s.config.U != "" && s.config.P != "" {
s.isVerify = true
} else {

View File

@ -2,9 +2,12 @@ package server
import (
"errors"
"github.com/astaxie/beego"
"github.com/cnlh/nps/lib/beego"
"github.com/cnlh/nps/bridge"
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/common"
"github.com/cnlh/nps/lib/conn"
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/lib/lg"
"net"
"path/filepath"
"strings"
@ -17,12 +20,12 @@ type TunnelModeServer struct {
}
//tcp|http|host
func NewTunnelModeServer(process process, bridge *bridge.Bridge, task *lib.Tunnel) *TunnelModeServer {
func NewTunnelModeServer(process process, bridge *bridge.Bridge, task *file.Tunnel) *TunnelModeServer {
s := new(TunnelModeServer)
s.bridge = bridge
s.process = process
s.task = task
s.config = lib.DeepCopyConfig(task.Config)
s.config = file.DeepCopyConfig(task.Config)
return s
}
@ -34,22 +37,22 @@ func (s *TunnelModeServer) Start() error {
return err
}
for {
conn, err := s.listener.AcceptTCP()
c, err := s.listener.AcceptTCP()
if err != nil {
if strings.Contains(err.Error(), "use of closed network connection") {
break
}
lib.Println(err)
lg.Println(err)
continue
}
go s.process(lib.NewConn(conn), s)
go s.process(conn.NewConn(c), s)
}
return nil
}
//与客户端建立通道
func (s *TunnelModeServer) dealClient(c *lib.Conn, cnf *lib.Config, addr string, method string, rb []byte) error {
link := lib.NewLink(s.task.Client.GetId(), lib.CONN_TCP, addr, cnf.CompressEncode, cnf.CompressDecode, cnf.Crypt, c, s.task.Flow, nil, s.task.Client.Rate, nil)
func (s *TunnelModeServer) dealClient(c *conn.Conn, cnf *file.Config, addr string, method string, rb []byte) error {
link := conn.NewLink(s.task.Client.GetId(), common.CONN_TCP, addr, cnf.CompressEncode, cnf.CompressDecode, cnf.Crypt, c, s.task.Flow, nil, s.task.Client.Rate, nil)
if tunnel, err := s.bridge.SendLinkInfo(s.task.Client.Id, link); err != nil {
c.Close()
@ -73,13 +76,13 @@ type WebServer struct {
//开始
func (s *WebServer) Start() error {
p, _ := beego.AppConfig.Int("httpport")
if !lib.TestTcpPort(p) {
lib.Fatalln("web管理端口", p, "被占用!")
if !common.TestTcpPort(p) {
lg.Fatalln("web管理端口", p, "被占用!")
}
beego.BConfig.WebConfig.Session.SessionOn = true
lib.Println("web管理启动访问端口为", beego.AppConfig.String("httpport"))
beego.SetStaticPath("/static", filepath.Join(lib.GetRunPath(), "web", "static"))
beego.SetViewsPath(filepath.Join(lib.GetRunPath(), "web", "views"))
lg.Println("web管理启动访问端口为", p)
beego.SetStaticPath("/static", filepath.Join(common.GetRunPath(), "web", "static"))
beego.SetViewsPath(filepath.Join(common.GetRunPath(), "web", "views"))
beego.Run()
return errors.New("web管理启动失败")
}
@ -91,32 +94,10 @@ func NewWebServer(bridge *bridge.Bridge) *WebServer {
return s
}
//host
type HostServer struct {
server
}
//开始
func (s *HostServer) Start() error {
return nil
}
func NewHostServer(task *lib.Tunnel) *HostServer {
s := new(HostServer)
s.task = task
s.config = lib.DeepCopyConfig(task.Config)
return s
}
//close
func (s *HostServer) Close() error {
return nil
}
type process func(c *lib.Conn, s *TunnelModeServer) error
type process func(c *conn.Conn, s *TunnelModeServer) error
//tcp隧道模式
func ProcessTunnel(c *lib.Conn, s *TunnelModeServer) error {
func ProcessTunnel(c *conn.Conn, s *TunnelModeServer) error {
if !s.ResetConfig() {
c.Close()
return errors.New("流量超出")
@ -125,7 +106,7 @@ func ProcessTunnel(c *lib.Conn, s *TunnelModeServer) error {
}
//http代理模式
func ProcessHttp(c *lib.Conn, s *TunnelModeServer) error {
func ProcessHttp(c *conn.Conn, s *TunnelModeServer) error {
if !s.ResetConfig() {
c.Close()
return errors.New("流量超出")

View File

@ -1,54 +1,78 @@
package server
import (
"github.com/astaxie/beego"
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/beego"
"github.com/cnlh/nps/lib/common"
"github.com/cnlh/nps/lib/file"
"log"
"strconv"
)
func TestServerConfig() {
var postArr []int
for _, v := range lib.GetCsvDb().Tasks {
isInArr(&postArr, v.TcpPort, v.Remark)
var postTcpArr []int
var postUdpArr []int
for _, v := range file.GetCsvDb().Tasks {
if v.Mode == "udpServer" {
isInArr(&postUdpArr, v.TcpPort, v.Remark, "udp")
} else {
isInArr(&postTcpArr, v.TcpPort, v.Remark, "tcp")
}
}
p, err := beego.AppConfig.Int("httpport")
if err != nil {
log.Fatalln("Getting web management port error :", err)
} else {
isInArr(&postArr, p, "WebmManagement port")
isInArr(&postTcpArr, p, "Web Management port", "tcp")
}
if p := beego.AppConfig.String("bridgePort"); p != "" {
if port, err := strconv.Atoi(p); err != nil {
log.Fatalln("get Server and client communication portserror:", err)
} else if beego.AppConfig.String("bridgeType") == "kcp" {
isInArr(&postUdpArr, port, "Server and client communication ports", "udp")
} else {
isInArr(&postTcpArr, port, "Server and client communication ports", "tcp")
}
}
if p := beego.AppConfig.String("httpProxyPort"); p != "" {
if port, err := strconv.Atoi(p); err != nil {
log.Fatalln("get http port error:", err)
} else {
isInArr(&postArr, port, "https port")
isInArr(&postTcpArr, port, "https port", "tcp")
}
}
if p := beego.AppConfig.String("httpsProxyPort"); p != "" {
if port, err := strconv.Atoi(p); err != nil {
log.Fatalln("get https port error", err)
} else {
if !lib.FileExists(beego.AppConfig.String("pemPath")) {
if !common.FileExists(beego.AppConfig.String("pemPath")) {
log.Fatalf("ssl certFile %s is not exist", beego.AppConfig.String("pemPath"))
}
if !lib.FileExists(beego.AppConfig.String("ketPath")) {
if !common.FileExists(beego.AppConfig.String("ketPath")) {
log.Fatalf("ssl keyFile %s is not exist", beego.AppConfig.String("pemPath"))
}
isInArr(&postArr, port, "http port")
isInArr(&postTcpArr, port, "http port", "tcp")
}
}
}
func isInArr(arr *[]int, val int, remark string) {
func isInArr(arr *[]int, val int, remark string, tp string) {
for _, v := range *arr {
if v == val {
log.Fatalf("the port %d is reused,remark: %s", val, remark)
}
}
if !lib.TestTcpPort(val) {
log.Fatalf("open the %d port error ,remark: %s", val, remark)
if tp == "tcp" {
if !common.TestTcpPort(val) {
log.Fatalf("open the %d port error ,remark: %s", val, remark)
}
} else {
if !common.TestUdpPort(val) {
log.Fatalf("open the %d port error ,remark: %s", val, remark)
}
}
*arr = append(*arr, val)
return
}

View File

@ -2,7 +2,10 @@ package server
import (
"github.com/cnlh/nps/bridge"
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/common"
"github.com/cnlh/nps/lib/conn"
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/lib/pool"
"net"
"strings"
)
@ -10,15 +13,15 @@ import (
type UdpModeServer struct {
server
listener *net.UDPConn
udpMap map[string]*lib.Conn
udpMap map[string]*conn.Conn
}
func NewUdpModeServer(bridge *bridge.Bridge, task *lib.Tunnel) *UdpModeServer {
func NewUdpModeServer(bridge *bridge.Bridge, task *file.Tunnel) *UdpModeServer {
s := new(UdpModeServer)
s.bridge = bridge
s.udpMap = make(map[string]*lib.Conn)
s.udpMap = make(map[string]*conn.Conn)
s.task = task
s.config = lib.DeepCopyConfig(task.Config)
s.config = file.DeepCopyConfig(task.Config)
return s
}
@ -29,7 +32,7 @@ func (s *UdpModeServer) Start() error {
if err != nil {
return err
}
buf := lib.BufPoolUdp.Get().([]byte)
buf := pool.BufPoolUdp.Get().([]byte)
for {
n, addr, err := s.listener.ReadFromUDP(buf)
if err != nil {
@ -47,13 +50,14 @@ func (s *UdpModeServer) Start() error {
}
func (s *UdpModeServer) process(addr *net.UDPAddr, data []byte) {
link := lib.NewLink(s.task.Client.GetId(), lib.CONN_UDP, s.task.Target, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, nil, s.task.Flow, s.listener, s.task.Client.Rate, addr)
link := conn.NewLink(s.task.Client.GetId(), common.CONN_UDP, s.task.Target, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, nil, s.task.Flow, s.listener, s.task.Client.Rate, addr)
if tunnel, err := s.bridge.SendLinkInfo(s.task.Client.Id, link); err != nil {
return
} else {
s.task.Flow.Add(len(data), 0)
tunnel.SendMsg(data, link)
pool.PutBufPoolUdp(data)
}
}

View File

@ -1,8 +1,8 @@
package controllers
import (
"github.com/astaxie/beego"
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/beego"
"github.com/cnlh/nps/lib/common"
"github.com/cnlh/nps/server"
"strconv"
"strings"
@ -40,7 +40,7 @@ func (s *BaseController) display(tpl ...string) {
}
ip := s.Ctx.Request.Host
if strings.LastIndex(ip, ":") > 0 {
arr := strings.Split(lib.GetHostByName(ip), ":")
arr := strings.Split(common.GetHostByName(ip), ":")
s.Data["ip"] = arr[0]
}
s.Data["p"] = server.Bridge.TunnelPort

View File

@ -1,7 +1,9 @@
package controllers
import (
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/crypt"
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/lib/rate"
"github.com/cnlh/nps/server"
)
@ -28,29 +30,29 @@ func (s *ClientController) Add() {
s.SetInfo("新增")
s.display()
} else {
t := &lib.Client{
VerifyKey: lib.GetRandomString(16),
Id: lib.GetCsvDb().GetClientId(),
t := &file.Client{
VerifyKey: crypt.GetRandomString(16),
Id: file.GetCsvDb().GetClientId(),
Status: true,
Remark: s.GetString("remark"),
Cnf: &lib.Config{
Cnf: &file.Config{
U: s.GetString("u"),
P: s.GetString("p"),
Compress: s.GetString("compress"),
Crypt: s.GetBoolNoErr("crypt"),
},
RateLimit: s.GetIntNoErr("rate_limit"),
Flow: &lib.Flow{
Flow: &file.Flow{
ExportFlow: 0,
InletFlow: 0,
FlowLimit: int64(s.GetIntNoErr("flow_limit")),
},
}
if t.RateLimit > 0 {
t.Rate = lib.NewRate(int64(t.RateLimit * 1024))
t.Rate = rate.NewRate(int64(t.RateLimit * 1024))
t.Rate.Start()
}
lib.GetCsvDb().NewClient(t)
file.GetCsvDb().NewClient(t)
s.AjaxOk("添加成功")
}
}
@ -58,7 +60,7 @@ func (s *ClientController) GetClient() {
if s.Ctx.Request.Method == "POST" {
id := s.GetIntNoErr("id")
data := make(map[string]interface{})
if c, err := lib.GetCsvDb().GetClient(id); err != nil {
if c, err := file.GetCsvDb().GetClient(id); err != nil {
data["code"] = 0
} else {
data["code"] = 1
@ -74,7 +76,7 @@ func (s *ClientController) Edit() {
id := s.GetIntNoErr("id")
if s.Ctx.Request.Method == "GET" {
s.Data["menu"] = "client"
if c, err := lib.GetCsvDb().GetClient(id); err != nil {
if c, err := file.GetCsvDb().GetClient(id); err != nil {
s.error()
} else {
s.Data["c"] = c
@ -82,7 +84,7 @@ func (s *ClientController) Edit() {
s.SetInfo("修改")
s.display()
} else {
if c, err := lib.GetCsvDb().GetClient(id); err != nil {
if c, err := file.GetCsvDb().GetClient(id); err != nil {
s.error()
} else {
c.Remark = s.GetString("remark")
@ -96,12 +98,12 @@ func (s *ClientController) Edit() {
c.Rate.Stop()
}
if c.RateLimit > 0 {
c.Rate = lib.NewRate(int64(c.RateLimit * 1024))
c.Rate = rate.NewRate(int64(c.RateLimit * 1024))
c.Rate.Start()
} else {
c.Rate = nil
}
lib.GetCsvDb().UpdateClient(c)
file.GetCsvDb().UpdateClient(c)
}
s.AjaxOk("修改成功")
}
@ -110,7 +112,7 @@ func (s *ClientController) Edit() {
//更改状态
func (s *ClientController) ChangeStatus() {
id := s.GetIntNoErr("id")
if client, err := lib.GetCsvDb().GetClient(id); err == nil {
if client, err := file.GetCsvDb().GetClient(id); err == nil {
client.Status = s.GetBoolNoErr("status")
if client.Status == false {
server.DelClientConnect(client.Id)
@ -123,7 +125,7 @@ func (s *ClientController) ChangeStatus() {
//删除客户端
func (s *ClientController) Del() {
id := s.GetIntNoErr("id")
if err := lib.GetCsvDb().DelClient(id); err != nil {
if err := file.GetCsvDb().DelClient(id); err != nil {
s.AjaxErr("删除失败")
}
server.DelTunnelAndHostByClientId(id)

View File

@ -1,7 +1,7 @@
package controllers
import (
"github.com/cnlh/nps/lib"
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/server"
)
@ -72,27 +72,27 @@ func (s *IndexController) Add() {
s.SetInfo("新增")
s.display()
} else {
t := &lib.Tunnel{
t := &file.Tunnel{
TcpPort: s.GetIntNoErr("port"),
Mode: s.GetString("type"),
Target: s.GetString("target"),
Config: &lib.Config{
Config: &file.Config{
U: s.GetString("u"),
P: s.GetString("p"),
Compress: s.GetString("compress"),
Crypt: s.GetBoolNoErr("crypt"),
},
Id: lib.GetCsvDb().GetTaskId(),
Id: file.GetCsvDb().GetTaskId(),
UseClientCnf: s.GetBoolNoErr("use_client"),
Status: true,
Remark: s.GetString("remark"),
Flow: &lib.Flow{},
Flow: &file.Flow{},
}
var err error
if t.Client, err = lib.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
if t.Client, err = file.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
s.AjaxErr(err.Error())
}
lib.GetCsvDb().NewTask(t)
file.GetCsvDb().NewTask(t)
if err := server.AddTask(t); err != nil {
s.AjaxErr(err.Error())
} else {
@ -103,7 +103,7 @@ func (s *IndexController) Add() {
func (s *IndexController) GetOneTunnel() {
id := s.GetIntNoErr("id")
data := make(map[string]interface{})
if t, err := lib.GetCsvDb().GetTask(id); err != nil {
if t, err := file.GetCsvDb().GetTask(id); err != nil {
data["code"] = 0
} else {
data["code"] = 1
@ -115,7 +115,7 @@ func (s *IndexController) GetOneTunnel() {
func (s *IndexController) Edit() {
id := s.GetIntNoErr("id")
if s.Ctx.Request.Method == "GET" {
if t, err := lib.GetCsvDb().GetTask(id); err != nil {
if t, err := file.GetCsvDb().GetTask(id); err != nil {
s.error()
} else {
s.Data["t"] = t
@ -123,7 +123,7 @@ func (s *IndexController) Edit() {
s.SetInfo("修改")
s.display()
} else {
if t, err := lib.GetCsvDb().GetTask(id); err != nil {
if t, err := file.GetCsvDb().GetTask(id); err != nil {
s.error()
} else {
t.TcpPort = s.GetIntNoErr("port")
@ -137,10 +137,10 @@ func (s *IndexController) Edit() {
t.Config.Crypt = s.GetBoolNoErr("crypt")
t.UseClientCnf = s.GetBoolNoErr("use_client")
t.Remark = s.GetString("remark")
if t.Client, err = lib.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
if t.Client, err = file.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
s.AjaxErr("修改失败")
}
lib.GetCsvDb().UpdateTask(t)
file.GetCsvDb().UpdateTask(t)
}
s.AjaxOk("修改成功")
}
@ -179,7 +179,7 @@ func (s *IndexController) HostList() {
} else {
start, length := s.GetAjaxParams()
clientId := s.GetIntNoErr("client_id")
list, cnt := lib.GetCsvDb().GetHost(start, length, clientId)
list, cnt := file.GetCsvDb().GetHost(start, length, clientId)
s.AjaxTable(list, cnt, cnt)
}
}
@ -200,7 +200,7 @@ func (s *IndexController) GetHost() {
func (s *IndexController) DelHost() {
host := s.GetString("host")
if err := lib.GetCsvDb().DelHost(host); err != nil {
if err := file.GetCsvDb().DelHost(host); err != nil {
s.AjaxErr("删除失败")
}
s.AjaxOk("删除成功")
@ -213,19 +213,19 @@ func (s *IndexController) AddHost() {
s.SetInfo("新增")
s.display("index/hadd")
} else {
h := &lib.Host{
h := &file.Host{
Host: s.GetString("host"),
Target: s.GetString("target"),
HeaderChange: s.GetString("header"),
HostChange: s.GetString("hostchange"),
Remark: s.GetString("remark"),
Flow: &lib.Flow{},
Flow: &file.Flow{},
}
var err error
if h.Client, err = lib.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
if h.Client, err = file.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
s.AjaxErr("添加失败")
}
lib.GetCsvDb().NewHost(h)
file.GetCsvDb().NewHost(h)
s.AjaxOk("添加成功")
}
}
@ -251,9 +251,9 @@ func (s *IndexController) EditHost() {
h.HostChange = s.GetString("hostchange")
h.Remark = s.GetString("remark")
h.TargetArr = nil
lib.GetCsvDb().UpdateHost(h)
file.GetCsvDb().UpdateHost(h)
var err error
if h.Client, err = lib.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
if h.Client, err = file.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil {
s.AjaxErr("修改失败")
}
}

View File

@ -1,7 +1,7 @@
package controllers
import (
"github.com/astaxie/beego"
"github.com/cnlh/nps/lib/beego"
)
type LoginController struct {

View File

@ -1,7 +1,7 @@
package routers
import (
"github.com/astaxie/beego"
"github.com/cnlh/nps/lib/beego"
"github.com/cnlh/nps/web/controllers"
)