From 97330bfbdc936a377791f72a11cc5264077c0fe4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=B2=B3?= Date: Fri, 15 Mar 2019 14:03:49 +0800 Subject: [PATCH] MUX optimization --- README.md | 143 +++++++++++++++++++++++++++---------- bridge/bridge.go | 67 ++++++++++++++++- client/client.go | 5 +- client/control.go | 2 +- client/health.go | 111 +++++++++++++--------------- cmd/npc/npc.go | 8 ++- cmd/nps/nps.go | 4 +- conf/clients.csv | 4 +- conf/hosts.csv | 5 +- conf/npc.conf | 32 +++++---- conf/nps.conf | 8 +-- conf/tasks.csv | 5 +- lib/common/util.go | 38 +++++++++- lib/config/config.go | 52 ++++++++++---- lib/conn/conn.go | 35 ++++++++- lib/file/file.go | 55 +++++++++++++- lib/file/obj.go | 34 ++++++--- lib/mux/conn.go | 114 ++++++++++++++++------------- lib/mux/map.go | 6 ++ lib/mux/mux.go | 141 ++++++++++++++++-------------------- lib/mux/mux_test.go | 26 +++++-- lib/mux/pmux.go | 7 +- lib/mux/pmux_test.go | 10 +-- lib/mux/queue.go | 82 +++++++++++++++++++++ lib/pool/pool.go | 10 ++- server/proxy/base.go | 5 +- server/proxy/http.go | 15 +++- server/proxy/socks5.go | 2 +- server/proxy/tcp.go | 10 ++- server/proxy/udp.go | 1 + server/server.go | 21 +++++- web/controllers/base.go | 5 +- web/views/index/index.html | 14 ++-- 33 files changed, 749 insertions(+), 328 deletions(-) create mode 100644 lib/mux/queue.go diff --git a/README.md b/README.md index b0e27fb..4d676b4 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ nps是一款轻量级、高性能、功能强大的**内网穿透**代理服务 * [编译安装](#源码安装) * [release安装](#release安装) * [使用示例(以web主控模式为主)](#使用示例) - * [统一准备工作](#统一准备工作) + * [统一准备工作](#统一准备工作(必做)) * [http|https域名解析](#域名解析) * [内网ssh连接即tcp隧道](#tcp隧道) * [内网dns解析即udp隧道](#udp隧道) @@ -89,6 +89,10 @@ nps是一款轻量级、高性能、功能强大的**内网穿透**代理服务 * [URL路由](#URL路由) * [限制ip访问](#限制ip访问) * [客户端最大连接数限制](#客户端最大连接数) + * [端口复用](#端口复用) + * [环境变量渲染](#环境变量渲染) + * [健康检查](#健康检查) + * [相关说明](#相关说明) * [流量统计](#流量统计) * [热更新支持](#热更新支持) @@ -124,7 +128,7 @@ nps是一款轻量级、高性能、功能强大的**内网穿透**代理服务 ## 使用示例 ### 统一准备工作(必做) -- 开启服务端,假设公网服务器ip为1.1.1.1,配置文件中`bridgePort`为8284,配置文件中`httpport`为8080 +- 开启服务端,假设公网服务器ip为1.1.1.1,配置文件中`bridgePort`为8284,配置文件中`web_port`为8080 - 访问1.1.1.1:8080 - 在客户端管理中创建一个客户端,记录下验证密钥 - 内网客户端运行(windows使用cmd运行加.exe) @@ -230,7 +234,7 @@ port=1000 想通过访问机器1的2001端口---->访问到内网2机器的22端口 **使用步骤** -- 在`nps.conf`中设置`serverIp`和`p2pPort` +- 在`nps.conf`中设置`p2p_ip`和`p2p_port` - 在刚才刚才创建的客户端中添加一条p2p代理,并设置唯一密钥p2pssh - 在需要连接的机器上(即机器1)以配置文件模式启动客户端,内容如下 @@ -291,23 +295,23 @@ port=2001 名称 | 含义 ---|--- -httpport | web管理端口 -password | web界面管理密码 -username | web界面管理账号 -bridgePort | 服务端客户端通信端口 -pemPath | ssl certFile绝对路径 -keyPath | ssl keyFile绝对路径 -httpsProxyPort | 域名代理https代理监听端口 -httpProxyPort | 域名代理http代理监听端口 -authKey|web api密钥 -bridgeType|客户端与服务端连接方式kcp或tcp -publicVkey|客户端以配置文件模式启动时的密钥,设置为空表示关闭客户端配置文件连接模式 -ipLimit|是否限制ip访问,true或false或忽略 -flowStoreInterval|服务端流量数据持久化间隔,单位分钟,忽略表示不持久化 -logLevel|日志输出级别 -cryptKey | 获取服务端authKey时的aes加密密钥,16位 -serverIp| 服务端Ip,使用p2p模式必填 -p2pPort|p2p模式开启的udp端口 +web_port | web管理端口 +web_password | web界面管理密码 +web_username | web界面管理账号 +bridge_port | 服务端客户端通信端口 +pem_path | ssl certFile绝对路径 +key_path | ssl keyFile绝对路径 +https_proxy_port | 域名代理https代理监听端口 +http_proxy_port | 域名代理http代理监听端口 +auth_key|web api密钥 +bridge_type|客户端与服务端连接方式kcp或tcp +public_vkey|客户端以配置文件模式启动时的密钥,设置为空表示关闭客户端配置文件连接模式 +ip_limit|是否限制ip访问,true或false或忽略 +flow_store_interval|服务端流量数据持久化间隔,单位分钟,忽略表示不持久化 +log_level|日志输出级别 +auth_crypt_key | 获取服务端authKey时的aes加密密钥,16位 +p2p_ip| 服务端Ip,使用p2p模式必填 +p2p_port|p2p模式开启的udp端口 ### 使用https @@ -351,7 +355,7 @@ server { ``` ### 关闭代理 -如需关闭http代理可在配置文件中将httpProxyPort设置为空,如需关闭https代理可在配置文件中将httpsProxyPort设置为空。 +如需关闭http代理可在配置文件中将http_proxy_port设置为空,如需关闭https代理可在配置文件中将https_proxy_port设置为空。 ### 将nps安装到系统 如果需要长期并且方便的运行nps服务端,可将nps安装到操作系统中,可执行命令 @@ -371,17 +375,17 @@ nps.exe test|start|stop|restart|status ``` ### 流量数据持久化 -服务端支持将流量数据持久化,默认情况下是关闭的,如果有需求可以设置`nps.conf`中的`flowStoreInterval`参数,单位为分钟 +服务端支持将流量数据持久化,默认情况下是关闭的,如果有需求可以设置`nps.conf`中的`flow_store_interval`参数,单位为分钟 **注意:** nps不会持久化通过公钥连接的客户端 ### 自定义客户端连接密钥 web上可以自定义客户端连接的密钥,但是必须具有唯一性 ### 关闭公钥访问 -可以将`nps.conf`中的`publicVkey`设置为空或者删除 +可以将`nps.conf`中的`public_vkey`设置为空或者删除 ### 关闭web管理 -可以将`nps.conf`中的`httpport`设置为空或者删除 +可以将`nps.conf`中的`web_port`设置为空或者删除 ## 客户端 @@ -616,17 +620,14 @@ LevelInformational->6 LevelDebug->7 由于是内网穿透,内网客户端与服务端之间的隧道存在大量的数据交换,为节省流量,加快传输速度,由此本程序支持SNNAPY形式的压缩。 -- 所有模式均支持数据压缩,可以与加密同时使用 -- 开启此功能会增加cpu和内存消耗 +- 所有模式均支持数据压缩 - 在web管理或客户端配置文件中设置 ### 加密传输 如果公司内网防火墙对外网访问进行了流量识别与屏蔽,例如禁止了ssh协议等,通过设置 配置文件,将服务端与客户端之间的通信内容加密传输,将会有效防止流量被拦截。 - -- 开启此功能会增加cpu和内存消耗 -- 在server端加上参数 +- nps使用tls加密,所以一定要保留conf目录下的密钥文件,同时也可以自行生成 - 在web管理或客户端配置文件中设置 @@ -660,13 +661,13 @@ LevelInformational->6 LevelDebug->7 支持客户端级带宽限制,带宽计算方式为入口和出口总和,权重均衡 ### 负载均衡 -本代理支持域名解析模式的负载均衡,在web域名添加或者编辑中内网目标分行填写多个目标即可实现轮训级别的负载均衡 +本代理支持域名解析模式和tcp代理的负载均衡,在web域名添加或者编辑中内网目标分行填写多个目标即可实现轮训级别的负载均衡 ### 端口白名单 -为了防止服务端上的端口被滥用,可在nps.conf中配置allowPorts限制可开启的端口,忽略或者不填表示端口不受限制,格式: +为了防止服务端上的端口被滥用,可在nps.conf中配置allow_ports限制可开启的端口,忽略或者不填表示端口不受限制,格式: ```ini -allowPorts=9001-9009,10001,11000-12000 +allow_ports=9001-9009,10001,11000-12000 ``` ### 端口范围映射 @@ -674,7 +675,7 @@ allowPorts=9001-9009,10001,11000-12000 ```ini [tcp] -mode=tcpServer +mode=tcp port=9001-9009,10001,11000-12000 target=8001-8009,10002,13000-14000 ``` @@ -683,7 +684,7 @@ target=8001-8009,10002,13000-14000 ### 端口范围映射到其他机器 ```ini [tcp] -mode=tcpServer +mode=tcp port=9001-9009,10001,11000-12000 target=8001-8009,10002,13000-14000 targetAddr=10.1.50.2 @@ -699,8 +700,8 @@ targetAddr=10.1.50.2 ``` ### KCP协议支持 -KCP 是一个快速可靠协议,能以比 TCP浪费10%-20%的带宽的代价,换取平均延迟降低 30%-40%,在弱网环境下对性能能有一定的提升。可在nps.conf中修改bridgeType为kcp -,设置后本代理将开启udp端口(bridgePort) +KCP 是一个快速可靠协议,能以比 TCP浪费10%-20%的带宽的代价,换取平均延迟降低 30%-40%,在弱网环境下对性能能有一定的提升。可在nps.conf中修改`bridge_type`为kcp +,设置后本代理将开启udp端口(`bridge_port`) 注意:当服务端为kcp时,客户端连接时也需要使用相同配置,无配置文件模式加上参数type=kcp,配置文件模式在配置文件中设置tp=kcp @@ -725,7 +726,7 @@ location=/static ### 限制ip访问 如果将一些危险性高的端口例如ssh端口暴露在公网上,可能会带来一些风险,本代理支持限制ip访问。 -**使用方法:** 在配置文件nps.conf中设置ipLimit=true,设置后仅通过注册的ip方可访问。 +**使用方法:** 在配置文件nps.conf中设置`ip_limit`=true,设置后仅通过注册的ip方可访问。 **ip注册**: 在需要访问的机器上,运行客户端 @@ -739,6 +740,74 @@ time为有效小时数,例如time=2,在当前时间后的两小时内,本 ### 客户端最大连接数 为防止恶意大量长连接,影响服务端程序的稳定性,可以在web或客户端配置文件中为每个客户端设置最大连接数。该功能针对`socks5`、`http正向代理`、`域名代理`、`tcp代理`、`私密代理`生效。 +### 端口复用 +在一些严格的网络环境中,对端口的个数等限制较大,nps支持强大端口复用功能。将`bridge_port`、 `http_proxy_port`、 `https_proxy_port` 、`web_port`都设置为同一端口,也能正常使用。 + +- 使用时将需要复用的端口设置为与`bridge_port`一致即可,将自动识别。 +- 如需将web管理的端口也复用,需要配置`web_host`也就是一个二级域名以便区分 + +### 环境变量渲染 +npc支持环境变量渲染以适应在某些特殊场景下的要求。 + +**在无配置文件启动模式下:** +设置环境变量 +``` +export NPC_SERVER_ADDR=1.1.1.1:8284 +export NPC_SERVER_VKEY=xxxxx +``` +直接执行./npc即可运行 + +**在配置文件启动模式下:** +```ini +[common] +server={{.NPC_SERVER_ADDR}} +tp=tcp +vkey={{.NPC_SERVER_VKEY}} +auto_reconnection=true +[web] +host={{.NPC_WEB_HOST}} +target={{.NPC_WEB_TARGET}} +``` +在配置文件中填入相应的环境变量名称,npc将自动进行渲染配置文件替换环境变量 + +### 健康检查 + +当客户端以配置文件模式启动时,支持多节点的健康检查。配置示例如下 + +```ini +[health_check_test1] +health_check_timeout=1 +health_check_max_failed=3 +health_check_interval=1 +health_http_url=/ +health_check_type=http +health_check_target=127.0.0.1:8083,127.0.0.1:8082 + +[health_check_test2] +health_check_timeout=1 +health_check_max_failed=3 +health_check_interval=1 +health_check_type=tcp +health_check_target=127.0.0.1:8083,127.0.0.1:8082 +``` +**health关键词必须在开头存在** + +第一种是http模式,也就是以get的方式请求目标+url,返回状态码为200表示成功 + +第一种是tcp模式,也就是以tcp的方式与目标建立连接,能成功建立连接表示成功 + +如果失败次数超过`health_check_max_failed`,nps则会移除该npc下的所有该目标,如果失败后目标重新上线,nps将自动将目标重新加入。 +项 | 含义 +---|--- +health_check_timeout | 健康检查超时时间 +health_check_max_failed | 健康检查允许失败次数 +health_check_interval | 健康检查间隔 +health_check_type | 健康检查类型 +health_check_target | 健康检查目标,多个以逗号(,)分隔 +health_check_type | 健康检查类型 +health_http_url | 健康检查url,仅http模式适用 + + ## 相关说明 ### 获取用户真实ip diff --git a/bridge/bridge.go b/bridge/bridge.go index 35a8d93..b3d5bc0 100755 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -17,6 +17,7 @@ import ( "github.com/cnlh/nps/vender/github.com/xtaci/kcp" "net" "strconv" + "strings" "sync" "time" ) @@ -42,6 +43,7 @@ type Bridge struct { Client map[int]*Client tunnelType string //bridge type kcp or tcp OpenTask chan *file.Tunnel + CloseTask chan *file.Tunnel CloseClient chan int SecretChan chan *conn.Secret clientLock sync.RWMutex @@ -57,6 +59,7 @@ func NewTunnel(tunnelPort int, tunnelType string, ipVerify bool, runList map[int t.Client = make(map[int]*Client) t.tunnelType = tunnelType t.OpenTask = make(chan *file.Tunnel) + t.CloseTask = make(chan *file.Tunnel) t.CloseClient = make(chan int) t.Register = make(map[string]time.Time) t.ipVerify = ipVerify @@ -106,6 +109,62 @@ func (s *Bridge) StartTunnel() error { return nil } +//get health information form client +func (s *Bridge) GetHealthFromClient(id int, c *conn.Conn) { + for { + if info, status, err := c.GetHealthInfo(); err != nil { + logs.Error(err) + break + } else if !status { //the status is true , return target to the targetArr + for _, v := range file.GetCsvDb().Tasks { + if v.Client.Id == id && v.Mode == "tcp" && strings.Contains(v.Target, info) { + v.Lock() + if v.TargetArr == nil || (len(v.TargetArr) == 0 && len(v.HealthRemoveArr) == 0) { + v.TargetArr = common.TrimArr(strings.Split(v.Target, "\n")) + } + v.TargetArr = common.RemoveArrVal(v.TargetArr, info) + if v.HealthRemoveArr == nil { + v.HealthRemoveArr = make([]string, 0) + } + v.HealthRemoveArr = append(v.HealthRemoveArr, info) + v.Unlock() + } + } + for _, v := range file.GetCsvDb().Hosts { + if v.Client.Id == id && strings.Contains(v.Target, info) { + v.Lock() + if v.TargetArr == nil || (len(v.TargetArr) == 0 && len(v.HealthRemoveArr) == 0) { + v.TargetArr = common.TrimArr(strings.Split(v.Target, "\n")) + } + v.TargetArr = common.RemoveArrVal(v.TargetArr, info) + if v.HealthRemoveArr == nil { + v.HealthRemoveArr = make([]string, 0) + } + v.HealthRemoveArr = append(v.HealthRemoveArr, info) + v.Unlock() + } + } + } else { //the status is false,remove target from the targetArr + for _, v := range file.GetCsvDb().Tasks { + if v.Client.Id == id && v.Mode == "tcp" && common.IsArrContains(v.HealthRemoveArr, info) && !common.IsArrContains(v.TargetArr, info) { + v.Lock() + v.TargetArr = append(v.TargetArr, info) + v.HealthRemoveArr = common.RemoveArrVal(v.HealthRemoveArr, info) + v.Unlock() + } + } + for _, v := range file.GetCsvDb().Hosts { + if v.Client.Id == id && common.IsArrContains(v.HealthRemoveArr, info) && !common.IsArrContains(v.TargetArr, info) { + v.Lock() + v.TargetArr = append(v.TargetArr, info) + v.HealthRemoveArr = common.RemoveArrVal(v.HealthRemoveArr, info) + v.Unlock() + } + } + } + } +} + //验证失败,返回错误验证flag,并且关闭连接 func (s *Bridge) verifyError(c *conn.Conn) { c.Write([]byte(common.VERIFY_EER)) @@ -187,6 +246,7 @@ func (s *Bridge) typeDeal(typeVal string, c *conn.Conn, id int) { s.Client[id] = NewClient(nil, nil, c) s.clientLock.Unlock() } + go s.GetHealthFromClient(id, c) logs.Info("clientId %d connection succeeded, address:%s ", id, c.Conn.RemoteAddr()) case common.WORK_CHAN: s.clientLock.Lock() @@ -264,7 +324,7 @@ func (s *Bridge) register(c *conn.Conn) { var hour int32 if err := binary.Read(c, binary.LittleEndian, &hour); err == nil { s.registerLock.Lock() - s.Register[common.GetIpByAddr(c.Conn.RemoteAddr().String())] = time.Now().Add(time.Hour * time.Duration(hour)) + s.Register[common.GetIpByAddr(c.Conn.RemoteAddr().String())] = time.Now().Add(time.Minute * time.Duration(hour)) s.registerLock.Unlock() } } @@ -282,11 +342,11 @@ func (s *Bridge) SendLinkInfo(clientId int, link *conn.Link, linkAddr string, t s.registerLock.Unlock() return nil, errors.New(fmt.Sprintf("The ip %s is not in the validation list", ip)) } else { + s.registerLock.Unlock() if !v.After(time.Now()) { return nil, errors.New(fmt.Sprintf("The validity of the ip %s has expired", ip)) } } - s.registerLock.Unlock() } var tunnel *mux.Mux if t != nil && t.Mode == "file" { @@ -311,7 +371,6 @@ func (s *Bridge) SendLinkInfo(clientId int, link *conn.Link, linkAddr string, t logs.Info("new connect error ,the target %s refuse to connect", link.Host) return } - } else { s.clientLock.Unlock() err = errors.New(fmt.Sprintf("the client %d is not connect", clientId)) @@ -366,6 +425,7 @@ loop: if err != nil { break loop } + file.GetCsvDb().Lock() for _, v := range file.GetCsvDb().Hosts { if v.Client.Id == id { str += v.Remark + common.CONN_DATA_SEQ @@ -376,6 +436,7 @@ loop: str += v.Remark + common.CONN_DATA_SEQ } } + file.GetCsvDb().Unlock() binary.Write(c, binary.LittleEndian, int32(len([]byte(str)))) binary.Write(c, binary.LittleEndian, []byte(str)) } diff --git a/client/client.go b/client/client.go index 56e9a93..a3fc3da 100755 --- a/client/client.go +++ b/client/client.go @@ -51,6 +51,7 @@ retry: } func (s *TRPClient) Close() { + s.stop <- true s.signal.Close() } @@ -58,7 +59,9 @@ func (s *TRPClient) Close() { func (s *TRPClient) processor(c *conn.Conn) { s.signal = c go s.dealChan() - go heathCheck(s.cnf, c) + if s.cnf != nil && len(s.cnf.Healths) > 0 { + go heathCheck(s.cnf.Healths, s.signal) + } for { flags, err := c.ReadFlag() if err != nil { diff --git a/client/control.go b/client/control.go index d4239d6..b527cf9 100644 --- a/client/control.go +++ b/client/control.go @@ -165,7 +165,7 @@ re: } c.Close() - NewRPClient(cnf.CommonConfig.Server, vkey, cnf.CommonConfig.Tp, cnf.CommonConfig.ProxyUrl).Start() + NewRPClient(cnf.CommonConfig.Server, vkey, cnf.CommonConfig.Tp, cnf.CommonConfig.ProxyUrl, cnf).Start() CloseLocalServer() goto re } diff --git a/client/health.go b/client/health.go index 1e09376..566a2ed 100644 --- a/client/health.go +++ b/client/health.go @@ -2,60 +2,59 @@ package client import ( "container/heap" - "github.com/cnlh/nps/lib/config" + "github.com/cnlh/nps/lib/conn" "github.com/cnlh/nps/lib/file" "github.com/cnlh/nps/lib/sheap" + "github.com/cnlh/nps/vender/github.com/astaxie/beego/logs" + "github.com/pkg/errors" "net" "net/http" "strings" "time" ) -func heathCheck(cnf *config.Config, c net.Conn) { - var hosts []*file.Host - var tunnels []*file.Tunnel +var isStart bool +var serverConn *conn.Conn + +func heathCheck(healths []*file.Health, c *conn.Conn) bool { + serverConn = c + if isStart { + for _, v := range healths { + v.HealthMap = make(map[string]int) + } + return true + } + isStart = true h := &sheap.IntHeap{} - for _, v := range cnf.Hosts { + for _, v := range healths { if v.HealthMaxFail > 0 && v.HealthCheckTimeout > 0 && v.HealthCheckInterval > 0 { - v.HealthNextTime = time.Now().Add(time.Duration(v.HealthCheckInterval)) + v.HealthNextTime = time.Now().Add(time.Duration(v.HealthCheckInterval) * time.Second) heap.Push(h, v.HealthNextTime.Unix()) v.HealthMap = make(map[string]int) - hosts = append(hosts, v) } } - for _, v := range cnf.Tasks { - if v.HealthMaxFail > 0 && v.HealthCheckTimeout > 0 && v.HealthCheckInterval > 0 { - v.HealthNextTime = time.Now().Add(time.Duration(v.HealthCheckInterval)) - heap.Push(h, v.HealthNextTime.Unix()) - v.HealthMap = make(map[string]int) - tunnels = append(tunnels, v) - } - } - if len(hosts) == 0 && len(tunnels) == 0 { - return - } + go session(healths, h) + return true +} + +func session(healths []*file.Health, h *sheap.IntHeap) { for { + if h.Len() == 0 { + logs.Error("health check error") + break + } rs := heap.Pop(h).(int64) - time.Now().Unix() - if rs < 0 { + if rs <= 0 { continue } - timer := time.NewTicker(time.Duration(rs)) + timer := time.NewTimer(time.Duration(rs) * time.Second) select { case <-timer.C: - for _, v := range hosts { + for _, v := range healths { if v.HealthNextTime.Before(time.Now()) { - v.HealthNextTime = time.Now().Add(time.Duration(v.HealthCheckInterval)) + v.HealthNextTime = time.Now().Add(time.Duration(v.HealthCheckInterval) * time.Second) //check - go checkHttp(v, c) - //reset time - heap.Push(h, v.HealthNextTime.Unix()) - } - } - for _, v := range tunnels { - if v.HealthNextTime.Before(time.Now()) { - v.HealthNextTime = time.Now().Add(time.Duration(v.HealthCheckInterval)) - //check - go checkTcp(v, c) + go check(v) //reset time heap.Push(h, v.HealthNextTime.Unix()) } @@ -64,41 +63,33 @@ func heathCheck(cnf *config.Config, c net.Conn) { } } -func checkTcp(t *file.Tunnel, c net.Conn) { - arr := strings.Split(t.Target, "\n") +//只针对一个端口 面向多个目标的情况 +func check(t *file.Health) { + arr := strings.Split(t.HealthCheckTarget, ",") + var err error + var rs *http.Response for _, v := range arr { - if _, err := net.DialTimeout("tcp", v, time.Duration(t.HealthCheckTimeout)); err != nil { - t.HealthMap[v] += 1 - } - if t.HealthMap[v] > t.HealthMaxFail { - t.HealthMap[v] += 1 - if t.HealthMap[v] == t.HealthMaxFail { - //send fail remove - ch <- file.NewHealthInfo("tcp", v, true) + if t.HealthCheckType == "tcp" { + _, err = net.DialTimeout("tcp", v, time.Duration(t.HealthCheckTimeout)*time.Second); + } else { + client := &http.Client{} + client.Timeout = time.Duration(t.HealthCheckTimeout) * time.Second + rs, err = client.Get("http://" + v + t.HttpHealthUrl) + if err == nil && rs.StatusCode != 200 { + err = errors.New("status code is not match") } + } + if err != nil { + t.HealthMap[v] += 1 } else if t.HealthMap[v] >= t.HealthMaxFail { //send recovery add - ch <- file.NewHealthInfo("tcp", v, false) + serverConn.SendHealthInfo(v, "1") t.HealthMap[v] = 0 } - } -} -func checkHttp(h *file.Host, ch chan *file.HealthInfo) { - arr := strings.Split(h.Target, "\n") - client := &http.Client{} - client.Timeout = time.Duration(h.HealthCheckTimeout) * time.Second - for _, v := range arr { - if _, err := client.Get(v + h.HttpHealthUrl); err != nil { - h.HealthMap[v] += 1 - if h.HealthMap[v] == h.HealthMaxFail { - //send fail remove - ch <- file.NewHealthInfo("http", v, true) - } - } else if h.HealthMap[v] >= h.HealthMaxFail { - //send recovery add - h.HealthMap[v] = 0 - ch <- file.NewHealthInfo("http", v, false) + if t.HealthMap[v] == t.HealthMaxFail { + //send fail remove + serverConn.SendHealthInfo(v, "0") } } } diff --git a/cmd/npc/npc.go b/cmd/npc/npc.go index 7131694..8bd9a0c 100644 --- a/cmd/npc/npc.go +++ b/cmd/npc/npc.go @@ -5,6 +5,7 @@ import ( "github.com/cnlh/nps/client" "github.com/cnlh/nps/lib/common" "github.com/cnlh/nps/lib/daemon" + "github.com/cnlh/nps/lib/version" "github.com/cnlh/nps/vender/github.com/astaxie/beego/logs" "os" "strings" @@ -44,14 +45,15 @@ func main() { } env := common.GetEnvMap() if *serverAddr == "" { - *serverAddr, _ = env["NPS_SERVER_ADDR"] + *serverAddr, _ = env["NPC_SERVER_ADDR"] } if *verifyKey == "" { - *verifyKey, _ = env["NPS_SERVER_VKEY"] + *verifyKey, _ = env["NPC_SERVER_VKEY"] } + logs.Info("the version of client is %s", version.VERSION) if *verifyKey != "" && *serverAddr != "" && *configPath == "" { for { - client.NewRPClient(*serverAddr, *verifyKey, *connType, *proxyUrl).Start() + client.NewRPClient(*serverAddr, *verifyKey, *connType, *proxyUrl, nil).Start() logs.Info("It will be reconnected in five seconds") time.Sleep(time.Second * 5) } diff --git a/cmd/nps/nps.go b/cmd/nps/nps.go index ebdfb26..4c3cb2e 100644 --- a/cmd/nps/nps.go +++ b/cmd/nps/nps.go @@ -6,6 +6,7 @@ import ( "github.com/cnlh/nps/lib/daemon" "github.com/cnlh/nps/lib/file" "github.com/cnlh/nps/lib/install" + "github.com/cnlh/nps/lib/version" "github.com/cnlh/nps/server" "github.com/cnlh/nps/server/connection" "github.com/cnlh/nps/server/test" @@ -54,9 +55,10 @@ func main() { } bridgePort, err := beego.AppConfig.Int("bridge_port") if err != nil { - logs.Error("Getting bridgePort error", err) + logs.Error("Getting bridge_port error", err) os.Exit(0) } + logs.Info("the version of server is %s ,allow client version to be %s", version.VERSION, version.GetVersion()) connection.InitConnectionService() server.StartNewServer(bridgePort, task, beego.AppConfig.String("bridge_type")) } diff --git a/conf/clients.csv b/conf/clients.csv index 74aeb20..a386ff7 100644 --- a/conf/clients.csv +++ b/conf/clients.csv @@ -1,2 +1,2 @@ -2,corjmrbhr33otit1,,true,,,1,false,0,0,0 -5,2dyy78gj7b9zw09l,,true,,,0,false,0,0,0 +2,corjmrbhr33otit1,,true,,,0,false,0,0,0 +5,2dyy78gj7b9zw09l,,true,,,1,false,0,0,0 diff --git a/conf/hosts.csv b/conf/hosts.csv index c091bd7..9c24cbe 100644 --- a/conf/hosts.csv +++ b/conf/hosts.csv @@ -1,2 +1,3 @@ -a.o.com,127.0.0.1:8082,2,,,111,/,3,5290945,32285,http -a.o.com,127.0.0.1:8080,2,,,,/,4,0,0,https +c.o.com,10.1.50.196:4000,5,,,,/,2,7543392,22379,all +a.o.com,127.0.0.1:8080,2,,,,/,3,0,0,all +b.o.com,127.0.0.1:8082,5,,,,/,4,0,0,all diff --git a/conf/npc.conf b/conf/npc.conf index ca98a9c..f0309e6 100644 --- a/conf/npc.conf +++ b/conf/npc.conf @@ -1,25 +1,33 @@ [common] -server={{.NPS_SERVER_ADDR}} +server=127.0.0.1:8024 tp=tcp -vkey={{.NPS_SERVER_VKEY}} +vkey=2dyy78gj7b9zw09l auto_reconnection=true [web] host=b.o.com -target=127.0.0.1:8082 -health_check_timeout = 3 -health_check_max_failed = 3 -health_check_interval = 10 +target=10.1.50.203:80 + +[health_check_test1] +health_check_timeout=1 +health_check_max_failed=3 +health_check_interval=1 health_http_url=/ +health_check_type=http +health_check_target=127.0.0.1:8083,127.0.0.1:8082 + +[health_check_test2] +health_check_timeout=1 +health_check_max_failed=3 +health_check_interval=1 +health_check_type=tcp +health_check_target=127.0.0.1:8083,127.0.0.1:8082 + [tcp] mode=tcp -target=8006-8010,8012 -port=9006-9010,9012 +target=127.0.0.1:8083,127.0.0.1:8082 +port=9006 targetAddr=123.206.77.88 -health_check_timeout = 3 -health_check_max_failed = 3 -health_check_interval = 10 -health_http_url=/ [socks5] mode=socks5 diff --git a/conf/nps.conf b/conf/nps.conf index 290b993..b336632 100755 --- a/conf/nps.conf +++ b/conf/nps.conf @@ -6,7 +6,7 @@ runmode = dev #HTTP(S) proxy port, no startup if empty http_proxy_port=80 -https_proxy_port=443 +#https_proxy_port=8024 #certFile absolute path pem_path=conf/server.pem #KeyFile absolute path @@ -26,7 +26,7 @@ public_vkey=123 #flow_store_interval=1 # log level LevelEmergency->0 LevelAlert->1 LevelCritical->2 LevelError->3 LevelWarning->4 LevelNotice->5 LevelInformational->6 LevelDebug->7 -#log_level=7 +log_level=7 #Whether to restrict IP access, true or false or ignore #ip_limit=true @@ -36,7 +36,7 @@ public_vkey=123 #p2p_port=6000 #web -web_host=c.o.com +web_host=a.o.com web_username=admin web_password=123 web_port = 8080 @@ -45,4 +45,4 @@ web_ip=0.0.0.0 auth_key=test auth_crypt_key =1234567812345678 -#allow_ports=9001-9009,10001,11000-12000 \ No newline at end of file +#allow_ports=9001-9009,10001,11000-12000 diff --git a/conf/tasks.csv b/conf/tasks.csv index fbe8834..c573f9f 100644 --- a/conf/tasks.csv +++ b/conf/tasks.csv @@ -1,4 +1,5 @@ 8025,socks5,,1,1,2,,0,0, 8026,httpProxy,,1,2,2,,0,0, -8001,tcp,"127.0.0.1:8080 -127.0.0.1:8082",1,3,5,,0,0, +9002,tcp,127.0.0.1:8082,1,3,2,,0,0, +9003,socks5,,1,5,5,,0,0, +9009,tcp,127.0.0.1:8082,1,21,5,,8244480,2382592, diff --git a/lib/common/util.go b/lib/common/util.go index 31c94e3..ce23ea6 100755 --- a/lib/common/util.go +++ b/lib/common/util.go @@ -83,7 +83,7 @@ func GetStrByBool(b bool) string { //int func GetIntNoErrByStr(str string) int { - i, _ := strconv.Atoi(str) + i, _ := strconv.Atoi(strings.TrimSpace(str)) return i } @@ -241,7 +241,8 @@ func GetIpByAddr(addr string) string { } func CopyBuffer(dst io.Writer, src io.Reader) (written int64, err error) { - buf := pool.BufPoolCopy.Get().([]byte) + buf := pool.GetBufPoolCopy() + defer pool.PutBufPoolCopy(buf) for { nr, er := src.Read(buf) if nr > 0 { @@ -265,7 +266,6 @@ func CopyBuffer(dst io.Writer, src io.Reader) (written int64, err error) { break } } - defer pool.PutBufPoolCopy(buf) return written, err } @@ -303,3 +303,35 @@ func GetEnvMap() map[string]string { } return m } + +func TrimArr(arr []string) []string { + newArr := make([]string, 0) + for _, v := range arr { + if v != "" { + newArr = append(newArr, v) + } + } + return newArr +} + +func IsArrContains(arr []string, val string) bool { + if arr == nil { + return false + } + for _, v := range arr { + if v == val { + return true + } + } + return false +} + +func RemoveArrVal(arr []string, val string) []string { + for k, v := range arr { + if v == val { + arr = append(arr[:k], arr[k+1:]...) + return arr + } + } + return arr +} diff --git a/lib/config/config.go b/lib/config/config.go index b75b1d5..de7ae26 100644 --- a/lib/config/config.go +++ b/lib/config/config.go @@ -29,6 +29,7 @@ type Config struct { CommonConfig *CommonConfig Hosts []*file.Host Tasks []*file.Tunnel + Healths []*file.Health LocalServer []*LocalServer } @@ -56,18 +57,24 @@ func NewConfig(path string) (c *Config, err error) { } nowContent = c.content[nowIndex:nextIndex] - if strings.Index(getTitleContent(c.title[i]), "secret") == 0 { + if strings.Index(getTitleContent(c.title[i]), "secret") == 0 && !strings.Contains(nowContent, "mode") { local := delLocalService(nowContent) local.Type = "secret" c.LocalServer = append(c.LocalServer, local) continue } - if strings.Index(getTitleContent(c.title[i]), "p2p") == 0 { + //except mode + if strings.Index(getTitleContent(c.title[i]), "p2p") == 0 && !strings.Contains(nowContent, "mode") { local := delLocalService(nowContent) local.Type = "p2p" c.LocalServer = append(c.LocalServer, local) continue } + //health set + if strings.Index(getTitleContent(c.title[i]), "health") == 0 { + c.Healths = append(c.Healths, dealHealth(nowContent)) + continue + } switch c.title[i] { case "[common]": c.CommonConfig = dealCommon(nowContent) @@ -146,15 +153,37 @@ func dealHost(s string) *file.Host { } else if len(item) == 1 { item = append(item, "") } - switch item[0] { + switch strings.TrimSpace(item[0]) { case "host": h.Host = item[1] case "target": h.Target = strings.Replace(item[1], ",", "\n", -1) case "host_change": h.HostChange = item[1] + case "schemego": + h.Scheme = item[1] case "location": h.Location = item[1] + default: + if strings.Contains(item[0], "header") { + headerChange += strings.Replace(item[0], "header_", "", -1) + ":" + item[1] + "\n" + } + h.HeaderChange = headerChange + } + } + return h +} + +func dealHealth(s string) *file.Health { + h := &file.Health{} + for _, v := range strings.Split(s, "\n") { + item := strings.Split(v, "=") + if len(item) == 0 { + continue + } else if len(item) == 1 { + item = append(item, "") + } + switch strings.TrimSpace(item[0]) { case "health_check_timeout": h.HealthCheckTimeout = common.GetIntNoErrByStr(item[1]) case "health_check_max_failed": @@ -163,11 +192,10 @@ func dealHost(s string) *file.Host { h.HealthCheckInterval = common.GetIntNoErrByStr(item[1]) case "health_http_url": h.HttpHealthUrl = item[1] - default: - if strings.Contains(item[0], "header") { - headerChange += strings.Replace(item[0], "header_", "", -1) + ":" + item[1] + "\n" - } - h.HeaderChange = headerChange + case "health_check_type": + h.HealthCheckType = item[1] + case "health_check_target": + h.HealthCheckTarget = item[1] } } return h @@ -182,7 +210,7 @@ func dealTunnel(s string) *file.Tunnel { } else if len(item) == 1 { item = append(item, "") } - switch item[0] { + switch strings.TrimSpace(item[0]) { case "port": t.Ports = item[1] case "mode": @@ -197,12 +225,6 @@ func dealTunnel(s string) *file.Tunnel { t.LocalPath = item[1] case "strip_pre": t.StripPre = item[1] - case "health_check_timeout": - t.HealthCheckTimeout = common.GetIntNoErrByStr(item[1]) - case "health_check_max_failed": - t.HealthMaxFail = common.GetIntNoErrByStr(item[1]) - case "health_check_interval": - t.HealthCheckInterval = common.GetIntNoErrByStr(item[1]) } } return t diff --git a/lib/conn/conn.go b/lib/conn/conn.go index 5f7f615..7662653 100755 --- a/lib/conn/conn.go +++ b/lib/conn/conn.go @@ -150,8 +150,6 @@ func (s *Conn) SetReadDeadline(t time.Duration, tp string) { func (s *Conn) SendLinkInfo(link *Link) (int, error) { raw := bytes.NewBuffer([]byte{}) common.BinaryWrite(raw, link.ConnType, link.Host, common.GetStrByBool(link.Compress), common.GetStrByBool(link.Crypt), link.RemoteAddr) - s.Lock() - defer s.Unlock() return s.Write(raw.Bytes()) } @@ -176,6 +174,33 @@ func (s *Conn) GetLinkInfo() (lk *Link, err error) { return } +//send info for link +func (s *Conn) SendHealthInfo(info, status string) (int, error) { + raw := bytes.NewBuffer([]byte{}) + common.BinaryWrite(raw, info, status) + s.Lock() + defer s.Unlock() + return s.Write(raw.Bytes()) +} + +//get health info from conn +func (s *Conn) GetHealthInfo() (info string, status bool, err error) { + var l int + buf := pool.BufPoolMax.Get().([]byte) + defer pool.PutBufPoolMax(buf) + if l, err = s.GetLen(); err != nil { + return + } else if _, err = s.ReadLen(l, buf); err != nil { + return + } else { + arr := strings.Split(string(buf[:l]), common.CONN_DATA_SEQ) + if len(arr) >= 2 { + return arr[0], common.GetBoolByStr(arr[1]), nil + } + } + return "", false, errors.New("receive health info error") +} + //send host info func (s *Conn) SendHostInfo(h *file.Host) (int, error) { /* @@ -188,7 +213,7 @@ func (s *Conn) SendHostInfo(h *file.Host) (int, error) { */ raw := bytes.NewBuffer([]byte{}) binary.Write(raw, binary.LittleEndian, []byte(common.NEW_HOST)) - common.BinaryWrite(raw, h.Host, h.Target, h.HeaderChange, h.HostChange, h.Remark, h.Location) + common.BinaryWrite(raw, h.Host, h.Target, h.HeaderChange, h.HostChange, h.Remark, h.Location, h.Scheme) s.Lock() defer s.Unlock() return s.Write(raw.Bytes()) @@ -228,6 +253,10 @@ func (s *Conn) GetHostInfo() (h *file.Host, err error) { h.HostChange = arr[3] h.Remark = arr[4] h.Location = arr[5] + h.Scheme = arr[6] + if h.Scheme == "" { + h.Scheme = "all" + } h.Flow = new(file.Flow) h.NoStore = true } diff --git a/lib/file/file.go b/lib/file/file.go index 9161433..113f01c 100644 --- a/lib/file/file.go +++ b/lib/file/file.go @@ -32,7 +32,7 @@ type Csv struct { ClientIncreaseId int //客户端id TaskIncreaseId int //任务自增ID HostIncreaseId int - sync.Mutex + sync.RWMutex } func (s *Csv) StoreTasksToCsv() { @@ -43,6 +43,7 @@ func (s *Csv) StoreTasksToCsv() { } defer csvFile.Close() writer := csv.NewWriter(csvFile) + s.Lock() for _, task := range s.Tasks { if task.NoStore { continue @@ -64,6 +65,7 @@ func (s *Csv) StoreTasksToCsv() { logs.Error(err.Error()) } } + s.Unlock() writer.Flush() } @@ -147,6 +149,7 @@ func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (int, error) { } func (s *Csv) NewTask(t *Tunnel) error { + s.Lock() for _, v := range s.Tasks { if (v.Mode == "secret" || v.Mode == "p2p") && v.Password == t.Password { return errors.New(fmt.Sprintf("Secret mode keys %s must be unique", t.Password)) @@ -154,33 +157,42 @@ func (s *Csv) NewTask(t *Tunnel) error { } t.Flow = new(Flow) s.Tasks = append(s.Tasks, t) + s.Unlock() s.StoreTasksToCsv() return nil } func (s *Csv) UpdateTask(t *Tunnel) error { + s.Lock() for _, v := range s.Tasks { if v.Id == t.Id { + s.Unlock() s.StoreTasksToCsv() return nil } } + s.Unlock() return errors.New("the task is not exist") } func (s *Csv) DelTask(id int) error { + s.Lock() for k, v := range s.Tasks { if v.Id == id { s.Tasks = append(s.Tasks[:k], s.Tasks[k+1:]...) + s.Unlock() s.StoreTasksToCsv() return nil } } + s.Unlock() return errors.New("不存在") } //md5 password func (s *Csv) GetTaskByMd5Password(p string) *Tunnel { + s.Lock() + defer s.Unlock() for _, v := range s.Tasks { if crypt.Md5(v.Password) == p { return v @@ -190,6 +202,8 @@ func (s *Csv) GetTaskByMd5Password(p string) *Tunnel { } func (s *Csv) GetTask(id int) (v *Tunnel, err error) { + s.Lock() + defer s.Unlock() for _, v = range s.Tasks { if v.Id == id { return @@ -210,6 +224,8 @@ func (s *Csv) StoreHostToCsv() { writer := csv.NewWriter(csvFile) // 将map中的Post转换成slice,因为csv的Write需要slice参数 // 并写入csv文件 + s.Lock() + defer s.Unlock() for _, host := range s.Hosts { if host.NoStore { continue @@ -313,17 +329,22 @@ func (s *Csv) LoadHostFromCsv() { } func (s *Csv) DelHost(id int) error { + s.Lock() for k, v := range s.Hosts { if v.Id == id { s.Hosts = append(s.Hosts[:k], s.Hosts[k+1:]...) + s.Unlock() s.StoreHostToCsv() return nil } } + s.Unlock() return errors.New("不存在") } func (s *Csv) IsHostExist(h *Host) bool { + s.Lock() + defer s.Unlock() for _, v := range s.Hosts { if v.Host == h.Host && h.Location == v.Location && (v.Scheme == "all" || v.Scheme == h.Scheme) { return true @@ -340,24 +361,31 @@ func (s *Csv) NewHost(t *Host) error { t.Location = "/" } t.Flow = new(Flow) + s.Lock() s.Hosts = append(s.Hosts, t) + s.Unlock() s.StoreHostToCsv() return nil } func (s *Csv) UpdateHost(t *Host) error { + s.Lock() for _, v := range s.Hosts { if v.Host == t.Host { + s.Unlock() s.StoreHostToCsv() return nil } } + s.Unlock() return errors.New("不存在") } func (s *Csv) GetHost(start, length int, id int) ([]*Host, int) { list := make([]*Host, 0) var cnt int + s.Lock() + defer s.Unlock() for _, v := range s.Hosts { if id == 0 || v.Client.Id == id { cnt++ @@ -372,13 +400,16 @@ func (s *Csv) GetHost(start, length int, id int) ([]*Host, int) { } func (s *Csv) DelClient(id int) error { + s.Lock() for k, v := range s.Clients { if v.Id == id { s.Clients = append(s.Clients[:k], s.Clients[k+1:]...) + s.Unlock() s.StoreClientsToCsv() return nil } } + s.Unlock() return errors.New("不存在") } @@ -402,13 +433,15 @@ reset: c.Flow = new(Flow) } s.Lock() - defer s.Unlock() s.Clients = append(s.Clients, c) + s.Unlock() s.StoreClientsToCsv() return nil } func (s *Csv) VerifyVkey(vkey string, id int) bool { + s.Lock() + defer s.Unlock() for _, v := range s.Clients { if v.VerifyKey == vkey && v.Id != id { return false @@ -426,7 +459,6 @@ func (s *Csv) GetClientId() int { func (s *Csv) UpdateClient(t *Client) error { s.Lock() - defer s.Unlock() for _, v := range s.Clients { if v.Id == t.Id { v.Cnf = t.Cnf @@ -435,16 +467,20 @@ func (s *Csv) UpdateClient(t *Client) error { v.RateLimit = t.RateLimit v.Flow = t.Flow v.Rate = t.Rate + s.Unlock() s.StoreClientsToCsv() return nil } } + s.Unlock() return errors.New("该客户端不存在") } func (s *Csv) GetClientList(start, length int) ([]*Client, int) { list := make([]*Client, 0) var cnt int + s.Lock() + defer s.Unlock() for _, v := range s.Clients { if v.NoDisplay { continue @@ -460,6 +496,8 @@ func (s *Csv) GetClientList(start, length int) ([]*Client, int) { } func (s *Csv) GetClient(id int) (v *Client, err error) { + s.Lock() + defer s.Unlock() for _, v = range s.Clients { if v.Id == id { return @@ -469,6 +507,8 @@ func (s *Csv) GetClient(id int) (v *Client, err error) { return } func (s *Csv) GetClientIdByVkey(vkey string) (id int, err error) { + s.Lock() + defer s.Unlock() for _, v := range s.Clients { if crypt.Md5(v.VerifyKey) == vkey { id = v.Id @@ -480,6 +520,8 @@ func (s *Csv) GetClientIdByVkey(vkey string) (id int, err error) { } func (s *Csv) GetHostById(id int) (h *Host, err error) { + s.Lock() + defer s.Unlock() for _, v := range s.Hosts { if v.Id == id { h = v @@ -495,7 +537,12 @@ func (s *Csv) GetInfoByHost(host string, r *http.Request) (h *Host, err error) { var hosts []*Host //Handling Ported Access host = common.GetIpByAddr(host) + s.Lock() + defer s.Unlock() for _, v := range s.Hosts { + if v.IsClose { + continue + } //Remove http(s) http(s)://a.proxy.com //*.proxy.com *.a.proxy.com Do some pan-parsing tmp := strings.Replace(v.Host, "*", `\w+?`, -1) @@ -533,6 +580,8 @@ func (s *Csv) StoreClientsToCsv() { } defer csvFile.Close() writer := csv.NewWriter(csvFile) + s.Lock() + defer s.Unlock() for _, client := range s.Clients { if client.NoStore { continue diff --git a/lib/file/obj.go b/lib/file/obj.go index 80a6198..d5ee365 100644 --- a/lib/file/obj.go +++ b/lib/file/obj.go @@ -2,6 +2,7 @@ package file import ( "github.com/cnlh/nps/lib/rate" + "github.com/pkg/errors" "strings" "sync" "time" @@ -78,7 +79,14 @@ func (s *Client) GetConn() bool { return false } +//modify the hosts and the tunnels by health information +func (s *Client) ModifyTarget() { + +} + func (s *Client) HasTunnel(t *Tunnel) bool { + GetCsvDb().Lock() + defer GetCsvDb().Unlock() for _, v := range GetCsvDb().Tasks { if v.Client.Id == s.Id && v.Port == t.Port { return true @@ -88,6 +96,8 @@ func (s *Client) HasTunnel(t *Tunnel) bool { } func (s *Client) HasHost(h *Host) bool { + GetCsvDb().Lock() + defer GetCsvDb().Unlock() for _, v := range GetCsvDb().Hosts { if v.Client.Id == s.Id && v.Host == h.Host && h.Location == v.Location { return true @@ -126,14 +136,19 @@ type Health struct { HealthMap map[string]int HttpHealthUrl string HealthRemoveArr []string + HealthCheckType string + HealthCheckTarget string } -func (s *Tunnel) GetRandomTarget() string { +func (s *Tunnel) GetRandomTarget() (string, error) { if s.TargetArr == nil { s.TargetArr = strings.Split(s.Target, "\n") } if len(s.TargetArr) == 1 { - return s.TargetArr[0] + return s.TargetArr[0], nil + } + if len(s.TargetArr) == 0 { + return "", errors.New("all inward-bending targets are offline") } s.Lock() defer s.Unlock() @@ -141,7 +156,7 @@ func (s *Tunnel) GetRandomTarget() string { s.NowIndex = -1 } s.NowIndex++ - return s.TargetArr[s.NowIndex] + return s.TargetArr[s.NowIndex], nil } type Config struct { @@ -165,23 +180,26 @@ type Host struct { TargetArr []string NoStore bool Scheme string //http https all + IsClose bool Health sync.RWMutex } -func (s *Host) GetRandomTarget() string { +func (s *Host) GetRandomTarget() (string, error) { if s.TargetArr == nil { s.TargetArr = strings.Split(s.Target, "\n") } if len(s.TargetArr) == 1 { - return s.TargetArr[0] + return s.TargetArr[0], nil + } + if len(s.TargetArr) == 0 { + return "", errors.New("all inward-bending targets are offline") } s.Lock() defer s.Unlock() if s.NowIndex >= len(s.TargetArr)-1 { s.NowIndex = -1 - } else { - s.NowIndex++ } - return s.TargetArr[s.NowIndex] + s.NowIndex++ + return s.TargetArr[s.NowIndex], nil } diff --git a/lib/mux/conn.go b/lib/mux/conn.go index ef34cf4..6fd81b8 100644 --- a/lib/mux/conn.go +++ b/lib/mux/conn.go @@ -5,6 +5,7 @@ import ( "github.com/cnlh/nps/lib/pool" "io" "net" + "sync" "time" ) @@ -15,78 +16,76 @@ type conn struct { connStatusFailCh chan struct{} readTimeOut time.Time writeTimeOut time.Time - sendMsgCh chan *msg //mux - sendStatusCh chan int32 //mux readBuffer []byte startRead int //now read position endRead int //now end read readFlag bool readCh chan struct{} + waitQueue *sliceEntry + stopWrite bool connId int32 isClose bool readWait bool mux *Mux } -type msg struct { - connId int32 - content []byte -} +var connPool = sync.Pool{} -func NewMsg(connId int32, content []byte) *msg { - return &msg{ - connId: connId, - content: content, - } -} - -func NewConn(connId int32, mux *Mux, sendMsgCh chan *msg, sendStatusCh chan int32) *conn { - return &conn{ +func NewConn(connId int32, mux *Mux) *conn { + c := &conn{ readCh: make(chan struct{}), - readBuffer: pool.BufPoolCopy.Get().([]byte), getStatusCh: make(chan struct{}), connStatusOkCh: make(chan struct{}), connStatusFailCh: make(chan struct{}), - readTimeOut: time.Time{}, - writeTimeOut: time.Time{}, - sendMsgCh: sendMsgCh, - sendStatusCh: sendStatusCh, + waitQueue: NewQueue(), connId: connId, - isClose: false, mux: mux, } + return c } func (s *conn) Read(buf []byte) (n int, err error) { - if s.isClose { + if s.isClose || buf == nil { return 0, errors.New("the conn has closed") } - if s.endRead-s.startRead == 0 { - s.readWait = true - if t := s.readTimeOut.Sub(time.Now()); t > 0 { - timer := time.NewTimer(t) - select { - case <-timer.C: - s.readWait = false - return 0, errors.New("read timeout") - case <-s.readCh: + if s.endRead-s.startRead == 0 { //read finish or start + if s.waitQueue.Size() == 0 { + s.readWait = true + if t := s.readTimeOut.Sub(time.Now()); t > 0 { + timer := time.NewTimer(t) + defer timer.Stop() + select { + case <-timer.C: + s.readWait = false + return 0, errors.New("read timeout") + case <-s.readCh: + } + } else { + <-s.readCh } - } else { - <-s.readCh } - } - s.readWait = false - if s.isClose { - return 0, io.EOF + if s.isClose { //If the connection is closed instead of continuing command + return 0, errors.New("the conn has closed") + } + if node, err := s.waitQueue.Pop(); err != nil { + s.Close() + return 0, io.EOF + } else { + pool.PutBufPoolCopy(s.readBuffer) + s.readBuffer = node.val + s.endRead = node.l + s.startRead = 0 + } } if len(buf) < s.endRead-s.startRead { n = copy(buf, s.readBuffer[s.startRead:s.startRead+len(buf)]) s.startRead += n } else { n = copy(buf, s.readBuffer[s.startRead:s.endRead]) - s.startRead = 0 - s.endRead = 0 - s.sendStatusCh <- s.connId + s.startRead += n + if s.waitQueue.Size() < s.mux.waitQueueSize/2 { + s.mux.sendInfo(MUX_MSG_SEND_OK, s.connId, nil) + } } return } @@ -99,6 +98,7 @@ func (s *conn) Write(buf []byte) (int, error) { go s.write(buf, ch) if t := s.writeTimeOut.Sub(time.Now()); t > 0 { timer := time.NewTimer(t) + defer timer.Stop() select { case <-timer.C: return 0, errors.New("write timeout") @@ -112,18 +112,18 @@ func (s *conn) Write(buf []byte) (int, error) { } return len(buf), nil } - func (s *conn) write(buf []byte, ch chan struct{}) { start := 0 l := len(buf) for { + if s.stopWrite { + <-s.getStatusCh + } if l-start > pool.PoolSizeCopy { - s.sendMsgCh <- NewMsg(s.connId, buf[start:start+pool.PoolSizeCopy]) + s.mux.sendInfo(MUX_NEW_MSG, s.connId, buf[start:start+pool.PoolSizeCopy]) start += pool.PoolSizeCopy - <-s.getStatusCh } else { - s.sendMsgCh <- NewMsg(s.connId, buf[start:l]) - <-s.getStatusCh + s.mux.sendInfo(MUX_NEW_MSG, s.connId, buf[start:l]) break } } @@ -131,18 +131,30 @@ func (s *conn) write(buf []byte, ch chan struct{}) { } func (s *conn) Close() error { + if s.isClose { + return errors.New("the conn has closed") + } + times := 0 +retry: + if s.waitQueue.Size() > 0 && times < 600 { + time.Sleep(time.Millisecond * 100) + times++ + goto retry + } if s.isClose { return errors.New("the conn has closed") } s.isClose = true pool.PutBufPoolCopy(s.readBuffer) - close(s.getStatusCh) - close(s.connStatusOkCh) - close(s.connStatusFailCh) - close(s.readCh) - if !s.mux.IsClose { - s.sendMsgCh <- NewMsg(s.connId, nil) + if s.readWait { + s.readCh <- struct{}{} } + s.waitQueue.Clear() + s.mux.connMap.Delete(s.connId) + if !s.mux.IsClose { + s.mux.sendInfo(MUX_CONN_CLOSE, s.connId, nil) + } + connPool.Put(s) return nil } diff --git a/lib/mux/map.go b/lib/mux/map.go index 99dfe25..0801201 100644 --- a/lib/mux/map.go +++ b/lib/mux/map.go @@ -44,6 +44,12 @@ func (s *connMap) Close() { s.closeCh <- struct{}{} } +func (s *connMap) Delete(id int32) { + s.Lock() + defer s.Unlock() + delete(s.connMap, id) +} + func (s *connMap) clean() { ticker := time.NewTimer(time.Minute * 1) for { diff --git a/lib/mux/mux.go b/lib/mux/mux.go index 6dab612..cfbac6d 100644 --- a/lib/mux/mux.go +++ b/lib/mux/mux.go @@ -22,38 +22,35 @@ const ( MUX_PING MUX_CONN_CLOSE MUX_PING_RETURN + MUX_STOP_WRITE RETRY_TIME = 2 //Heart beat allowed fault tolerance times ) type Mux struct { net.Listener - conn net.Conn - connMap *connMap - sendMsgCh chan *msg //write msg chan - sendStatusCh chan int32 //write read ok chan - newConnCh chan *conn - id int32 - closeChan chan struct{} - IsClose bool - pingOk int + conn net.Conn + connMap *connMap + newConnCh chan *conn + id int32 + closeChan chan struct{} + IsClose bool + pingOk int + waitQueueSize int sync.Mutex } func NewMux(c net.Conn) *Mux { m := &Mux{ - conn: c, - connMap: NewConnMap(), - sendMsgCh: make(chan *msg), - sendStatusCh: make(chan int32), - id: 0, - closeChan: make(chan struct{}), - newConnCh: make(chan *conn), - IsClose: false, + conn: c, + connMap: NewConnMap(), + id: 0, + closeChan: make(chan struct{}), + newConnCh: make(chan *conn), + IsClose: false, + waitQueueSize: 10, //TODO :In order to be more efficient, this value can be dynamically generated according to the delay algorithm. } //read session by flag go m.readSession() - //write session - go m.writeSession() //ping go m.ping() return m @@ -63,7 +60,7 @@ func (s *Mux) NewConn() (*conn, error) { if s.IsClose { return nil, errors.New("the mux has closed") } - conn := NewConn(s.getId(), s, s.sendMsgCh, s.sendStatusCh) + conn := NewConn(s.getId(), s) raw := bytes.NewBuffer([]byte{}) if err := binary.Write(raw, binary.LittleEndian, MUX_NEW_CONN); err != nil { return nil, err @@ -76,10 +73,14 @@ func (s *Mux) NewConn() (*conn, error) { if _, err := s.conn.Write(raw.Bytes()); err != nil { return nil, err } + //set a timer timeout 30 second + timer := time.NewTimer(time.Second * 30) + defer timer.Stop() select { case <-conn.connStatusOkCh: return conn, nil case <-conn.connStatusFailCh: + case <-timer.C: } return nil, errors.New("create connection fail,the server refused the connection") } @@ -95,10 +96,24 @@ func (s *Mux) Addr() net.Addr { return s.conn.LocalAddr() } +func (s *Mux) sendInfo(flag int32, id int32, content []byte) error { + raw := bytes.NewBuffer([]byte{}) + binary.Write(raw, binary.LittleEndian, flag) + binary.Write(raw, binary.LittleEndian, id) + if content != nil && len(content) > 0 { + binary.Write(raw, binary.LittleEndian, int32(len(content))) + binary.Write(raw, binary.LittleEndian, content) + } + if _, err := s.conn.Write(raw.Bytes()); err != nil || s.pingOk > RETRY_TIME { + s.Close() + return err + } + return nil +} + func (s *Mux) ping() { go func() { ticker := time.NewTicker(time.Second * 5) - raw := bytes.NewBuffer([]byte{}) for { select { case <-ticker.C: @@ -107,11 +122,7 @@ func (s *Mux) ping() { if (math.MaxInt32 - s.id) < 10000 { s.id = 0 } - raw.Reset() - binary.Write(raw, binary.LittleEndian, MUX_PING_FLAG) - binary.Write(raw, binary.LittleEndian, MUX_PING) - if _, err := s.conn.Write(raw.Bytes()); err != nil || s.pingOk > RETRY_TIME { - s.Close() + if err := s.sendInfo(MUX_PING_FLAG, MUX_PING, nil); err != nil || s.pingOk > RETRY_TIME { break } s.pingOk += 1 @@ -122,45 +133,9 @@ func (s *Mux) ping() { } } -func (s *Mux) writeSession() { - go func() { - raw := bytes.NewBuffer([]byte{}) - for { - raw.Reset() - select { - case msg := <-s.sendMsgCh: - if msg == nil { - break - } - if msg.content == nil { //close - binary.Write(raw, binary.LittleEndian, MUX_CONN_CLOSE) - binary.Write(raw, binary.LittleEndian, msg.connId) - break - } - binary.Write(raw, binary.LittleEndian, MUX_NEW_MSG) - binary.Write(raw, binary.LittleEndian, msg.connId) - binary.Write(raw, binary.LittleEndian, int32(len(msg.content))) - binary.Write(raw, binary.LittleEndian, msg.content) - case connId := <-s.sendStatusCh: - binary.Write(raw, binary.LittleEndian, MUX_MSG_SEND_OK) - binary.Write(raw, binary.LittleEndian, connId) - } - if _, err := s.conn.Write(raw.Bytes()); err != nil { - s.Close() - break - } - } - }() - select { - case <-s.closeChan: - } -} - func (s *Mux) readSession() { + var buf []byte go func() { - raw := bytes.NewBuffer([]byte{}) - buf := pool.BufPoolCopy.Get().([]byte) - defer pool.PutBufPoolCopy(buf) for { var flag, i int32 var n int @@ -171,24 +146,19 @@ func (s *Mux) readSession() { } switch flag { case MUX_NEW_CONN: //new conn - conn := NewConn(i, s, s.sendMsgCh, s.sendStatusCh) + conn := NewConn(i, s) s.connMap.Set(i, conn) //it has been set before send ok s.newConnCh <- conn - raw.Reset() - binary.Write(raw, binary.LittleEndian, MUX_NEW_CONN_OK) - binary.Write(raw, binary.LittleEndian, i) - s.conn.Write(raw.Bytes()) + s.sendInfo(MUX_NEW_CONN_OK, i, nil) continue case MUX_PING_FLAG: //ping - raw.Reset() - binary.Write(raw, binary.LittleEndian, MUX_PING_RETURN) - binary.Write(raw, binary.LittleEndian, MUX_PING) - s.conn.Write(raw.Bytes()) + s.sendInfo(MUX_PING_RETURN, MUX_PING, nil) continue case MUX_PING_RETURN: s.pingOk -= 1 continue case MUX_NEW_MSG: + buf = pool.GetBufPoolCopy() if n, err = ReadLenBytes(buf, s.conn); err != nil { break } @@ -196,20 +166,36 @@ func (s *Mux) readSession() { if conn, ok := s.connMap.Get(i); ok && !conn.isClose { switch flag { case MUX_NEW_MSG: //new msg from remote conn - copy(conn.readBuffer, buf[:n]) - conn.endRead = n + //insert wait queue + conn.waitQueue.Push(NewBufNode(buf, n)) + //judge len if >xxx ,send stop if conn.readWait { + conn.readWait = false conn.readCh <- struct{}{} } + if conn.waitQueue.Size() > s.waitQueueSize { + s.sendInfo(MUX_STOP_WRITE, conn.connId, nil) + } + case MUX_STOP_WRITE: + conn.stopWrite = true case MUX_MSG_SEND_OK: //the remote has read - conn.getStatusCh <- struct{}{} + if conn.stopWrite { + conn.stopWrite = false + select { + case conn.getStatusCh <- struct{}{}: + default: + } + } case MUX_NEW_CONN_OK: //conn ok conn.connStatusOkCh <- struct{}{} case MUX_NEW_CONN_Fail: conn.connStatusFailCh <- struct{}{} case MUX_CONN_CLOSE: //close the connection - conn.Close() + go conn.Close() + s.connMap.Delete(i) } + } else if flag == MUX_NEW_MSG { + pool.PutBufPoolCopy(buf) } } else { break @@ -231,9 +217,6 @@ func (s *Mux) Close() error { s.closeChan <- struct{}{} s.closeChan <- struct{}{} s.closeChan <- struct{}{} - close(s.closeChan) - close(s.sendMsgCh) - close(s.sendStatusCh) return s.conn.Close() } diff --git a/lib/mux/mux_test.go b/lib/mux/mux_test.go index aa97ae0..651224e 100644 --- a/lib/mux/mux_test.go +++ b/lib/mux/mux_test.go @@ -2,7 +2,7 @@ package mux import ( "github.com/cnlh/nps/lib/common" - conn3 "github.com/cnlh/nps/lib/conn" + "github.com/cnlh/nps/lib/pool" "github.com/cnlh/nps/vender/github.com/astaxie/beego/logs" "log" "net" @@ -17,7 +17,7 @@ var conn2 net.Conn func TestNewMux(t *testing.T) { go func() { - http.ListenAndServe("0.0.0.0:8899", nil) + http.ListenAndServe("0.0.0.0:8889", nil) }() logs.EnableFuncCallDepth(true) logs.SetLogFuncCallDepth(3) @@ -32,12 +32,12 @@ func TestNewMux(t *testing.T) { log.Fatalln(err) } go func(c net.Conn) { - c2, err := net.Dial("tcp", "127.0.0.1:8080") + c2, err := net.Dial("tcp", "10.1.50.196:4000") if err != nil { log.Fatalln(err) } - go common.CopyBuffer(c2, conn3.NewCryptConn(c, true, nil)) - common.CopyBuffer(conn3.NewCryptConn(c, true, nil), c2) + go common.CopyBuffer(c2, c) + common.CopyBuffer(c, c2) c.Close() c2.Close() }(c) @@ -60,8 +60,8 @@ func TestNewMux(t *testing.T) { if err != nil { log.Fatalln(err) } - go common.CopyBuffer(conn3.NewCryptConn(tmpCpnn, true, nil), conn) - common.CopyBuffer(conn, conn3.NewCryptConn(tmpCpnn, true, nil)) + go common.CopyBuffer(tmpCpnn, conn) + common.CopyBuffer(conn, tmpCpnn) conn.Close() tmpCpnn.Close() }(conn) @@ -95,3 +95,15 @@ func client() { log.Fatalln(err) } } + +func TestNewConn(t *testing.T) { + buf := pool.GetBufPoolCopy() + logs.Warn(len(buf), cap(buf)) + //b := pool.GetBufPoolCopy() + //b[0] = 1 + //b[1] = 2 + //b[2] = 3 + b := []byte{1, 2, 3} + logs.Warn(copy(buf[:3], b), len(buf), cap(buf)) + logs.Warn(len(buf), buf[0]) +} diff --git a/lib/mux/pmux.go b/lib/mux/pmux.go index 340497e..498fd84 100644 --- a/lib/mux/pmux.go +++ b/lib/mux/pmux.go @@ -5,10 +5,12 @@ package mux import ( "bufio" "bytes" + "github.com/cnlh/nps/lib/common" "github.com/cnlh/nps/vender/github.com/astaxie/beego/logs" "github.com/pkg/errors" "io" "net" + "os" "strconv" "strings" "time" @@ -59,7 +61,8 @@ func (pMux *PortMux) Start() error { } pMux.Listener, err = net.ListenTCP("tcp", tcpAddr) if err != nil { - return err + logs.Error(err) + os.Exit(0) } go func() { for { @@ -105,7 +108,7 @@ func (pMux *PortMux) process(conn net.Conn) { str = strings.Replace(str, "host:", "", -1) str = strings.TrimSpace(str) // Determine whether it is the same as the manager domain name - if str == pMux.managerHost { + if common.GetIpByAddr(str) == pMux.managerHost { ch = pMux.managerConn } else { ch = pMux.httpConn diff --git a/lib/mux/pmux_test.go b/lib/mux/pmux_test.go index 01a9b6c..641ae2a 100644 --- a/lib/mux/pmux_test.go +++ b/lib/mux/pmux_test.go @@ -11,7 +11,7 @@ func TestPortMux_Close(t *testing.T) { logs.EnableFuncCallDepth(true) logs.SetLogFuncCallDepth(3) - pMux := NewPortMux(8888) + pMux := NewPortMux(8888,"Ds") go func() { if pMux.Start() != nil { logs.Warn("Error") @@ -19,21 +19,21 @@ func TestPortMux_Close(t *testing.T) { }() time.Sleep(time.Second * 3) go func() { - l := pMux.GetHttpsAccept() + l := pMux.GetHttpListener() conn, err := l.Accept() logs.Warn(conn, err) }() go func() { - l := pMux.GetHttpAccept() + l := pMux.GetHttpListener() conn, err := l.Accept() logs.Warn(conn, err) }() go func() { - l := pMux.GetClientAccept() + l := pMux.GetHttpListener() conn, err := l.Accept() logs.Warn(conn, err) }() - l := pMux.GetManagerAccept() + l := pMux.GetHttpListener() conn, err := l.Accept() logs.Warn(conn, err) } diff --git a/lib/mux/queue.go b/lib/mux/queue.go new file mode 100644 index 0000000..f03bafd --- /dev/null +++ b/lib/mux/queue.go @@ -0,0 +1,82 @@ +package mux + +import ( + "errors" + "github.com/cnlh/nps/lib/pool" + "sync" +) + +type Element *bufNode + +type bufNode struct { + val []byte //buf value + l int //length +} + +func NewBufNode(buf []byte, l int) *bufNode { + return &bufNode{ + val: buf, + l: l, + } +} + +type Queue interface { + Push(e Element) //向队列中添加元素 + Pop() Element //移除队列中最前面的元素 + Clear() bool //清空队列 + Size() int //获取队列的元素个数 + IsEmpty() bool //判断队列是否是空 +} + +type sliceEntry struct { + element []Element + sync.Mutex +} + +func NewQueue() *sliceEntry { + return &sliceEntry{} +} + +//向队列中添加元素 +func (entry *sliceEntry) Push(e Element) { + entry.Lock() + defer entry.Unlock() + entry.element = append(entry.element, e) +} + +//移除队列中最前面的额元素 +func (entry *sliceEntry) Pop() (Element, error) { + if entry.IsEmpty() { + return nil, errors.New("queue is empty!") + } + entry.Lock() + defer entry.Unlock() + firstElement := entry.element[0] + entry.element = entry.element[1:] + return firstElement, nil +} + +func (entry *sliceEntry) Clear() bool { + entry.Lock() + defer entry.Unlock() + if entry.IsEmpty() { + return false + } + for i := 0; i < entry.Size(); i++ { + pool.PutBufPoolCopy(entry.element[i].val) + entry.element[i] = nil + } + entry.element = nil + return true +} + +func (entry *sliceEntry) Size() int { + return len(entry.element) +} + +func (entry *sliceEntry) IsEmpty() bool { + if len(entry.element) == 0 { + return true + } + return false +} diff --git a/lib/pool/pool.go b/lib/pool/pool.go index f1e58b1..997c836 100644 --- a/lib/pool/pool.go +++ b/lib/pool/pool.go @@ -32,10 +32,10 @@ var BufPoolSmall = sync.Pool{ } var BufPoolCopy = sync.Pool{ New: func() interface{} { - return make([]byte, PoolSizeCopy) + buf := make([]byte, PoolSizeCopy) + return &buf }, } - func PutBufPoolUdp(buf []byte) { if cap(buf) == PoolSizeUdp { BufPoolUdp.Put(buf[:PoolSizeUdp]) @@ -44,10 +44,14 @@ func PutBufPoolUdp(buf []byte) { func PutBufPoolCopy(buf []byte) { if cap(buf) == PoolSizeCopy { - BufPoolCopy.Put(buf[:PoolSizeCopy]) + BufPoolCopy.Put(&buf) } } +func GetBufPoolCopy() ([]byte) { + return (*BufPoolCopy.Get().(*[]byte))[:PoolSizeCopy] +} + func PutBufPoolSmall(buf []byte) { if cap(buf) == PoolSizeSmall { BufPoolSmall.Put(buf[:PoolSizeSmall]) diff --git a/server/proxy/base.go b/server/proxy/base.go index 1eb5949..6ec40ee 100644 --- a/server/proxy/base.go +++ b/server/proxy/base.go @@ -81,9 +81,10 @@ func (s *BaseServer) DealClient(c *conn.Conn, addr string, rb []byte, tp string) return err } else { if rb != nil { - target.Write(rb) + //HTTP proxy crypt or compress + conn.GetConn(target, link.Crypt, link.Compress, s.task.Client.Rate, true).Write(rb) } - conn.CopyWaitGroup(target, c.Conn, link.Crypt, link.Compress, s.task.Client.Rate, s.task.Client.Flow, true) + conn.CopyWaitGroup(target, c.Conn, link.Crypt, link.Compress, s.task.Client.Rate, s.task.Flow, true) } s.task.Client.AddConn() diff --git a/server/proxy/http.go b/server/proxy/http.go index 54be9ca..6c78283 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -134,6 +134,9 @@ func (s *httpServer) process(c *conn.Conn, r *http.Request) { err error connClient io.ReadWriteCloser scheme = r.URL.Scheme + lk *conn.Link + targetAddr string + wg sync.WaitGroup ) if host, err = file.GetCsvDb().GetInfoByHost(r.Host, r); err != nil { logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI) @@ -159,7 +162,11 @@ func (s *httpServer) process(c *conn.Conn, r *http.Request) { logs.Warn("auth error", err, r.RemoteAddr) break } - lk := conn.NewLink(common.CONN_TCP, host.GetRandomTarget(), host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr) + if targetAddr, err = host.GetRandomTarget(); err != nil { + logs.Warn(err.Error()) + break + } + lk = conn.NewLink(common.CONN_TCP, targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr) if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, c.Conn.RemoteAddr().String(), nil); err != nil { logs.Notice("connect to target %s error %s", lk.Host, err) break @@ -167,10 +174,12 @@ func (s *httpServer) process(c *conn.Conn, r *http.Request) { connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) isConn = false go func() { + wg.Add(1) w, _ := common.CopyBuffer(c, connClient) host.Flow.Add(0, w) c.Close() target.Close() + wg.Done() }() } else { r, err = http.ReadRequest(bufio.NewReader(c)) @@ -197,7 +206,6 @@ func (s *httpServer) process(c *conn.Conn, r *http.Request) { host = hostTmp lastHost = host isConn = true - goto start } } @@ -208,7 +216,7 @@ func (s *httpServer) process(c *conn.Conn, r *http.Request) { break } host.Flow.Add(int64(len(b)), 0) - logs.Trace("%s request, method %s, host %s, url %s, remote address %s, target %s", r.URL.Scheme, r.Method, r.Host, r.RequestURI, r.RemoteAddr, host.Target) + logs.Trace("%s request, method %s, host %s, url %s, remote address %s, target %s", r.URL.Scheme, r.Method, r.Host, r.RequestURI, r.RemoteAddr, lk.Host) //write connClient.Write(b) } @@ -220,6 +228,7 @@ end: if target != nil { target.Close() } + wg.Wait() if host != nil { host.Client.AddConn() } diff --git a/server/proxy/socks5.go b/server/proxy/socks5.go index 3a2b0c2..524421f 100755 --- a/server/proxy/socks5.go +++ b/server/proxy/socks5.go @@ -149,7 +149,7 @@ func (s *Sock5ModeServer) doConnect(c net.Conn, command uint8) { return } else { s.sendReply(c, succeeded) - conn.CopyWaitGroup(target, c, link.Crypt, link.Compress, s.task.Client.Rate, s.task.Client.Flow, true) + conn.CopyWaitGroup(target, c, link.Crypt, link.Compress, s.task.Client.Rate, s.task.Flow, true) } s.task.Client.AddConn() diff --git a/server/proxy/tcp.go b/server/proxy/tcp.go index 8efb074..f27deef 100755 --- a/server/proxy/tcp.go +++ b/server/proxy/tcp.go @@ -51,7 +51,7 @@ func (s *TunnelModeServer) Start() error { c.Close() } if s.task.Client.GetConn() { - logs.Trace("New tcp connection,client %d,remote address %s", s.task.Client.Id, c.RemoteAddr()) + logs.Trace("New tcp connection,local port %d,client %d,remote address %s", s.task.Port, s.task.Client.Id, c.RemoteAddr()) go s.process(conn.NewConn(c), s) } else { logs.Info("Connections exceed the current client %d limit", s.task.Client.Id) @@ -109,7 +109,13 @@ type process func(c *conn.Conn, s *TunnelModeServer) error //tcp隧道模式 func ProcessTunnel(c *conn.Conn, s *TunnelModeServer) error { - return s.DealClient(c, s.task.GetRandomTarget(), nil, common.CONN_TCP) + targetAddr, err := s.task.GetRandomTarget() + if err != nil { + c.Close() + logs.Warn("tcp port %d ,client id %d,task id %d connect error %s", s.task.Port, s.task.Client.Id, s.task.Id, err.Error()) + return err + } + return s.DealClient(c, targetAddr, nil, common.CONN_TCP) } //http代理模式 diff --git a/server/proxy/udp.go b/server/proxy/udp.go index 335f35d..4995928 100755 --- a/server/proxy/udp.go +++ b/server/proxy/udp.go @@ -57,6 +57,7 @@ func (s *UdpModeServer) process(addr *net.UDPAddr, data []byte) { buf := pool.BufPoolUdp.Get().([]byte) defer pool.BufPoolUdp.Put(buf) target.Write(data) + s.task.Flow.Add(int64(len(data)), 0) if n, err := target.Read(buf); err != nil { logs.Warn(err) return diff --git a/server/server.go b/server/server.go index 74284f0..f6038d9 100644 --- a/server/server.go +++ b/server/server.go @@ -52,9 +52,13 @@ func DealBridgeTask() { select { case t := <-Bridge.OpenTask: AddTask(t) + case t := <-Bridge.CloseTask: + StopServer(t.Id) case id := <-Bridge.CloseClient: DelTunnelAndHostByClientId(id) file.GetCsvDb().DelClient(id) + case tunnel := <-Bridge.OpenTask: + StartTask(tunnel.Id) case s := <-Bridge.SecretChan: logs.Trace("New secret connection, addr", s.Conn.Conn.RemoteAddr()) if t := file.GetCsvDb().GetTaskByMd5Password(s.Password); t != nil { @@ -202,6 +206,8 @@ func DelTask(id int) error { func GetTunnel(start, length int, typeVal string, clientId int) ([]*file.Tunnel, int) { list := make([]*file.Tunnel, 0) var cnt int + file.GetCsvDb().Lock() + defer file.GetCsvDb().Unlock() for _, v := range file.GetCsvDb().Tasks { if (typeVal != "" && v.Mode != typeVal) || (typeVal == "" && clientId != v.Client.Id) { continue @@ -234,6 +240,8 @@ func GetClientList(start, length int) (list []*file.Client, cnt int) { } func dealClientData(list []*file.Client) { + file.GetCsvDb().Lock() + defer file.GetCsvDb().Unlock() for _, v := range list { if _, ok := Bridge.Client[v.Id]; ok { v.IsConnect = true @@ -261,19 +269,27 @@ func dealClientData(list []*file.Client) { //根据客户端id删除其所属的所有隧道和域名 func DelTunnelAndHostByClientId(clientId int) { var ids []int + file.GetCsvDb().Lock() for _, v := range file.GetCsvDb().Tasks { if v.Client.Id == clientId { ids = append(ids, v.Id) } } + file.GetCsvDb().Unlock() for _, id := range ids { DelTask(id) } + ids = ids[:0] + file.GetCsvDb().Lock() for _, v := range file.GetCsvDb().Hosts { if v.Client.Id == clientId { - file.GetCsvDb().DelHost(v.Id) + ids = append(ids, v.Id) } } + file.GetCsvDb().Unlock() + for _, id := range ids { + file.GetCsvDb().DelHost(id) + } } //关闭客户端连接 @@ -300,6 +316,8 @@ func GetDashboardData() map[string]interface{} { data["inletFlowCount"] = int(in) data["exportFlowCount"] = int(out) var tcp, udp, secret, socks5, p2p, http int + file.GetCsvDb().Lock() + defer file.GetCsvDb().Unlock() for _, v := range file.GetCsvDb().Tasks { switch v.Mode { case "tcp": @@ -366,7 +384,6 @@ func GetDashboardData() map[string]interface{} { data["sys"+strconv.Itoa(i+1)] = serverStatus[i*fg] } } - return data } diff --git a/web/controllers/base.go b/web/controllers/base.go index a52ae28..b6045f3 100755 --- a/web/controllers/base.go +++ b/web/controllers/base.go @@ -46,10 +46,7 @@ func (s *BaseController) display(tpl ...string) { tplname = s.controllerName + "/" + s.actionName + ".html" } ip := s.Ctx.Request.Host - if strings.LastIndex(ip, ":") > 0 { - arr := strings.Split(common.GetHostByName(ip), ":") - s.Data["ip"] = arr[0] - } + s.Data["ip"] = common.GetIpByAddr(ip) s.Data["bridgeType"] = beego.AppConfig.String("bridge_type") if common.IsWindows() { s.Data["win"] = ".exe" diff --git a/web/views/index/index.html b/web/views/index/index.html index b564a98..fac3607 100755 --- a/web/views/index/index.html +++ b/web/views/index/index.html @@ -75,7 +75,7 @@
  • - httpProxyPort + http proxy port
    {{.data.httpProxyPort}} @@ -85,7 +85,7 @@
  • - httpsProxyPort + https proxy port
    {{.data.httpsProxyPort}} @@ -95,7 +95,7 @@
  • - ipLimit + ip limit
    {{.data.ipLimit}} @@ -105,7 +105,7 @@
  • - flowStoreInterval + flow store interval
    {{.data.flowStoreInterval}} @@ -115,7 +115,7 @@
  • - logLevel + log level
    {{.data.logLevel}} @@ -125,7 +125,7 @@
  • - p2pPort + p2p port
    {{.data.p2pPort}} @@ -135,7 +135,7 @@
  • - serverIp + server ip
    {{.data.serverIp}}