mirror of https://github.com/fatedier/frp
fatedier
8 years ago
115 changed files with 5515 additions and 3491 deletions
Before Width: | Height: | Size: 4.2 KiB After Width: | Height: | Size: 4.2 KiB |
Before Width: | Height: | Size: 1.8 KiB After Width: | Height: | Size: 1.8 KiB |
File diff suppressed because one or more lines are too long
@ -0,0 +1,345 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
"runtime" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/fatedier/frp/models/config" |
||||
"github.com/fatedier/frp/models/msg" |
||||
"github.com/fatedier/frp/utils/log" |
||||
"github.com/fatedier/frp/utils/net" |
||||
"github.com/fatedier/frp/utils/util" |
||||
"github.com/fatedier/frp/utils/version" |
||||
) |
||||
|
||||
type Control struct { |
||||
// frpc service
|
||||
svr *Service |
||||
|
||||
// login message to server
|
||||
loginMsg *msg.Login |
||||
|
||||
// proxy configures
|
||||
pxyCfgs map[string]config.ProxyConf |
||||
|
||||
// proxies
|
||||
proxies map[string]Proxy |
||||
|
||||
// control connection
|
||||
conn net.Conn |
||||
|
||||
// put a message in this channel to send it over control connection to server
|
||||
sendCh chan (msg.Message) |
||||
|
||||
// read from this channel to get the next message sent by server
|
||||
readCh chan (msg.Message) |
||||
|
||||
// run id got from server
|
||||
runId string |
||||
|
||||
// connection or other error happens , control will try to reconnect to server
|
||||
closed int32 |
||||
|
||||
// goroutines can block by reading from this channel, it will be closed only in reader() when control connection is closed
|
||||
closedCh chan int |
||||
|
||||
// last time got the Pong message
|
||||
lastPong time.Time |
||||
|
||||
mu sync.RWMutex |
||||
|
||||
log.Logger |
||||
} |
||||
|
||||
func NewControl(svr *Service, pxyCfgs map[string]config.ProxyConf) *Control { |
||||
loginMsg := &msg.Login{ |
||||
Arch: runtime.GOARCH, |
||||
Os: runtime.GOOS, |
||||
PoolCount: config.ClientCommonCfg.PoolCount, |
||||
User: config.ClientCommonCfg.User, |
||||
Version: version.Full(), |
||||
} |
||||
return &Control{ |
||||
svr: svr, |
||||
loginMsg: loginMsg, |
||||
pxyCfgs: pxyCfgs, |
||||
proxies: make(map[string]Proxy), |
||||
sendCh: make(chan msg.Message, 10), |
||||
readCh: make(chan msg.Message, 10), |
||||
closedCh: make(chan int), |
||||
Logger: log.NewPrefixLogger(""), |
||||
} |
||||
} |
||||
|
||||
// 1. login
|
||||
// 2. start reader() writer() manager()
|
||||
// 3. connection closed
|
||||
// 4. In reader(): close closedCh and exit, controler() get it
|
||||
// 5. In controler(): close readCh and sendCh, manager() and writer() will exit
|
||||
// 6. In controler(): ini readCh, sendCh, closedCh
|
||||
// 7. In controler(): start new reader(), writer(), manager()
|
||||
// controler() will keep running
|
||||
func (ctl *Control) Run() error { |
||||
err := ctl.login() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
go ctl.controler() |
||||
go ctl.manager() |
||||
go ctl.writer() |
||||
go ctl.reader() |
||||
|
||||
// send NewProxy message for all configured proxies
|
||||
for _, cfg := range ctl.pxyCfgs { |
||||
var newProxyMsg msg.NewProxy |
||||
cfg.UnMarshalToMsg(&newProxyMsg) |
||||
ctl.sendCh <- &newProxyMsg |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (ctl *Control) NewWorkConn() { |
||||
workConn, err := net.ConnectTcpServerByHttpProxy(config.ClientCommonCfg.HttpProxy, |
||||
fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, config.ClientCommonCfg.ServerPort)) |
||||
if err != nil { |
||||
ctl.Warn("start new work connection error: %v", err) |
||||
return |
||||
} |
||||
|
||||
m := &msg.NewWorkConn{ |
||||
RunId: ctl.runId, |
||||
} |
||||
if err = msg.WriteMsg(workConn, m); err != nil { |
||||
ctl.Warn("work connection write to server error: %v", err) |
||||
workConn.Close() |
||||
return |
||||
} |
||||
|
||||
var startMsg msg.StartWorkConn |
||||
if err = msg.ReadMsgInto(workConn, &startMsg); err != nil { |
||||
ctl.Error("work connection closed and no response from server, %v", err) |
||||
workConn.Close() |
||||
return |
||||
} |
||||
workConn.AddLogPrefix(startMsg.ProxyName) |
||||
|
||||
// dispatch this work connection to related proxy
|
||||
if pxy, ok := ctl.proxies[startMsg.ProxyName]; ok { |
||||
go pxy.InWorkConn(workConn) |
||||
workConn.Info("start a new work connection") |
||||
} |
||||
} |
||||
|
||||
func (ctl *Control) init() { |
||||
ctl.sendCh = make(chan msg.Message, 10) |
||||
ctl.readCh = make(chan msg.Message, 10) |
||||
ctl.closedCh = make(chan int) |
||||
} |
||||
|
||||
// login send a login message to server and wait for a loginResp message.
|
||||
func (ctl *Control) login() (err error) { |
||||
if ctl.conn != nil { |
||||
ctl.conn.Close() |
||||
} |
||||
conn, err := net.ConnectTcpServerByHttpProxy(config.ClientCommonCfg.HttpProxy, |
||||
fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, config.ClientCommonCfg.ServerPort)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
now := time.Now().Unix() |
||||
ctl.loginMsg.PrivilegeKey = util.GetAuthKey(config.ClientCommonCfg.PrivilegeToken, now) |
||||
ctl.loginMsg.Timestamp = now |
||||
ctl.loginMsg.RunId = ctl.runId |
||||
|
||||
if err = msg.WriteMsg(conn, ctl.loginMsg); err != nil { |
||||
return err |
||||
} |
||||
|
||||
var loginRespMsg msg.LoginResp |
||||
if err = msg.ReadMsgInto(conn, &loginRespMsg); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if loginRespMsg.Error != "" { |
||||
err = fmt.Errorf("%s", loginRespMsg.Error) |
||||
ctl.Error("%s", loginRespMsg.Error) |
||||
return err |
||||
} |
||||
|
||||
ctl.conn = conn |
||||
// update runId got from server
|
||||
ctl.runId = loginRespMsg.RunId |
||||
ctl.ClearLogPrefix() |
||||
ctl.AddLogPrefix(loginRespMsg.RunId) |
||||
ctl.Info("login to server success, get run id [%s]", loginRespMsg.RunId) |
||||
|
||||
// login success, so we let closedCh available again
|
||||
ctl.closedCh = make(chan int) |
||||
ctl.lastPong = time.Now() |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (ctl *Control) reader() { |
||||
defer func() { |
||||
if err := recover(); err != nil { |
||||
ctl.Error("panic error: %v", err) |
||||
} |
||||
}() |
||||
|
||||
for { |
||||
if m, err := msg.ReadMsg(ctl.conn); err != nil { |
||||
if err == io.EOF { |
||||
ctl.Debug("read from control connection EOF") |
||||
close(ctl.closedCh) |
||||
return |
||||
} else { |
||||
ctl.Warn("read error: %v", err) |
||||
continue |
||||
} |
||||
} else { |
||||
ctl.readCh <- m |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (ctl *Control) writer() { |
||||
for { |
||||
if m, ok := <-ctl.sendCh; !ok { |
||||
ctl.Info("control writer is closing") |
||||
return |
||||
} else { |
||||
if err := msg.WriteMsg(ctl.conn, m); err != nil { |
||||
ctl.Warn("write message to control connection error: %v", err) |
||||
return |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (ctl *Control) manager() { |
||||
defer func() { |
||||
if err := recover(); err != nil { |
||||
ctl.Error("panic error: %v", err) |
||||
} |
||||
}() |
||||
|
||||
hbSend := time.NewTicker(time.Duration(config.ClientCommonCfg.HeartBeatInterval) * time.Second) |
||||
defer hbSend.Stop() |
||||
hbCheck := time.NewTicker(time.Second) |
||||
defer hbCheck.Stop() |
||||
|
||||
for { |
||||
select { |
||||
case <-hbSend.C: |
||||
// send heartbeat to server
|
||||
ctl.sendCh <- &msg.Ping{} |
||||
case <-hbCheck.C: |
||||
if time.Since(ctl.lastPong) > time.Duration(config.ClientCommonCfg.HeartBeatTimeout)*time.Second { |
||||
ctl.Warn("heartbeat timeout") |
||||
return |
||||
} |
||||
case rawMsg, ok := <-ctl.readCh: |
||||
if !ok { |
||||
return |
||||
} |
||||
|
||||
switch m := rawMsg.(type) { |
||||
case *msg.ReqWorkConn: |
||||
go ctl.NewWorkConn() |
||||
case *msg.NewProxyResp: |
||||
// Server will return NewProxyResp message to each NewProxy message.
|
||||
// Start a new proxy handler if no error got
|
||||
if m.Error != "" { |
||||
ctl.Warn("[%s] start error: %s", m.ProxyName, m.Error) |
||||
continue |
||||
} |
||||
oldPxy, ok := ctl.proxies[m.ProxyName] |
||||
if ok { |
||||
oldPxy.Close() |
||||
} |
||||
cfg, ok := ctl.pxyCfgs[m.ProxyName] |
||||
if !ok { |
||||
// it will never go to this branch
|
||||
ctl.Warn("[%s] no proxy conf found", m.ProxyName) |
||||
continue |
||||
} |
||||
pxy := NewProxy(ctl, cfg) |
||||
pxy.Run() |
||||
ctl.proxies[m.ProxyName] = pxy |
||||
ctl.Info("[%s] start proxy success", m.ProxyName) |
||||
case *msg.Pong: |
||||
ctl.lastPong = time.Now() |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
// control keep watching closedCh, start a new connection if previous control connection is closed
|
||||
func (ctl *Control) controler() { |
||||
var err error |
||||
maxDelayTime := 30 * time.Second |
||||
delayTime := time.Second |
||||
for { |
||||
// we won't get any variable from this channel
|
||||
_, ok := <-ctl.closedCh |
||||
if !ok { |
||||
// close related channels
|
||||
close(ctl.readCh) |
||||
close(ctl.sendCh) |
||||
time.Sleep(time.Second) |
||||
|
||||
// loop util reconnect to server success
|
||||
for { |
||||
ctl.Info("try to reconnect to server...") |
||||
err = ctl.login() |
||||
if err != nil { |
||||
ctl.Warn("reconnect to server error: %v", err) |
||||
time.Sleep(delayTime) |
||||
delayTime = delayTime * 2 |
||||
if delayTime > maxDelayTime { |
||||
delayTime = maxDelayTime |
||||
} |
||||
continue |
||||
} |
||||
// reconnect success, init the delayTime
|
||||
delayTime = time.Second |
||||
break |
||||
} |
||||
|
||||
// init related channels and variables
|
||||
ctl.init() |
||||
|
||||
// previous work goroutines should be closed and start them here
|
||||
go ctl.manager() |
||||
go ctl.writer() |
||||
go ctl.reader() |
||||
|
||||
// send NewProxy message for all configured proxies
|
||||
for _, cfg := range ctl.pxyCfgs { |
||||
var newProxyMsg msg.NewProxy |
||||
cfg.UnMarshalToMsg(&newProxyMsg) |
||||
ctl.sendCh <- &newProxyMsg |
||||
} |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,141 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
|
||||
"github.com/fatedier/frp/models/config" |
||||
"github.com/fatedier/frp/models/proto/tcp" |
||||
"github.com/fatedier/frp/utils/net" |
||||
) |
||||
|
||||
type Proxy interface { |
||||
Run() |
||||
InWorkConn(conn net.Conn) |
||||
Close() |
||||
} |
||||
|
||||
func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy) { |
||||
switch cfg := pxyConf.(type) { |
||||
case *config.TcpProxyConf: |
||||
pxy = &TcpProxy{ |
||||
cfg: cfg, |
||||
ctl: ctl, |
||||
} |
||||
case *config.UdpProxyConf: |
||||
pxy = &UdpProxy{ |
||||
cfg: cfg, |
||||
ctl: ctl, |
||||
} |
||||
case *config.HttpProxyConf: |
||||
pxy = &HttpProxy{ |
||||
cfg: cfg, |
||||
ctl: ctl, |
||||
} |
||||
case *config.HttpsProxyConf: |
||||
pxy = &HttpsProxy{ |
||||
cfg: cfg, |
||||
ctl: ctl, |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// TCP
|
||||
type TcpProxy struct { |
||||
cfg *config.TcpProxyConf |
||||
ctl *Control |
||||
} |
||||
|
||||
func (pxy *TcpProxy) Run() { |
||||
} |
||||
|
||||
func (pxy *TcpProxy) Close() { |
||||
} |
||||
|
||||
func (pxy *TcpProxy) InWorkConn(conn net.Conn) { |
||||
defer conn.Close() |
||||
localConn, err := net.ConnectTcpServer(fmt.Sprintf("%s:%d", pxy.cfg.LocalIp, pxy.cfg.LocalPort)) |
||||
if err != nil { |
||||
conn.Error("connect to local service [%s:%d] error: %v", pxy.cfg.LocalIp, pxy.cfg.LocalPort, err) |
||||
return |
||||
} |
||||
|
||||
var remote io.ReadWriteCloser |
||||
remote = conn |
||||
if pxy.cfg.UseEncryption { |
||||
remote, err = tcp.WithEncryption(remote, []byte(config.ClientCommonCfg.PrivilegeToken)) |
||||
if err != nil { |
||||
conn.Error("create encryption stream error: %v", err) |
||||
return |
||||
} |
||||
} |
||||
if pxy.cfg.UseCompression { |
||||
remote = tcp.WithCompression(remote) |
||||
} |
||||
conn.Debug("join connections") |
||||
tcp.Join(localConn, remote) |
||||
conn.Debug("join connections closed") |
||||
} |
||||
|
||||
// UDP
|
||||
type UdpProxy struct { |
||||
cfg *config.UdpProxyConf |
||||
ctl *Control |
||||
} |
||||
|
||||
func (pxy *UdpProxy) Run() { |
||||
} |
||||
|
||||
func (pxy *UdpProxy) Close() { |
||||
} |
||||
|
||||
func (pxy *UdpProxy) InWorkConn(conn net.Conn) { |
||||
defer conn.Close() |
||||
} |
||||
|
||||
// HTTP
|
||||
type HttpProxy struct { |
||||
cfg *config.HttpProxyConf |
||||
ctl *Control |
||||
} |
||||
|
||||
func (pxy *HttpProxy) Run() { |
||||
} |
||||
|
||||
func (pxy *HttpProxy) Close() { |
||||
} |
||||
|
||||
func (pxy *HttpProxy) InWorkConn(conn net.Conn) { |
||||
defer conn.Close() |
||||
} |
||||
|
||||
// HTTPS
|
||||
type HttpsProxy struct { |
||||
cfg *config.HttpsProxyConf |
||||
ctl *Control |
||||
} |
||||
|
||||
func (pxy *HttpsProxy) Run() { |
||||
} |
||||
|
||||
func (pxy *HttpsProxy) Close() { |
||||
} |
||||
|
||||
func (pxy *HttpsProxy) InWorkConn(conn net.Conn) { |
||||
defer conn.Close() |
||||
} |
@ -0,0 +1,43 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client |
||||
|
||||
import "github.com/fatedier/frp/models/config" |
||||
|
||||
type Service struct { |
||||
// manager control connection with server
|
||||
ctl *Control |
||||
|
||||
closedCh chan int |
||||
} |
||||
|
||||
func NewService(pxyCfgs map[string]config.ProxyConf) (svr *Service) { |
||||
svr = &Service{ |
||||
closedCh: make(chan int), |
||||
} |
||||
ctl := NewControl(svr, pxyCfgs) |
||||
svr.ctl = ctl |
||||
return |
||||
} |
||||
|
||||
func (svr *Service) Run() error { |
||||
err := svr.ctl.Run() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
<-svr.closedCh |
||||
return nil |
||||
} |
@ -0,0 +1,117 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
docopt "github.com/docopt/docopt-go" |
||||
ini "github.com/vaughan0/go-ini" |
||||
|
||||
"github.com/fatedier/frp/models/config" |
||||
"github.com/fatedier/frp/server" |
||||
"github.com/fatedier/frp/utils/log" |
||||
"github.com/fatedier/frp/utils/version" |
||||
) |
||||
|
||||
var usage string = `frps is the server of frp |
||||
|
||||
Usage:
|
||||
frps [-c config_file] [-L log_file] [--log-level=<log_level>] [--addr=<bind_addr>] |
||||
frps -h | --help |
||||
frps -v | --version |
||||
|
||||
Options: |
||||
-c config_file set config file |
||||
-L log_file set output log file, including console |
||||
--log-level=<log_level> set log level: debug, info, warn, error |
||||
--addr=<bind_addr> listen addr for client, example: 0.0.0.0:7000 |
||||
-h --help show this screen |
||||
-v --version show version |
||||
` |
||||
|
||||
func main() { |
||||
var err error |
||||
confFile := "./frps.ini" |
||||
// the configures parsed from file will be replaced by those from command line if exist
|
||||
args, err := docopt.Parse(usage, nil, true, version.Full(), false) |
||||
|
||||
if args["-c"] != nil { |
||||
confFile = args["-c"].(string) |
||||
} |
||||
|
||||
conf, err := ini.LoadFile(confFile) |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
os.Exit(1) |
||||
} |
||||
config.ServerCommonCfg, err = config.LoadServerCommonConf(conf) |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
os.Exit(1) |
||||
} |
||||
|
||||
if args["-L"] != nil { |
||||
if args["-L"].(string) == "console" { |
||||
config.ServerCommonCfg.LogWay = "console" |
||||
} else { |
||||
config.ServerCommonCfg.LogWay = "file" |
||||
config.ServerCommonCfg.LogFile = args["-L"].(string) |
||||
} |
||||
} |
||||
|
||||
if args["--log-level"] != nil { |
||||
config.ServerCommonCfg.LogLevel = args["--log-level"].(string) |
||||
} |
||||
|
||||
if args["--addr"] != nil { |
||||
addr := strings.Split(args["--addr"].(string), ":") |
||||
if len(addr) != 2 { |
||||
fmt.Println("--addr format error: example 0.0.0.0:7000") |
||||
os.Exit(1) |
||||
} |
||||
bindPort, err := strconv.ParseInt(addr[1], 10, 64) |
||||
if err != nil { |
||||
fmt.Println("--addr format error, example 0.0.0.0:7000") |
||||
os.Exit(1) |
||||
} |
||||
config.ServerCommonCfg.BindAddr = addr[0] |
||||
config.ServerCommonCfg.BindPort = bindPort |
||||
} |
||||
|
||||
if args["-v"] != nil { |
||||
if args["-v"].(bool) { |
||||
fmt.Println(version.Full()) |
||||
os.Exit(0) |
||||
} |
||||
} |
||||
|
||||
log.InitLog(config.ServerCommonCfg.LogWay, config.ServerCommonCfg.LogFile, |
||||
config.ServerCommonCfg.LogLevel, config.ServerCommonCfg.LogMaxDays) |
||||
|
||||
svr, err := server.NewService() |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
os.Exit(1) |
||||
} |
||||
log.Info("Start frps success") |
||||
if config.ServerCommonCfg.PrivilegeMode == true { |
||||
log.Info("PrivilegeMode is enabled, you should pay more attention to security issues") |
||||
} |
||||
svr.Run() |
||||
} |
@ -0,0 +1,160 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"strconv" |
||||
|
||||
ini "github.com/vaughan0/go-ini" |
||||
) |
||||
|
||||
var ClientCommonCfg *ClientCommonConf |
||||
|
||||
// client common config
|
||||
type ClientCommonConf struct { |
||||
ConfigFile string |
||||
ServerAddr string |
||||
ServerPort int64 |
||||
HttpProxy string |
||||
LogFile string |
||||
LogWay string |
||||
LogLevel string |
||||
LogMaxDays int64 |
||||
PrivilegeToken string |
||||
PoolCount int |
||||
User string |
||||
HeartBeatInterval int64 |
||||
HeartBeatTimeout int64 |
||||
} |
||||
|
||||
func GetDeaultClientCommonConf() *ClientCommonConf { |
||||
return &ClientCommonConf{ |
||||
ConfigFile: "./frpc.ini", |
||||
ServerAddr: "0.0.0.0", |
||||
ServerPort: 7000, |
||||
HttpProxy: "", |
||||
LogFile: "console", |
||||
LogWay: "console", |
||||
LogLevel: "info", |
||||
LogMaxDays: 3, |
||||
PrivilegeToken: "", |
||||
PoolCount: 1, |
||||
User: "", |
||||
HeartBeatInterval: 10, |
||||
HeartBeatTimeout: 30, |
||||
} |
||||
} |
||||
|
||||
func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { |
||||
var ( |
||||
tmpStr string |
||||
ok bool |
||||
v int64 |
||||
) |
||||
cfg = GetDeaultClientCommonConf() |
||||
|
||||
tmpStr, ok = conf.Get("common", "server_addr") |
||||
if ok { |
||||
cfg.ServerAddr = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "server_port") |
||||
if ok { |
||||
cfg.ServerPort, _ = strconv.ParseInt(tmpStr, 10, 64) |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "http_proxy") |
||||
if ok { |
||||
cfg.HttpProxy = tmpStr |
||||
} else { |
||||
// get http_proxy from env
|
||||
cfg.HttpProxy = os.Getenv("http_proxy") |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "log_file") |
||||
if ok { |
||||
cfg.LogFile = tmpStr |
||||
if cfg.LogFile == "console" { |
||||
cfg.LogWay = "console" |
||||
} else { |
||||
cfg.LogWay = "file" |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "log_level") |
||||
if ok { |
||||
cfg.LogLevel = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "log_max_days") |
||||
if ok { |
||||
cfg.LogMaxDays, _ = strconv.ParseInt(tmpStr, 10, 64) |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "privilege_token") |
||||
if ok { |
||||
cfg.PrivilegeToken = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "pool_count") |
||||
if ok { |
||||
v, err = strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil { |
||||
cfg.PoolCount = 1 |
||||
} else { |
||||
cfg.PoolCount = int(v) |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "user") |
||||
if ok { |
||||
cfg.User = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "heartbeat_timeout") |
||||
if ok { |
||||
v, err = strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil { |
||||
err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect") |
||||
return |
||||
} else { |
||||
cfg.HeartBeatTimeout = v |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "heartbeat_interval") |
||||
if ok { |
||||
v, err = strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil { |
||||
err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect") |
||||
return |
||||
} else { |
||||
cfg.HeartBeatInterval = v |
||||
} |
||||
} |
||||
|
||||
if cfg.HeartBeatInterval <= 0 { |
||||
err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect") |
||||
return |
||||
} |
||||
|
||||
if cfg.HeartBeatTimeout < cfg.HeartBeatInterval { |
||||
err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect, heartbeat_timeout is less than heartbeat_interval") |
||||
return |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,446 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config |
||||
|
||||
import ( |
||||
"fmt" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"github.com/fatedier/frp/models/consts" |
||||
"github.com/fatedier/frp/models/msg" |
||||
|
||||
ini "github.com/vaughan0/go-ini" |
||||
) |
||||
|
||||
type ProxyConf interface { |
||||
GetName() string |
||||
GetBaseInfo() *BaseProxyConf |
||||
LoadFromMsg(pMsg *msg.NewProxy) |
||||
LoadFromFile(name string, conf ini.Section) error |
||||
UnMarshalToMsg(pMsg *msg.NewProxy) |
||||
Check() error |
||||
} |
||||
|
||||
func NewProxyConf(pMsg *msg.NewProxy) (cfg ProxyConf, err error) { |
||||
if pMsg.ProxyType == "" { |
||||
pMsg.ProxyType = consts.TcpProxy |
||||
} |
||||
switch pMsg.ProxyType { |
||||
case consts.TcpProxy: |
||||
cfg = &TcpProxyConf{} |
||||
case consts.UdpProxy: |
||||
cfg = &UdpProxyConf{} |
||||
case consts.HttpProxy: |
||||
cfg = &HttpProxyConf{} |
||||
case consts.HttpsProxy: |
||||
cfg = &HttpsProxyConf{} |
||||
default: |
||||
err = fmt.Errorf("proxy [%s] type [%s] error", pMsg.ProxyName, pMsg.ProxyType) |
||||
return |
||||
} |
||||
cfg.LoadFromMsg(pMsg) |
||||
err = cfg.Check() |
||||
return |
||||
} |
||||
|
||||
func NewProxyConfFromFile(name string, section ini.Section) (cfg ProxyConf, err error) { |
||||
proxyType := section["type"] |
||||
if proxyType == "" { |
||||
proxyType = consts.TcpProxy |
||||
section["type"] = consts.TcpProxy |
||||
} |
||||
switch proxyType { |
||||
case consts.TcpProxy: |
||||
cfg = &TcpProxyConf{} |
||||
case consts.UdpProxy: |
||||
cfg = &UdpProxyConf{} |
||||
case consts.HttpProxy: |
||||
cfg = &HttpProxyConf{} |
||||
case consts.HttpsProxy: |
||||
cfg = &HttpsProxyConf{} |
||||
default: |
||||
err = fmt.Errorf("proxy [%s] type [%s] error", name, proxyType) |
||||
return |
||||
} |
||||
err = cfg.LoadFromFile(name, section) |
||||
return |
||||
} |
||||
|
||||
// BaseProxy info
|
||||
type BaseProxyConf struct { |
||||
ProxyName string |
||||
ProxyType string |
||||
|
||||
UseEncryption bool |
||||
UseCompression bool |
||||
} |
||||
|
||||
func (cfg *BaseProxyConf) GetName() string { |
||||
return cfg.ProxyName |
||||
} |
||||
|
||||
func (cfg *BaseProxyConf) GetBaseInfo() *BaseProxyConf { |
||||
return cfg |
||||
} |
||||
|
||||
func (cfg *BaseProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { |
||||
cfg.ProxyName = pMsg.ProxyName |
||||
cfg.ProxyType = pMsg.ProxyType |
||||
cfg.UseEncryption = pMsg.UseEncryption |
||||
cfg.UseCompression = pMsg.UseCompression |
||||
} |
||||
|
||||
func (cfg *BaseProxyConf) LoadFromFile(name string, section ini.Section) error { |
||||
var ( |
||||
tmpStr string |
||||
ok bool |
||||
) |
||||
cfg.ProxyName = ClientCommonCfg.User + "." + name |
||||
cfg.ProxyType = section["type"] |
||||
|
||||
tmpStr, ok = section["use_encryption"] |
||||
if ok && tmpStr == "true" { |
||||
cfg.UseEncryption = true |
||||
} |
||||
|
||||
tmpStr, ok = section["use_compression"] |
||||
if ok && tmpStr == "true" { |
||||
cfg.UseCompression = true |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (cfg *BaseProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) { |
||||
pMsg.ProxyName = cfg.ProxyName |
||||
pMsg.ProxyType = cfg.ProxyType |
||||
pMsg.UseEncryption = cfg.UseEncryption |
||||
pMsg.UseCompression = cfg.UseCompression |
||||
} |
||||
|
||||
// Bind info
|
||||
type BindInfoConf struct { |
||||
BindAddr string |
||||
RemotePort int64 |
||||
} |
||||
|
||||
func (cfg *BindInfoConf) LoadFromMsg(pMsg *msg.NewProxy) { |
||||
cfg.BindAddr = ServerCommonCfg.BindAddr |
||||
cfg.RemotePort = pMsg.RemotePort |
||||
} |
||||
|
||||
func (cfg *BindInfoConf) LoadFromFile(name string, section ini.Section) (err error) { |
||||
var ( |
||||
tmpStr string |
||||
ok bool |
||||
) |
||||
if tmpStr, ok = section["remote_port"]; ok { |
||||
if cfg.RemotePort, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] remote_port error", name) |
||||
} |
||||
} else { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", name) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (cfg *BindInfoConf) UnMarshalToMsg(pMsg *msg.NewProxy) { |
||||
pMsg.RemotePort = cfg.RemotePort |
||||
} |
||||
|
||||
func (cfg *BindInfoConf) check() (err error) { |
||||
if len(ServerCommonCfg.PrivilegeAllowPorts) != 0 { |
||||
if _, ok := ServerCommonCfg.PrivilegeAllowPorts[cfg.RemotePort]; !ok { |
||||
return fmt.Errorf("remote port [%d] isn't allowed", cfg.RemotePort) |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Domain info
|
||||
type DomainConf struct { |
||||
CustomDomains []string |
||||
SubDomain string |
||||
} |
||||
|
||||
func (cfg *DomainConf) LoadFromMsg(pMsg *msg.NewProxy) { |
||||
cfg.CustomDomains = pMsg.CustomDomains |
||||
cfg.SubDomain = pMsg.SubDomain |
||||
} |
||||
|
||||
func (cfg *DomainConf) LoadFromFile(name string, section ini.Section) (err error) { |
||||
var ( |
||||
tmpStr string |
||||
ok bool |
||||
) |
||||
if tmpStr, ok = section["custom_domains"]; ok { |
||||
cfg.CustomDomains = strings.Split(tmpStr, ",") |
||||
for i, domain := range cfg.CustomDomains { |
||||
cfg.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain)) |
||||
} |
||||
} |
||||
|
||||
if tmpStr, ok = section["subdomain"]; ok { |
||||
cfg.SubDomain = tmpStr |
||||
} |
||||
|
||||
if len(cfg.CustomDomains) == 0 && cfg.SubDomain == "" { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] custom_domains and subdomain should set at least one of them", name) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (cfg *DomainConf) UnMarshalToMsg(pMsg *msg.NewProxy) { |
||||
pMsg.CustomDomains = cfg.CustomDomains |
||||
pMsg.SubDomain = cfg.SubDomain |
||||
} |
||||
|
||||
func (cfg *DomainConf) check() (err error) { |
||||
for _, domain := range cfg.CustomDomains { |
||||
if ServerCommonCfg.SubDomainHost != "" && len(strings.Split(ServerCommonCfg.SubDomainHost, ".")) < len(strings.Split(domain, ".")) { |
||||
if strings.Contains(domain, ServerCommonCfg.SubDomainHost) { |
||||
return fmt.Errorf("custom domain [%s] should not belong to subdomain_host [%s]", domain, ServerCommonCfg.SubDomainHost) |
||||
} |
||||
} |
||||
} |
||||
|
||||
if cfg.SubDomain != "" { |
||||
if ServerCommonCfg.SubDomainHost == "" { |
||||
return fmt.Errorf("subdomain is not supported because this feature is not enabled by frps") |
||||
} |
||||
if strings.Contains(cfg.SubDomain, ".") || strings.Contains(cfg.SubDomain, "*") { |
||||
return fmt.Errorf("'.' and '*' is not supported in subdomain") |
||||
} |
||||
cfg.SubDomain += "." + ServerCommonCfg.SubDomainHost |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
type LocalSvrConf struct { |
||||
LocalIp string |
||||
LocalPort int |
||||
} |
||||
|
||||
func (cfg *LocalSvrConf) LoadFromFile(name string, section ini.Section) (err error) { |
||||
if cfg.LocalIp = section["local_ip"]; cfg.LocalIp == "" { |
||||
cfg.LocalIp = "127.0.0.1" |
||||
} |
||||
|
||||
if tmpStr, ok := section["local_port"]; ok { |
||||
if cfg.LocalPort, err = strconv.Atoi(tmpStr); err != nil { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] local_port error", name) |
||||
} |
||||
} else { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] local_port not found", name) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// TCP
|
||||
type TcpProxyConf struct { |
||||
BaseProxyConf |
||||
BindInfoConf |
||||
|
||||
LocalSvrConf |
||||
} |
||||
|
||||
func (cfg *TcpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { |
||||
cfg.BaseProxyConf.LoadFromMsg(pMsg) |
||||
cfg.BindInfoConf.LoadFromMsg(pMsg) |
||||
} |
||||
|
||||
func (cfg *TcpProxyConf) LoadFromFile(name string, section ini.Section) (err error) { |
||||
if err = cfg.BaseProxyConf.LoadFromFile(name, section); err != nil { |
||||
return |
||||
} |
||||
if err = cfg.BindInfoConf.LoadFromFile(name, section); err != nil { |
||||
return |
||||
} |
||||
if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil { |
||||
return |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (cfg *TcpProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) { |
||||
cfg.BaseProxyConf.UnMarshalToMsg(pMsg) |
||||
cfg.BindInfoConf.UnMarshalToMsg(pMsg) |
||||
} |
||||
|
||||
func (cfg *TcpProxyConf) Check() (err error) { |
||||
err = cfg.BindInfoConf.check() |
||||
return |
||||
} |
||||
|
||||
// UDP
|
||||
type UdpProxyConf struct { |
||||
BaseProxyConf |
||||
BindInfoConf |
||||
|
||||
LocalSvrConf |
||||
} |
||||
|
||||
func (cfg *UdpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { |
||||
cfg.BaseProxyConf.LoadFromMsg(pMsg) |
||||
cfg.BindInfoConf.LoadFromMsg(pMsg) |
||||
} |
||||
|
||||
func (cfg *UdpProxyConf) LoadFromFile(name string, section ini.Section) (err error) { |
||||
if err = cfg.BaseProxyConf.LoadFromFile(name, section); err != nil { |
||||
return |
||||
} |
||||
if err = cfg.BindInfoConf.LoadFromFile(name, section); err != nil { |
||||
return |
||||
} |
||||
if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil { |
||||
return |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (cfg *UdpProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) { |
||||
cfg.BaseProxyConf.UnMarshalToMsg(pMsg) |
||||
cfg.BindInfoConf.UnMarshalToMsg(pMsg) |
||||
} |
||||
|
||||
func (cfg *UdpProxyConf) Check() (err error) { |
||||
err = cfg.BindInfoConf.check() |
||||
return |
||||
} |
||||
|
||||
// HTTP
|
||||
type HttpProxyConf struct { |
||||
BaseProxyConf |
||||
DomainConf |
||||
|
||||
LocalSvrConf |
||||
|
||||
Locations []string |
||||
HostHeaderRewrite string |
||||
HttpUser string |
||||
HttpPwd string |
||||
} |
||||
|
||||
func (cfg *HttpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { |
||||
cfg.BaseProxyConf.LoadFromMsg(pMsg) |
||||
cfg.DomainConf.LoadFromMsg(pMsg) |
||||
|
||||
cfg.Locations = pMsg.Locations |
||||
cfg.HostHeaderRewrite = pMsg.HostHeaderRewrite |
||||
cfg.HttpUser = pMsg.HttpUser |
||||
cfg.HttpPwd = pMsg.HttpPwd |
||||
} |
||||
|
||||
func (cfg *HttpProxyConf) LoadFromFile(name string, section ini.Section) (err error) { |
||||
if err = cfg.BaseProxyConf.LoadFromFile(name, section); err != nil { |
||||
return |
||||
} |
||||
if err = cfg.DomainConf.LoadFromFile(name, section); err != nil { |
||||
return |
||||
} |
||||
if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil { |
||||
return |
||||
} |
||||
|
||||
var ( |
||||
tmpStr string |
||||
ok bool |
||||
) |
||||
if tmpStr, ok = section["locations"]; ok { |
||||
cfg.Locations = strings.Split(tmpStr, ",") |
||||
} else { |
||||
cfg.Locations = []string{""} |
||||
} |
||||
|
||||
cfg.HostHeaderRewrite = section["host_header_rewrite"] |
||||
cfg.HttpUser = section["http_user"] |
||||
cfg.HttpPwd = section["http_pwd"] |
||||
return |
||||
} |
||||
|
||||
func (cfg *HttpProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) { |
||||
cfg.BaseProxyConf.UnMarshalToMsg(pMsg) |
||||
cfg.DomainConf.UnMarshalToMsg(pMsg) |
||||
|
||||
pMsg.Locations = cfg.Locations |
||||
pMsg.HostHeaderRewrite = cfg.HostHeaderRewrite |
||||
pMsg.HttpUser = cfg.HttpUser |
||||
pMsg.HttpPwd = cfg.HttpPwd |
||||
} |
||||
|
||||
func (cfg *HttpProxyConf) Check() (err error) { |
||||
if ServerCommonCfg.VhostHttpPort == 0 { |
||||
return fmt.Errorf("type [http] not support when vhost_http_port is not set") |
||||
} |
||||
err = cfg.DomainConf.check() |
||||
return |
||||
} |
||||
|
||||
// HTTPS
|
||||
type HttpsProxyConf struct { |
||||
BaseProxyConf |
||||
DomainConf |
||||
|
||||
LocalSvrConf |
||||
} |
||||
|
||||
func (cfg *HttpsProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { |
||||
cfg.BaseProxyConf.LoadFromMsg(pMsg) |
||||
cfg.DomainConf.LoadFromMsg(pMsg) |
||||
} |
||||
|
||||
func (cfg *HttpsProxyConf) LoadFromFile(name string, section ini.Section) (err error) { |
||||
if err = cfg.BaseProxyConf.LoadFromFile(name, section); err != nil { |
||||
return |
||||
} |
||||
if err = cfg.DomainConf.LoadFromFile(name, section); err != nil { |
||||
return |
||||
} |
||||
if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil { |
||||
return |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (cfg *HttpsProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) { |
||||
cfg.BaseProxyConf.UnMarshalToMsg(pMsg) |
||||
cfg.DomainConf.UnMarshalToMsg(pMsg) |
||||
} |
||||
|
||||
func (cfg *HttpsProxyConf) Check() (err error) { |
||||
if ServerCommonCfg.VhostHttpsPort == 0 { |
||||
return fmt.Errorf("type [https] not support when vhost_https_port is not set") |
||||
} |
||||
err = cfg.DomainConf.check() |
||||
return |
||||
} |
||||
|
||||
func LoadProxyConfFromFile(conf ini.File) (proxyConfs map[string]ProxyConf, err error) { |
||||
var prefix string |
||||
if ClientCommonCfg.User != "" { |
||||
prefix = ClientCommonCfg.User + "." |
||||
} |
||||
proxyConfs = make(map[string]ProxyConf) |
||||
for name, section := range conf { |
||||
if name != "common" { |
||||
cfg, err := NewProxyConfFromFile(name, section) |
||||
if err != nil { |
||||
return proxyConfs, err |
||||
} |
||||
proxyConfs[prefix+name] = cfg |
||||
} |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,279 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config |
||||
|
||||
import ( |
||||
"fmt" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
ini "github.com/vaughan0/go-ini" |
||||
) |
||||
|
||||
var ServerCommonCfg *ServerCommonConf |
||||
|
||||
// common config
|
||||
type ServerCommonConf struct { |
||||
ConfigFile string |
||||
BindAddr string |
||||
BindPort int64 |
||||
|
||||
// If VhostHttpPort equals 0, don't listen a public port for http protocol.
|
||||
VhostHttpPort int64 |
||||
|
||||
// if VhostHttpsPort equals 0, don't listen a public port for https protocol
|
||||
VhostHttpsPort int64 |
||||
|
||||
// if DashboardPort equals 0, dashboard is not available
|
||||
DashboardPort int64 |
||||
DashboardUser string |
||||
DashboardPwd string |
||||
AssetsDir string |
||||
LogFile string |
||||
LogWay string // console or file
|
||||
LogLevel string |
||||
LogMaxDays int64 |
||||
PrivilegeMode bool |
||||
PrivilegeToken string |
||||
AuthTimeout int64 |
||||
SubDomainHost string |
||||
|
||||
// if PrivilegeAllowPorts is not nil, tcp proxies which remote port exist in this map can be connected
|
||||
PrivilegeAllowPorts map[int64]struct{} |
||||
MaxPoolCount int64 |
||||
HeartBeatTimeout int64 |
||||
UserConnTimeout int64 |
||||
} |
||||
|
||||
func GetDefaultServerCommonConf() *ServerCommonConf { |
||||
return &ServerCommonConf{ |
||||
ConfigFile: "./frps.ini", |
||||
BindAddr: "0.0.0.0", |
||||
BindPort: 7000, |
||||
VhostHttpPort: 0, |
||||
VhostHttpsPort: 0, |
||||
DashboardPort: 0, |
||||
DashboardUser: "admin", |
||||
DashboardPwd: "admin", |
||||
AssetsDir: "", |
||||
LogFile: "console", |
||||
LogWay: "console", |
||||
LogLevel: "info", |
||||
LogMaxDays: 3, |
||||
PrivilegeMode: true, |
||||
PrivilegeToken: "", |
||||
AuthTimeout: 900, |
||||
SubDomainHost: "", |
||||
MaxPoolCount: 10, |
||||
HeartBeatTimeout: 30, |
||||
UserConnTimeout: 10, |
||||
} |
||||
} |
||||
|
||||
// Load server common configure.
|
||||
func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { |
||||
var ( |
||||
tmpStr string |
||||
ok bool |
||||
v int64 |
||||
) |
||||
cfg = GetDefaultServerCommonConf() |
||||
|
||||
tmpStr, ok = conf.Get("common", "bind_addr") |
||||
if ok { |
||||
cfg.BindAddr = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "bind_port") |
||||
if ok { |
||||
v, err = strconv.ParseInt(tmpStr, 10, 64) |
||||
if err == nil { |
||||
cfg.BindPort = v |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "vhost_http_port") |
||||
if ok { |
||||
cfg.VhostHttpPort, err = strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil { |
||||
err = fmt.Errorf("Parse conf error: vhost_http_port is incorrect") |
||||
return |
||||
} |
||||
} else { |
||||
cfg.VhostHttpPort = 0 |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "vhost_https_port") |
||||
if ok { |
||||
cfg.VhostHttpsPort, err = strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil { |
||||
err = fmt.Errorf("Parse conf error: vhost_https_port is incorrect") |
||||
return |
||||
} |
||||
} else { |
||||
cfg.VhostHttpsPort = 0 |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "dashboard_port") |
||||
if ok { |
||||
cfg.DashboardPort, err = strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil { |
||||
err = fmt.Errorf("Parse conf error: dashboard_port is incorrect") |
||||
return |
||||
} |
||||
} else { |
||||
cfg.DashboardPort = 0 |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "dashboard_user") |
||||
if ok { |
||||
cfg.DashboardUser = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "dashboard_pwd") |
||||
if ok { |
||||
cfg.DashboardPwd = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "assets_dir") |
||||
if ok { |
||||
cfg.AssetsDir = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "log_file") |
||||
if ok { |
||||
cfg.LogFile = tmpStr |
||||
if cfg.LogFile == "console" { |
||||
cfg.LogWay = "console" |
||||
} else { |
||||
cfg.LogWay = "file" |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "log_level") |
||||
if ok { |
||||
cfg.LogLevel = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "log_max_days") |
||||
if ok { |
||||
v, err = strconv.ParseInt(tmpStr, 10, 64) |
||||
if err == nil { |
||||
cfg.LogMaxDays = v |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "privilege_mode") |
||||
if ok { |
||||
if tmpStr == "true" { |
||||
cfg.PrivilegeMode = true |
||||
} |
||||
} |
||||
|
||||
// PrivilegeMode configure
|
||||
if cfg.PrivilegeMode == true { |
||||
tmpStr, ok = conf.Get("common", "privilege_token") |
||||
if ok { |
||||
if tmpStr == "" { |
||||
err = fmt.Errorf("Parse conf error: privilege_token can not be empty") |
||||
return |
||||
} |
||||
cfg.PrivilegeToken = tmpStr |
||||
} else { |
||||
err = fmt.Errorf("Parse conf error: privilege_token must be set if privilege_mode is enabled") |
||||
return |
||||
} |
||||
|
||||
cfg.PrivilegeAllowPorts = make(map[int64]struct{}) |
||||
tmpStr, ok = conf.Get("common", "privilege_allow_ports") |
||||
if ok { |
||||
// e.g. 1000-2000,2001,2002,3000-4000
|
||||
portRanges := strings.Split(tmpStr, ",") |
||||
for _, portRangeStr := range portRanges { |
||||
// 1000-2000 or 2001
|
||||
portArray := strings.Split(portRangeStr, "-") |
||||
// length: only 1 or 2 is correct
|
||||
rangeType := len(portArray) |
||||
if rangeType == 1 { |
||||
// single port
|
||||
singlePort, errRet := strconv.ParseInt(portArray[0], 10, 64) |
||||
if errRet != nil { |
||||
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) |
||||
return |
||||
} |
||||
ServerCommonCfg.PrivilegeAllowPorts[singlePort] = struct{}{} |
||||
} else if rangeType == 2 { |
||||
// range ports
|
||||
min, errRet := strconv.ParseInt(portArray[0], 10, 64) |
||||
if errRet != nil { |
||||
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) |
||||
return |
||||
} |
||||
max, errRet := strconv.ParseInt(portArray[1], 10, 64) |
||||
if errRet != nil { |
||||
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) |
||||
return |
||||
} |
||||
if max < min { |
||||
err = fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect") |
||||
return |
||||
} |
||||
for i := min; i <= max; i++ { |
||||
cfg.PrivilegeAllowPorts[i] = struct{}{} |
||||
} |
||||
} else { |
||||
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect") |
||||
return |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "max_pool_count") |
||||
if ok { |
||||
v, err = strconv.ParseInt(tmpStr, 10, 64) |
||||
if err == nil && v >= 0 { |
||||
cfg.MaxPoolCount = v |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "authentication_timeout") |
||||
if ok { |
||||
v, errRet := strconv.ParseInt(tmpStr, 10, 64) |
||||
if errRet != nil { |
||||
err = fmt.Errorf("Parse conf error: authentication_timeout is incorrect") |
||||
return |
||||
} else { |
||||
cfg.AuthTimeout = v |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "subdomain_host") |
||||
if ok { |
||||
cfg.SubDomainHost = strings.ToLower(strings.TrimSpace(tmpStr)) |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "heartbeat_timeout") |
||||
if ok { |
||||
v, errRet := strconv.ParseInt(tmpStr, 10, 64) |
||||
if errRet != nil { |
||||
err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect") |
||||
return |
||||
} else { |
||||
cfg.HeartBeatTimeout = v |
||||
} |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,21 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package errors |
||||
|
||||
import "errors" |
||||
|
||||
var ( |
||||
ErrMsgType = errors.New("message type error") |
||||
) |
@ -0,0 +1,122 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg |
||||
|
||||
import "reflect" |
||||
|
||||
const ( |
||||
TypeLogin = 'o' |
||||
TypeLoginResp = '1' |
||||
TypeNewProxy = 'p' |
||||
TypeNewProxyResp = '2' |
||||
TypeNewWorkConn = 'w' |
||||
TypeReqWorkConn = 'r' |
||||
TypeStartWorkConn = 's' |
||||
TypePing = 'h' |
||||
TypePong = '4' |
||||
) |
||||
|
||||
var ( |
||||
TypeMap map[byte]reflect.Type |
||||
TypeStringMap map[reflect.Type]byte |
||||
) |
||||
|
||||
func init() { |
||||
TypeMap = make(map[byte]reflect.Type) |
||||
TypeStringMap = make(map[reflect.Type]byte) |
||||
|
||||
TypeMap[TypeLogin] = getTypeFn((*Login)(nil)) |
||||
TypeMap[TypeLoginResp] = getTypeFn((*LoginResp)(nil)) |
||||
TypeMap[TypeNewProxy] = getTypeFn((*NewProxy)(nil)) |
||||
TypeMap[TypeNewProxyResp] = getTypeFn((*NewProxyResp)(nil)) |
||||
TypeMap[TypeNewWorkConn] = getTypeFn((*NewWorkConn)(nil)) |
||||
TypeMap[TypeReqWorkConn] = getTypeFn((*ReqWorkConn)(nil)) |
||||
TypeMap[TypeStartWorkConn] = getTypeFn((*StartWorkConn)(nil)) |
||||
TypeMap[TypePing] = getTypeFn((*Ping)(nil)) |
||||
TypeMap[TypePong] = getTypeFn((*Pong)(nil)) |
||||
|
||||
for k, v := range TypeMap { |
||||
TypeStringMap[v] = k |
||||
} |
||||
} |
||||
|
||||
func getTypeFn(obj interface{}) reflect.Type { |
||||
return reflect.TypeOf(obj).Elem() |
||||
} |
||||
|
||||
// Message wraps socket packages for communicating between frpc and frps.
|
||||
type Message interface{} |
||||
|
||||
// When frpc start, client send this message to login to server.
|
||||
type Login struct { |
||||
Version string `json:"version"` |
||||
Hostname string `json:"hostname"` |
||||
Os string `json:"os"` |
||||
Arch string `json:"arch"` |
||||
User string `json:"user"` |
||||
PrivilegeKey string `json:"privilege_key"` |
||||
Timestamp int64 `json:"timestamp"` |
||||
RunId string `json:"run_id"` |
||||
|
||||
// Some global configures.
|
||||
PoolCount int `json:"pool_count"` |
||||
} |
||||
|
||||
type LoginResp struct { |
||||
Version string `json:"version"` |
||||
RunId string `json:"run_id"` |
||||
Error string `json:"error"` |
||||
} |
||||
|
||||
// When frpc login success, send this message to frps for running a new proxy.
|
||||
type NewProxy struct { |
||||
ProxyName string `json:"proxy_name"` |
||||
ProxyType string `json:"proxy_type"` |
||||
UseEncryption bool `json:"use_encryption"` |
||||
UseCompression bool `json:"use_compression"` |
||||
|
||||
// tcp and udp only
|
||||
RemotePort int64 `json:"remote_port"` |
||||
|
||||
// http and https only
|
||||
CustomDomains []string `json:"custom_domains"` |
||||
SubDomain string `json:"subdomain"` |
||||
Locations []string `json:"locations"` |
||||
HostHeaderRewrite string `json:"host_header_rewrite"` |
||||
HttpUser string `json:"http_user"` |
||||
HttpPwd string `json:"http_pwd"` |
||||
} |
||||
|
||||
type NewProxyResp struct { |
||||
ProxyName string `json:"proxy_name"` |
||||
Error string `json:"error"` |
||||
} |
||||
|
||||
type NewWorkConn struct { |
||||
RunId string `json:"run_id"` |
||||
} |
||||
|
||||
type ReqWorkConn struct { |
||||
} |
||||
|
||||
type StartWorkConn struct { |
||||
ProxyName string `json:"proxy_name"` |
||||
} |
||||
|
||||
type Ping struct { |
||||
} |
||||
|
||||
type Pong struct { |
||||
} |
@ -0,0 +1,69 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/binary" |
||||
"encoding/json" |
||||
"fmt" |
||||
"reflect" |
||||
|
||||
"github.com/fatedier/frp/utils/errors" |
||||
) |
||||
|
||||
func unpack(typeByte byte, buffer []byte, msgIn Message) (msg Message, err error) { |
||||
if msgIn == nil { |
||||
t, ok := TypeMap[typeByte] |
||||
if !ok { |
||||
err = fmt.Errorf("Unsupported message type %b", typeByte) |
||||
return |
||||
} |
||||
|
||||
msg = reflect.New(t).Interface().(Message) |
||||
} else { |
||||
msg = msgIn |
||||
} |
||||
|
||||
err = json.Unmarshal(buffer, &msg) |
||||
return |
||||
} |
||||
|
||||
func UnPackInto(buffer []byte, msg Message) (err error) { |
||||
_, err = unpack(' ', buffer, msg) |
||||
return |
||||
} |
||||
|
||||
func UnPack(typeByte byte, buffer []byte) (msg Message, err error) { |
||||
return unpack(typeByte, buffer, nil) |
||||
} |
||||
|
||||
func Pack(msg Message) ([]byte, error) { |
||||
typeByte, ok := TypeStringMap[reflect.TypeOf(msg).Elem()] |
||||
if !ok { |
||||
return nil, errors.ErrMsgType |
||||
} |
||||
|
||||
content, err := json.Marshal(msg) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
buffer := bytes.NewBuffer(nil) |
||||
buffer.WriteByte(typeByte) |
||||
binary.Write(buffer, binary.BigEndian, int64(len(content))) |
||||
buffer.Write(content) |
||||
return buffer.Bytes(), nil |
||||
} |
@ -0,0 +1,86 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/binary" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
|
||||
"github.com/fatedier/frp/utils/errors" |
||||
) |
||||
|
||||
type TestStruct struct{} |
||||
|
||||
func TestPack(t *testing.T) { |
||||
assert := assert.New(t) |
||||
|
||||
var ( |
||||
msg Message |
||||
buffer []byte |
||||
err error |
||||
) |
||||
|
||||
// error type
|
||||
msg = &TestStruct{} |
||||
buffer, err = Pack(msg) |
||||
assert.Error(err, errors.ErrMsgType.Error()) |
||||
|
||||
// correct
|
||||
msg = &Ping{} |
||||
buffer, err = Pack(msg) |
||||
assert.NoError(err) |
||||
b := bytes.NewBuffer(nil) |
||||
b.WriteByte(TypePing) |
||||
binary.Write(b, binary.BigEndian, int64(2)) |
||||
b.WriteString("{}") |
||||
assert.True(bytes.Equal(b.Bytes(), buffer)) |
||||
} |
||||
|
||||
func TestUnPack(t *testing.T) { |
||||
assert := assert.New(t) |
||||
|
||||
var ( |
||||
msg Message |
||||
err error |
||||
) |
||||
|
||||
// error message type
|
||||
msg, err = UnPack('-', []byte("{}")) |
||||
assert.Error(err) |
||||
|
||||
// correct
|
||||
msg, err = UnPack(TypePong, []byte("{}")) |
||||
assert.NoError(err) |
||||
assert.Equal(getTypeFn(msg), getTypeFn((*Pong)(nil))) |
||||
} |
||||
|
||||
func TestUnPackInto(t *testing.T) { |
||||
assert := assert.New(t) |
||||
|
||||
var err error |
||||
|
||||
// correct type
|
||||
pongMsg := &Pong{} |
||||
err = UnPackInto([]byte("{}"), pongMsg) |
||||
assert.NoError(err) |
||||
|
||||
// wrong type
|
||||
loginMsg := &Login{} |
||||
err = UnPackInto([]byte(`{"version": 123}`), loginMsg) |
||||
assert.Error(err) |
||||
} |
@ -0,0 +1,88 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg |
||||
|
||||
import ( |
||||
"encoding/binary" |
||||
"fmt" |
||||
"io" |
||||
) |
||||
|
||||
var ( |
||||
MaxMsgLength int64 = 10240 |
||||
) |
||||
|
||||
func readMsg(c io.Reader) (typeByte byte, buffer []byte, err error) { |
||||
buffer = make([]byte, 1) |
||||
_, err = c.Read(buffer) |
||||
if err != nil { |
||||
return |
||||
} |
||||
typeByte = buffer[0] |
||||
if _, ok := TypeMap[typeByte]; !ok { |
||||
err = fmt.Errorf("Message type error") |
||||
return |
||||
} |
||||
|
||||
var length int64 |
||||
err = binary.Read(c, binary.BigEndian, &length) |
||||
if err != nil { |
||||
return |
||||
} |
||||
if length > MaxMsgLength { |
||||
err = fmt.Errorf("Message length exceed the limit") |
||||
return |
||||
} |
||||
|
||||
buffer = make([]byte, length) |
||||
n, err := io.ReadFull(c, buffer) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
if int64(n) != length { |
||||
err = fmt.Errorf("Message format error") |
||||
} |
||||
return |
||||
} |
||||
|
||||
func ReadMsg(c io.Reader) (msg Message, err error) { |
||||
typeByte, buffer, err := readMsg(c) |
||||
if err != nil { |
||||
return |
||||
} |
||||
return UnPack(typeByte, buffer) |
||||
} |
||||
|
||||
func ReadMsgInto(c io.Reader, msg Message) (err error) { |
||||
_, buffer, err := readMsg(c) |
||||
if err != nil { |
||||
return |
||||
} |
||||
return UnPackInto(buffer, msg) |
||||
} |
||||
|
||||
func WriteMsg(c io.Writer, msg interface{}) (err error) { |
||||
buffer, err := Pack(msg) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
if _, err = c.Write(buffer); err != nil { |
||||
return |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,97 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/binary" |
||||
"reflect" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func TestProcess(t *testing.T) { |
||||
assert := assert.New(t) |
||||
|
||||
var ( |
||||
msg Message |
||||
resMsg Message |
||||
err error |
||||
) |
||||
// empty struct
|
||||
msg = &Ping{} |
||||
buffer := bytes.NewBuffer(nil) |
||||
err = WriteMsg(buffer, msg) |
||||
assert.NoError(err) |
||||
|
||||
resMsg, err = ReadMsg(buffer) |
||||
assert.NoError(err) |
||||
assert.Equal(reflect.TypeOf(resMsg).Elem(), TypeMap[TypePing]) |
||||
|
||||
// normal message
|
||||
msg = &StartWorkConn{ |
||||
ProxyName: "test", |
||||
} |
||||
buffer = bytes.NewBuffer(nil) |
||||
err = WriteMsg(buffer, msg) |
||||
assert.NoError(err) |
||||
|
||||
resMsg, err = ReadMsg(buffer) |
||||
assert.NoError(err) |
||||
assert.Equal(reflect.TypeOf(resMsg).Elem(), TypeMap[TypeStartWorkConn]) |
||||
|
||||
startWorkConnMsg, ok := resMsg.(*StartWorkConn) |
||||
assert.True(ok) |
||||
assert.Equal("test", startWorkConnMsg.ProxyName) |
||||
|
||||
// ReadMsgInto correct
|
||||
msg = &Pong{} |
||||
buffer = bytes.NewBuffer(nil) |
||||
err = WriteMsg(buffer, msg) |
||||
assert.NoError(err) |
||||
|
||||
err = ReadMsgInto(buffer, msg) |
||||
assert.NoError(err) |
||||
|
||||
// ReadMsgInto error type
|
||||
content := []byte(`{"run_id": 123}`) |
||||
buffer = bytes.NewBuffer(nil) |
||||
buffer.WriteByte(TypeNewWorkConn) |
||||
binary.Write(buffer, binary.BigEndian, int64(len(content))) |
||||
buffer.Write(content) |
||||
|
||||
resMsg = &NewWorkConn{} |
||||
err = ReadMsgInto(buffer, resMsg) |
||||
assert.Error(err) |
||||
|
||||
// message format error
|
||||
buffer = bytes.NewBuffer([]byte("1234")) |
||||
|
||||
resMsg = &NewProxyResp{} |
||||
err = ReadMsgInto(buffer, resMsg) |
||||
assert.Error(err) |
||||
|
||||
// MaxLength, real message length is 2
|
||||
MaxMsgLength = 1 |
||||
msg = &Ping{} |
||||
buffer = bytes.NewBuffer(nil) |
||||
err = WriteMsg(buffer, msg) |
||||
assert.NoError(err) |
||||
|
||||
_, err = ReadMsg(buffer) |
||||
assert.Error(err) |
||||
return |
||||
} |
@ -0,0 +1,38 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tcp |
||||
|
||||
import ( |
||||
"io" |
||||
"sync" |
||||
) |
||||
|
||||
// Join two io.ReadWriteCloser and do some operations.
|
||||
func Join(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) (inCount int64, outCount int64) { |
||||
var wait sync.WaitGroup |
||||
pipe := func(to io.ReadWriteCloser, from io.ReadWriteCloser, count *int64) { |
||||
defer to.Close() |
||||
defer from.Close() |
||||
defer wait.Done() |
||||
|
||||
*count, _ = io.Copy(to, from) |
||||
} |
||||
|
||||
wait.Add(2) |
||||
go pipe(c1, c2, &inCount) |
||||
go pipe(c2, c1, &outCount) |
||||
wait.Wait() |
||||
return |
||||
} |
@ -0,0 +1,129 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tcp |
||||
|
||||
import ( |
||||
"io" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
|
||||
"github.com/fatedier/frp/utils/crypto" |
||||
) |
||||
|
||||
func TestJoin(t *testing.T) { |
||||
assert := assert.New(t) |
||||
|
||||
var ( |
||||
n int |
||||
err error |
||||
) |
||||
text1 := "A document that gives tips for writing clear, idiomatic Go code. A must read for any new Go programmer. It augments the tour and the language specification, both of which should be read first." |
||||
text2 := "A document that specifies the conditions under which reads of a variable in one goroutine can be guaranteed to observe values produced by writes to the same variable in a different goroutine." |
||||
|
||||
// Forward bytes directly.
|
||||
pr, pw := io.Pipe() |
||||
pr2, pw2 := io.Pipe() |
||||
pr3, pw3 := io.Pipe() |
||||
pr4, pw4 := io.Pipe() |
||||
|
||||
conn1 := WrapReadWriteCloser(pr, pw2) |
||||
conn2 := WrapReadWriteCloser(pr2, pw) |
||||
conn3 := WrapReadWriteCloser(pr3, pw4) |
||||
conn4 := WrapReadWriteCloser(pr4, pw3) |
||||
|
||||
go func() { |
||||
Join(conn2, conn3) |
||||
}() |
||||
|
||||
buf1 := make([]byte, 1024) |
||||
buf2 := make([]byte, 1024) |
||||
|
||||
conn1.Write([]byte(text1)) |
||||
conn4.Write([]byte(text2)) |
||||
|
||||
n, err = conn4.Read(buf1) |
||||
assert.NoError(err) |
||||
assert.Equal(text1, string(buf1[:n])) |
||||
|
||||
n, err = conn1.Read(buf2) |
||||
assert.NoError(err) |
||||
assert.Equal(text2, string(buf2[:n])) |
||||
|
||||
conn1.Close() |
||||
conn2.Close() |
||||
conn3.Close() |
||||
conn4.Close() |
||||
} |
||||
|
||||
func TestJoinEncrypt(t *testing.T) { |
||||
assert := assert.New(t) |
||||
|
||||
var ( |
||||
n int |
||||
err error |
||||
) |
||||
text1 := "1234567890" |
||||
text2 := "abcdefghij" |
||||
key := "authkey" |
||||
|
||||
// Forward enrypted bytes.
|
||||
pr, pw := io.Pipe() |
||||
pr2, pw2 := io.Pipe() |
||||
pr3, pw3 := io.Pipe() |
||||
pr4, pw4 := io.Pipe() |
||||
pr5, pw5 := io.Pipe() |
||||
pr6, pw6 := io.Pipe() |
||||
|
||||
conn1 := WrapReadWriteCloser(pr, pw2) |
||||
conn2 := WrapReadWriteCloser(pr2, pw) |
||||
conn3 := WrapReadWriteCloser(pr3, pw4) |
||||
conn4 := WrapReadWriteCloser(pr4, pw3) |
||||
conn5 := WrapReadWriteCloser(pr5, pw6) |
||||
conn6 := WrapReadWriteCloser(pr6, pw5) |
||||
|
||||
r1, err := crypto.NewReader(conn3, []byte(key)) |
||||
assert.NoError(err) |
||||
w1, err := crypto.NewWriter(conn3, []byte(key)) |
||||
assert.NoError(err) |
||||
|
||||
r2, err := crypto.NewReader(conn4, []byte(key)) |
||||
assert.NoError(err) |
||||
w2, err := crypto.NewWriter(conn4, []byte(key)) |
||||
assert.NoError(err) |
||||
|
||||
go Join(conn2, WrapReadWriteCloser(r1, w1)) |
||||
go Join(WrapReadWriteCloser(r2, w2), conn5) |
||||
|
||||
buf := make([]byte, 128) |
||||
|
||||
conn1.Write([]byte(text1)) |
||||
conn6.Write([]byte(text2)) |
||||
|
||||
n, err = conn6.Read(buf) |
||||
assert.NoError(err) |
||||
assert.Equal(text1, string(buf[:n])) |
||||
|
||||
n, err = conn1.Read(buf) |
||||
assert.NoError(err) |
||||
assert.Equal(text2, string(buf[:n])) |
||||
|
||||
conn1.Close() |
||||
conn2.Close() |
||||
conn3.Close() |
||||
conn4.Close() |
||||
conn5.Close() |
||||
conn6.Close() |
||||
} |
@ -0,0 +1,89 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tcp |
||||
|
||||
import ( |
||||
"io" |
||||
|
||||
"github.com/golang/snappy" |
||||
|
||||
"github.com/fatedier/frp/utils/crypto" |
||||
) |
||||
|
||||
func WithEncryption(rwc io.ReadWriteCloser, key []byte) (res io.ReadWriteCloser, err error) { |
||||
var ( |
||||
r io.Reader |
||||
w io.Writer |
||||
) |
||||
r, err = crypto.NewReader(rwc, key) |
||||
if err != nil { |
||||
return |
||||
} |
||||
w, err = crypto.NewWriter(rwc, key) |
||||
if err != nil { |
||||
return |
||||
} |
||||
res = WrapReadWriteCloser(r, w) |
||||
return |
||||
} |
||||
|
||||
func WithCompression(rwc io.ReadWriteCloser) (res io.ReadWriteCloser) { |
||||
var ( |
||||
r io.Reader |
||||
w io.Writer |
||||
) |
||||
r = snappy.NewReader(rwc) |
||||
w = snappy.NewWriter(rwc) |
||||
res = WrapReadWriteCloser(r, w) |
||||
return |
||||
} |
||||
|
||||
func WrapReadWriteCloser(r io.Reader, w io.Writer) io.ReadWriteCloser { |
||||
return &ReadWriteCloser{ |
||||
r: r, |
||||
w: w, |
||||
} |
||||
} |
||||
|
||||
type ReadWriteCloser struct { |
||||
r io.Reader |
||||
w io.Writer |
||||
} |
||||
|
||||
func (rwc *ReadWriteCloser) Read(p []byte) (n int, err error) { |
||||
return rwc.r.Read(p) |
||||
} |
||||
|
||||
func (rwc *ReadWriteCloser) Write(p []byte) (n int, err error) { |
||||
return rwc.w.Write(p) |
||||
} |
||||
|
||||
func (rwc *ReadWriteCloser) Close() (errRet error) { |
||||
var err error |
||||
if rc, ok := rwc.r.(io.Closer); ok { |
||||
err = rc.Close() |
||||
if err != nil { |
||||
errRet = err |
||||
} |
||||
} |
||||
|
||||
if wc, ok := rwc.w.(io.Closer); ok { |
||||
err = wc.Close() |
||||
if err != nil { |
||||
errRet = err |
||||
} |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,100 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tcp |
||||
|
||||
import ( |
||||
"io" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func TestWithCompression(t *testing.T) { |
||||
assert := assert.New(t) |
||||
|
||||
// Forward compression bytes.
|
||||
pr, pw := io.Pipe() |
||||
pr2, pw2 := io.Pipe() |
||||
|
||||
conn1 := WrapReadWriteCloser(pr, pw2) |
||||
conn2 := WrapReadWriteCloser(pr2, pw) |
||||
|
||||
compressionStream1 := WithCompression(conn1) |
||||
compressionStream2 := WithCompression(conn2) |
||||
|
||||
var ( |
||||
n int |
||||
err error |
||||
) |
||||
|
||||
text := "1234567812345678" |
||||
buf := make([]byte, 256) |
||||
|
||||
go compressionStream1.Write([]byte(text)) |
||||
n, err = compressionStream2.Read(buf) |
||||
assert.NoError(err) |
||||
assert.Equal(text, string(buf[:n])) |
||||
|
||||
go compressionStream2.Write([]byte(text)) |
||||
n, err = compressionStream1.Read(buf) |
||||
assert.NoError(err) |
||||
assert.Equal(text, string(buf[:n])) |
||||
} |
||||
|
||||
func TestWithEncryption(t *testing.T) { |
||||
assert := assert.New(t) |
||||
var ( |
||||
n int |
||||
err error |
||||
) |
||||
text1 := "Go is expressive, concise, clean, and efficient. Its concurrency mechanisms make it easy to write programs that get the most out of multicore and networked machines, while its novel type system enables flexible and modular program construction. Go compiles quickly to machine code yet has the convenience of garbage collection and the power of run-time reflection. It's a fast, statically typed, compiled language that feels like a dynamically typed, interpreted language." |
||||
text2 := "An interactive introduction to Go in three sections. The first section covers basic syntax and data structures; the second discusses methods and interfaces; and the third introduces Go's concurrency primitives. Each section concludes with a few exercises so you can practice what you've learned. You can take the tour online or install it locally with" |
||||
key := "authkey" |
||||
|
||||
// Forward enrypted bytes.
|
||||
pr, pw := io.Pipe() |
||||
pr2, pw2 := io.Pipe() |
||||
pr3, pw3 := io.Pipe() |
||||
pr4, pw4 := io.Pipe() |
||||
pr5, pw5 := io.Pipe() |
||||
pr6, pw6 := io.Pipe() |
||||
|
||||
conn1 := WrapReadWriteCloser(pr, pw2) |
||||
conn2 := WrapReadWriteCloser(pr2, pw) |
||||
conn3 := WrapReadWriteCloser(pr3, pw4) |
||||
conn4 := WrapReadWriteCloser(pr4, pw3) |
||||
conn5 := WrapReadWriteCloser(pr5, pw6) |
||||
conn6 := WrapReadWriteCloser(pr6, pw5) |
||||
|
||||
encryptStream1, err := WithEncryption(conn3, []byte(key)) |
||||
assert.NoError(err) |
||||
encryptStream2, err := WithEncryption(conn4, []byte(key)) |
||||
assert.NoError(err) |
||||
|
||||
go Join(conn2, encryptStream1) |
||||
go Join(encryptStream2, conn5) |
||||
|
||||
buf := make([]byte, 1024) |
||||
|
||||
conn1.Write([]byte(text1)) |
||||
conn6.Write([]byte(text2)) |
||||
|
||||
n, err = conn6.Read(buf) |
||||
assert.NoError(err) |
||||
assert.Equal(text1, string(buf[:n])) |
||||
|
||||
n, err = conn1.Read(buf) |
||||
assert.NoError(err) |
||||
} |
@ -0,0 +1,332 @@
|
||||
package server |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/fatedier/frp/models/config" |
||||
"github.com/fatedier/frp/models/consts" |
||||
"github.com/fatedier/frp/models/msg" |
||||
"github.com/fatedier/frp/utils/errors" |
||||
"github.com/fatedier/frp/utils/net" |
||||
"github.com/fatedier/frp/utils/shutdown" |
||||
"github.com/fatedier/frp/utils/version" |
||||
) |
||||
|
||||
type Control struct { |
||||
// frps service
|
||||
svr *Service |
||||
|
||||
// login message
|
||||
loginMsg *msg.Login |
||||
|
||||
// control connection
|
||||
conn net.Conn |
||||
|
||||
// put a message in this channel to send it over control connection to client
|
||||
sendCh chan (msg.Message) |
||||
|
||||
// read from this channel to get the next message sent by client
|
||||
readCh chan (msg.Message) |
||||
|
||||
// work connections
|
||||
workConnCh chan net.Conn |
||||
|
||||
// proxies in one client
|
||||
proxies []Proxy |
||||
|
||||
// pool count
|
||||
poolCount int |
||||
|
||||
// last time got the Ping message
|
||||
lastPing time.Time |
||||
|
||||
// A new run id will be generated when a new client login.
|
||||
// If run id got from login message has same run id, it means it's the same client, so we can
|
||||
// replace old controller instantly.
|
||||
runId string |
||||
|
||||
// control status
|
||||
status string |
||||
|
||||
readerShutdown *shutdown.Shutdown |
||||
writerShutdown *shutdown.Shutdown |
||||
managerShutdown *shutdown.Shutdown |
||||
allShutdown *shutdown.Shutdown |
||||
|
||||
mu sync.RWMutex |
||||
} |
||||
|
||||
func NewControl(svr *Service, ctlConn net.Conn, loginMsg *msg.Login) *Control { |
||||
return &Control{ |
||||
svr: svr, |
||||
conn: ctlConn, |
||||
loginMsg: loginMsg, |
||||
sendCh: make(chan msg.Message, 10), |
||||
readCh: make(chan msg.Message, 10), |
||||
workConnCh: make(chan net.Conn, loginMsg.PoolCount+10), |
||||
proxies: make([]Proxy, 0), |
||||
poolCount: loginMsg.PoolCount, |
||||
lastPing: time.Now(), |
||||
runId: loginMsg.RunId, |
||||
status: consts.Working, |
||||
readerShutdown: shutdown.New(), |
||||
writerShutdown: shutdown.New(), |
||||
managerShutdown: shutdown.New(), |
||||
allShutdown: shutdown.New(), |
||||
} |
||||
} |
||||
|
||||
// Start send a login success message to client and start working.
|
||||
func (ctl *Control) Start() { |
||||
go ctl.writer() |
||||
|
||||
ctl.sendCh <- &msg.LoginResp{ |
||||
Version: version.Full(), |
||||
RunId: ctl.runId, |
||||
Error: "", |
||||
} |
||||
|
||||
for i := 0; i < ctl.poolCount; i++ { |
||||
ctl.sendCh <- &msg.ReqWorkConn{} |
||||
} |
||||
|
||||
go ctl.manager() |
||||
go ctl.reader() |
||||
go ctl.stoper() |
||||
} |
||||
|
||||
func (ctl *Control) RegisterWorkConn(conn net.Conn) { |
||||
defer func() { |
||||
if err := recover(); err != nil { |
||||
ctl.conn.Error("panic error: %v", err) |
||||
} |
||||
}() |
||||
|
||||
select { |
||||
case ctl.workConnCh <- conn: |
||||
ctl.conn.Debug("new work connection registered.") |
||||
default: |
||||
ctl.conn.Debug("work connection pool is full, discarding.") |
||||
conn.Close() |
||||
} |
||||
} |
||||
|
||||
// When frps get one user connection, we get one work connection from the pool and return it.
|
||||
// If no workConn available in the pool, send message to frpc to get one or more
|
||||
// and wait until it is available.
|
||||
// return an error if wait timeout
|
||||
func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) { |
||||
defer func() { |
||||
if err := recover(); err != nil { |
||||
ctl.conn.Error("panic error: %v", err) |
||||
} |
||||
}() |
||||
|
||||
var ok bool |
||||
// get a work connection from the pool
|
||||
select { |
||||
case workConn, ok = <-ctl.workConnCh: |
||||
if !ok { |
||||
err = fmt.Errorf("no work connections available, control is closing") |
||||
return |
||||
} |
||||
ctl.conn.Debug("get work connection from pool") |
||||
default: |
||||
// no work connections available in the poll, send message to frpc to get more
|
||||
err = errors.PanicToError(func() { |
||||
ctl.sendCh <- &msg.ReqWorkConn{} |
||||
}) |
||||
if err != nil { |
||||
ctl.conn.Error("%v", err) |
||||
return |
||||
} |
||||
|
||||
select { |
||||
case workConn, ok = <-ctl.workConnCh: |
||||
if !ok { |
||||
err = fmt.Errorf("no work connections available, control is closing") |
||||
ctl.conn.Warn("%v", err) |
||||
return |
||||
} |
||||
|
||||
case <-time.After(time.Duration(config.ServerCommonCfg.UserConnTimeout) * time.Second): |
||||
err = fmt.Errorf("timeout trying to get work connection") |
||||
ctl.conn.Warn("%v", err) |
||||
return |
||||
} |
||||
} |
||||
|
||||
// When we get a work connection from pool, replace it with a new one.
|
||||
errors.PanicToError(func() { |
||||
ctl.sendCh <- &msg.ReqWorkConn{} |
||||
}) |
||||
return |
||||
} |
||||
|
||||
func (ctl *Control) Replaced(newCtl *Control) { |
||||
ctl.conn.Info("Replaced by client [%s]", newCtl.runId) |
||||
ctl.runId = "" |
||||
ctl.allShutdown.Start() |
||||
} |
||||
|
||||
func (ctl *Control) writer() { |
||||
defer func() { |
||||
if err := recover(); err != nil { |
||||
ctl.conn.Error("panic error: %v", err) |
||||
} |
||||
}() |
||||
|
||||
defer ctl.allShutdown.Start() |
||||
defer ctl.writerShutdown.Done() |
||||
|
||||
for { |
||||
if m, ok := <-ctl.sendCh; !ok { |
||||
ctl.conn.Info("control writer is closing") |
||||
return |
||||
} else { |
||||
if err := msg.WriteMsg(ctl.conn, m); err != nil { |
||||
ctl.conn.Warn("write message to control connection error: %v", err) |
||||
return |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (ctl *Control) reader() { |
||||
defer func() { |
||||
if err := recover(); err != nil { |
||||
ctl.conn.Error("panic error: %v", err) |
||||
} |
||||
}() |
||||
|
||||
defer ctl.allShutdown.Start() |
||||
defer ctl.readerShutdown.Done() |
||||
|
||||
for { |
||||
if m, err := msg.ReadMsg(ctl.conn); err != nil { |
||||
if err == io.EOF { |
||||
ctl.conn.Debug("control connection closed") |
||||
return |
||||
} else { |
||||
ctl.conn.Warn("read error: %v", err) |
||||
return |
||||
} |
||||
} else { |
||||
ctl.readCh <- m |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (ctl *Control) stoper() { |
||||
defer func() { |
||||
if err := recover(); err != nil { |
||||
ctl.conn.Error("panic error: %v", err) |
||||
} |
||||
}() |
||||
|
||||
ctl.allShutdown.WaitStart() |
||||
|
||||
close(ctl.readCh) |
||||
ctl.managerShutdown.WaitDown() |
||||
|
||||
close(ctl.sendCh) |
||||
ctl.writerShutdown.WaitDown() |
||||
|
||||
ctl.conn.Close() |
||||
|
||||
close(ctl.workConnCh) |
||||
for workConn := range ctl.workConnCh { |
||||
workConn.Close() |
||||
} |
||||
|
||||
for _, pxy := range ctl.proxies { |
||||
ctl.svr.DelProxy(pxy.GetName()) |
||||
pxy.Close() |
||||
} |
||||
|
||||
ctl.allShutdown.Done() |
||||
ctl.conn.Info("all shutdown success") |
||||
} |
||||
|
||||
func (ctl *Control) manager() { |
||||
defer func() { |
||||
if err := recover(); err != nil { |
||||
ctl.conn.Error("panic error: %v", err) |
||||
} |
||||
}() |
||||
|
||||
defer ctl.allShutdown.Start() |
||||
defer ctl.managerShutdown.Done() |
||||
|
||||
heartbeat := time.NewTicker(time.Second) |
||||
defer heartbeat.Stop() |
||||
|
||||
for { |
||||
select { |
||||
case <-heartbeat.C: |
||||
if time.Since(ctl.lastPing) > time.Duration(config.ServerCommonCfg.HeartBeatTimeout)*time.Second { |
||||
ctl.conn.Warn("heartbeat timeout") |
||||
ctl.allShutdown.Start() |
||||
} |
||||
case rawMsg, ok := <-ctl.readCh: |
||||
if !ok { |
||||
return |
||||
} |
||||
|
||||
switch m := rawMsg.(type) { |
||||
case *msg.NewProxy: |
||||
// register proxy in this control
|
||||
err := ctl.RegisterProxy(m) |
||||
resp := &msg.NewProxyResp{ |
||||
ProxyName: m.ProxyName, |
||||
} |
||||
if err != nil { |
||||
resp.Error = err.Error() |
||||
ctl.conn.Warn("new proxy [%s] error: %v", m.ProxyName, err) |
||||
} else { |
||||
ctl.conn.Info("new proxy [%s] success", m.ProxyName) |
||||
} |
||||
ctl.sendCh <- resp |
||||
case *msg.Ping: |
||||
ctl.lastPing = time.Now() |
||||
ctl.sendCh <- &msg.Pong{} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (err error) { |
||||
var pxyConf config.ProxyConf |
||||
// Load configures from NewProxy message and check.
|
||||
pxyConf, err = config.NewProxyConf(pxyMsg) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
// NewProxy will return a interface Proxy.
|
||||
// In fact it create different proxies by different proxy type, we just call run() here.
|
||||
pxy, err := NewProxy(ctl, pxyConf) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
err = pxy.Run() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer func() { |
||||
if err != nil { |
||||
pxy.Close() |
||||
} |
||||
}() |
||||
|
||||
err = ctl.svr.RegisterProxy(pxyMsg.ProxyName, pxy) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
ctl.proxies = append(ctl.proxies, pxy) |
||||
return nil |
||||
} |
@ -0,0 +1,75 @@
|
||||
package server |
||||
|
||||
import ( |
||||
"fmt" |
||||
"sync" |
||||
) |
||||
|
||||
type ControlManager struct { |
||||
// controls indexed by run id
|
||||
ctlsByRunId map[string]*Control |
||||
|
||||
mu sync.RWMutex |
||||
} |
||||
|
||||
func NewControlManager() *ControlManager { |
||||
return &ControlManager{ |
||||
ctlsByRunId: make(map[string]*Control), |
||||
} |
||||
} |
||||
|
||||
func (cm *ControlManager) Add(runId string, ctl *Control) (oldCtl *Control) { |
||||
cm.mu.Lock() |
||||
defer cm.mu.Unlock() |
||||
|
||||
oldCtl, ok := cm.ctlsByRunId[runId] |
||||
if ok { |
||||
oldCtl.Replaced(ctl) |
||||
} |
||||
cm.ctlsByRunId[runId] = ctl |
||||
return |
||||
} |
||||
|
||||
func (cm *ControlManager) GetById(runId string) (ctl *Control, ok bool) { |
||||
cm.mu.RLock() |
||||
defer cm.mu.RUnlock() |
||||
ctl, ok = cm.ctlsByRunId[runId] |
||||
return |
||||
} |
||||
|
||||
type ProxyManager struct { |
||||
// proxies indexed by proxy name
|
||||
pxys map[string]Proxy |
||||
|
||||
mu sync.RWMutex |
||||
} |
||||
|
||||
func NewProxyManager() *ProxyManager { |
||||
return &ProxyManager{ |
||||
pxys: make(map[string]Proxy), |
||||
} |
||||
} |
||||
|
||||
func (pm *ProxyManager) Add(name string, pxy Proxy) error { |
||||
pm.mu.Lock() |
||||
defer pm.mu.Unlock() |
||||
if _, ok := pm.pxys[name]; ok { |
||||
return fmt.Errorf("proxy name [%s] is already in use", name) |
||||
} |
||||
|
||||
pm.pxys[name] = pxy |
||||
return nil |
||||
} |
||||
|
||||
func (pm *ProxyManager) Del(name string) { |
||||
pm.mu.Lock() |
||||
defer pm.mu.Unlock() |
||||
delete(pm.pxys, name) |
||||
} |
||||
|
||||
func (pm *ProxyManager) GetByName(name string) (pxy Proxy, ok bool) { |
||||
pm.mu.RLock() |
||||
defer pm.mu.RUnlock() |
||||
pxy, ok = pm.pxys[name] |
||||
return |
||||
} |
@ -0,0 +1,219 @@
|
||||
package server |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
|
||||
"github.com/fatedier/frp/models/config" |
||||
"github.com/fatedier/frp/models/msg" |
||||
"github.com/fatedier/frp/models/proto/tcp" |
||||
"github.com/fatedier/frp/utils/log" |
||||
"github.com/fatedier/frp/utils/net" |
||||
) |
||||
|
||||
type Proxy interface { |
||||
Run() error |
||||
GetControl() *Control |
||||
GetName() string |
||||
GetConf() config.ProxyConf |
||||
Close() |
||||
log.Logger |
||||
} |
||||
|
||||
type BaseProxy struct { |
||||
name string |
||||
ctl *Control |
||||
listeners []net.Listener |
||||
log.Logger |
||||
} |
||||
|
||||
func (pxy *BaseProxy) GetName() string { |
||||
return pxy.name |
||||
} |
||||
|
||||
func (pxy *BaseProxy) GetControl() *Control { |
||||
return pxy.ctl |
||||
} |
||||
|
||||
func (pxy *BaseProxy) Close() { |
||||
pxy.Info("proxy closing") |
||||
for _, l := range pxy.listeners { |
||||
l.Close() |
||||
} |
||||
} |
||||
|
||||
func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy, err error) { |
||||
basePxy := BaseProxy{ |
||||
name: pxyConf.GetName(), |
||||
ctl: ctl, |
||||
listeners: make([]net.Listener, 0), |
||||
Logger: log.NewPrefixLogger(ctl.runId), |
||||
} |
||||
switch cfg := pxyConf.(type) { |
||||
case *config.TcpProxyConf: |
||||
pxy = &TcpProxy{ |
||||
BaseProxy: basePxy, |
||||
cfg: cfg, |
||||
} |
||||
case *config.HttpProxyConf: |
||||
pxy = &HttpProxy{ |
||||
BaseProxy: basePxy, |
||||
cfg: cfg, |
||||
} |
||||
case *config.HttpsProxyConf: |
||||
pxy = &HttpsProxy{ |
||||
BaseProxy: basePxy, |
||||
cfg: cfg, |
||||
} |
||||
case *config.UdpProxyConf: |
||||
pxy = &UdpProxy{ |
||||
BaseProxy: basePxy, |
||||
cfg: cfg, |
||||
} |
||||
default: |
||||
return pxy, fmt.Errorf("proxy type not support") |
||||
} |
||||
pxy.AddLogPrefix(pxy.GetName()) |
||||
return |
||||
} |
||||
|
||||
type TcpProxy struct { |
||||
BaseProxy |
||||
cfg *config.TcpProxyConf |
||||
} |
||||
|
||||
func (pxy *TcpProxy) Run() error { |
||||
listener, err := net.ListenTcp(config.ServerCommonCfg.BindAddr, int64(pxy.cfg.RemotePort)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
pxy.listeners = append(pxy.listeners, listener) |
||||
|
||||
go func(l net.Listener) { |
||||
for { |
||||
// block
|
||||
// if listener is closed, err returned
|
||||
c, err := l.Accept() |
||||
if err != nil { |
||||
pxy.Info("listener is closed") |
||||
return |
||||
} |
||||
pxy.Debug("got one user connection [%s]", c.RemoteAddr().String()) |
||||
go HandleUserTcpConnection(pxy, c) |
||||
} |
||||
}(listener) |
||||
return nil |
||||
} |
||||
|
||||
func (pxy *TcpProxy) GetConf() config.ProxyConf { |
||||
return pxy.cfg |
||||
} |
||||
|
||||
func (pxy *TcpProxy) Close() { |
||||
pxy.BaseProxy.Close() |
||||
} |
||||
|
||||
type HttpProxy struct { |
||||
BaseProxy |
||||
cfg *config.HttpProxyConf |
||||
} |
||||
|
||||
func (pxy *HttpProxy) Run() (err error) { |
||||
return |
||||
} |
||||
|
||||
func (pxy *HttpProxy) GetConf() config.ProxyConf { |
||||
return pxy.cfg |
||||
} |
||||
|
||||
func (pxy *HttpProxy) Close() { |
||||
pxy.BaseProxy.Close() |
||||
} |
||||
|
||||
type HttpsProxy struct { |
||||
BaseProxy |
||||
cfg *config.HttpsProxyConf |
||||
} |
||||
|
||||
func (pxy *HttpsProxy) Run() (err error) { |
||||
return |
||||
} |
||||
|
||||
func (pxy *HttpsProxy) GetConf() config.ProxyConf { |
||||
return pxy.cfg |
||||
} |
||||
|
||||
func (pxy *HttpsProxy) Close() { |
||||
pxy.BaseProxy.Close() |
||||
} |
||||
|
||||
type UdpProxy struct { |
||||
BaseProxy |
||||
cfg *config.UdpProxyConf |
||||
} |
||||
|
||||
func (pxy *UdpProxy) Run() (err error) { |
||||
return |
||||
} |
||||
|
||||
func (pxy *UdpProxy) GetConf() config.ProxyConf { |
||||
return pxy.cfg |
||||
} |
||||
|
||||
func (pxy *UdpProxy) Close() { |
||||
pxy.BaseProxy.Close() |
||||
} |
||||
|
||||
// HandleUserTcpConnection is used for incoming tcp user connections.
|
||||
func HandleUserTcpConnection(pxy Proxy, userConn net.Conn) { |
||||
defer userConn.Close() |
||||
ctl := pxy.GetControl() |
||||
var ( |
||||
workConn net.Conn |
||||
err error |
||||
) |
||||
// try all connections from the pool
|
||||
for i := 0; i < ctl.poolCount+1; i++ { |
||||
if workConn, err = ctl.GetWorkConn(); err != nil { |
||||
pxy.Warn("failed to get work connection: %v", err) |
||||
return |
||||
} |
||||
defer workConn.Close() |
||||
pxy.Info("get one new work connection: %s", workConn.RemoteAddr().String()) |
||||
workConn.AddLogPrefix(pxy.GetName()) |
||||
|
||||
err := msg.WriteMsg(workConn, &msg.StartWorkConn{ |
||||
ProxyName: pxy.GetName(), |
||||
}) |
||||
if err != nil { |
||||
workConn.Warn("failed to send message to work connection from pool: %v, times: %d", err, i) |
||||
workConn.Close() |
||||
} else { |
||||
break |
||||
} |
||||
} |
||||
|
||||
if err != nil { |
||||
pxy.Error("try to get work connection failed in the end") |
||||
return |
||||
} |
||||
|
||||
var ( |
||||
local io.ReadWriteCloser |
||||
remote io.ReadWriteCloser |
||||
) |
||||
local = workConn |
||||
remote = userConn |
||||
cfg := pxy.GetConf().GetBaseInfo() |
||||
if cfg.UseEncryption { |
||||
local, err = tcp.WithEncryption(local, []byte(config.ServerCommonCfg.PrivilegeToken)) |
||||
if err != nil { |
||||
pxy.Error("create encryption stream error: %v", err) |
||||
return |
||||
} |
||||
} |
||||
if cfg.UseCompression { |
||||
local = tcp.WithCompression(local) |
||||
} |
||||
tcp.Join(local, remote) |
||||
} |
@ -0,0 +1,196 @@
|
||||
package server |
||||
|
||||
import ( |
||||
"fmt" |
||||
"time" |
||||
|
||||
"github.com/fatedier/frp/assets" |
||||
"github.com/fatedier/frp/models/config" |
||||
"github.com/fatedier/frp/models/msg" |
||||
"github.com/fatedier/frp/utils/log" |
||||
"github.com/fatedier/frp/utils/net" |
||||
"github.com/fatedier/frp/utils/util" |
||||
"github.com/fatedier/frp/utils/version" |
||||
"github.com/fatedier/frp/utils/vhost" |
||||
) |
||||
|
||||
// Server service.
|
||||
type Service struct { |
||||
// Accept connections from client.
|
||||
listener net.Listener |
||||
|
||||
// For http proxies, route requests to different clients by hostname and other infomation.
|
||||
VhostHttpMuxer *vhost.HttpMuxer |
||||
|
||||
// For https proxies, route requests to different clients by hostname and other infomation.
|
||||
VhostHttpsMuxer *vhost.HttpsMuxer |
||||
|
||||
// Manage all controllers.
|
||||
ctlManager *ControlManager |
||||
|
||||
// Manage all proxies.
|
||||
pxyManager *ProxyManager |
||||
} |
||||
|
||||
func NewService() (svr *Service, err error) { |
||||
svr = &Service{ |
||||
ctlManager: NewControlManager(), |
||||
pxyManager: NewProxyManager(), |
||||
} |
||||
|
||||
// Init assets.
|
||||
err = assets.Load(config.ServerCommonCfg.AssetsDir) |
||||
if err != nil { |
||||
err = fmt.Errorf("Load assets error: %v", err) |
||||
return |
||||
} |
||||
|
||||
// Listen for accepting connections from client.
|
||||
svr.listener, err = net.ListenTcp(config.ServerCommonCfg.BindAddr, config.ServerCommonCfg.BindPort) |
||||
if err != nil { |
||||
err = fmt.Errorf("Create server listener error, %v", err) |
||||
return |
||||
} |
||||
|
||||
// Create http vhost muxer.
|
||||
if config.ServerCommonCfg.VhostHttpPort != 0 { |
||||
var l net.Listener |
||||
l, err = net.ListenTcp(config.ServerCommonCfg.BindAddr, config.ServerCommonCfg.VhostHttpPort) |
||||
if err != nil { |
||||
err = fmt.Errorf("Create vhost http listener error, %v", err) |
||||
return |
||||
} |
||||
svr.VhostHttpMuxer, err = vhost.NewHttpMuxer(l, 30*time.Second) |
||||
if err != nil { |
||||
err = fmt.Errorf("Create vhost httpMuxer error, %v", err) |
||||
return |
||||
} |
||||
} |
||||
|
||||
// Create https vhost muxer.
|
||||
if config.ServerCommonCfg.VhostHttpsPort != 0 { |
||||
var l net.Listener |
||||
l, err = net.ListenTcp(config.ServerCommonCfg.BindAddr, config.ServerCommonCfg.VhostHttpsPort) |
||||
if err != nil { |
||||
err = fmt.Errorf("Create vhost https listener error, %v", err) |
||||
return |
||||
} |
||||
svr.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(l, 30*time.Second) |
||||
if err != nil { |
||||
err = fmt.Errorf("Create vhost httpsMuxer error, %v", err) |
||||
return |
||||
} |
||||
} |
||||
|
||||
// Create dashboard web server.
|
||||
if config.ServerCommonCfg.DashboardPort != 0 { |
||||
err = RunDashboardServer(config.ServerCommonCfg.BindAddr, config.ServerCommonCfg.DashboardPort) |
||||
if err != nil { |
||||
err = fmt.Errorf("Create dashboard web server error, %v", err) |
||||
return |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (svr *Service) Run() { |
||||
// Listen for incoming connections from client.
|
||||
for { |
||||
c, err := svr.listener.Accept() |
||||
if err != nil { |
||||
log.Warn("Listener for incoming connections from client closed") |
||||
return |
||||
} |
||||
|
||||
// Start a new goroutine for dealing connections.
|
||||
go func(frpConn net.Conn) { |
||||
var rawMsg msg.Message |
||||
if rawMsg, err = msg.ReadMsg(frpConn); err != nil { |
||||
log.Warn("Failed to read message: %v", err) |
||||
frpConn.Close() |
||||
return |
||||
} |
||||
|
||||
switch m := rawMsg.(type) { |
||||
case *msg.Login: |
||||
err = svr.RegisterControl(frpConn, m) |
||||
// If login failed, send error message there.
|
||||
// Otherwise send success message in control's work goroutine.
|
||||
if err != nil { |
||||
frpConn.Warn("%v", err) |
||||
msg.WriteMsg(frpConn, &msg.LoginResp{ |
||||
Version: version.Full(), |
||||
Error: err.Error(), |
||||
}) |
||||
frpConn.Close() |
||||
} |
||||
case *msg.NewWorkConn: |
||||
svr.RegisterWorkConn(frpConn, m) |
||||
default: |
||||
log.Warn("Error message type for the new connection [%s]", frpConn.RemoteAddr().String()) |
||||
frpConn.Close() |
||||
} |
||||
}(c) |
||||
} |
||||
} |
||||
|
||||
func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err error) { |
||||
ctlConn.Info("client login info: ip [%s] version [%s] hostname [%s] os [%s] arch [%s]", |
||||
ctlConn.RemoteAddr().String(), loginMsg.Version, loginMsg.Hostname, loginMsg.Os, loginMsg.Arch) |
||||
|
||||
// Check client version.
|
||||
if ok, msg := version.Compat(loginMsg.Version); !ok { |
||||
err = fmt.Errorf("%s", msg) |
||||
return |
||||
} |
||||
|
||||
// Check auth.
|
||||
nowTime := time.Now().Unix() |
||||
if config.ServerCommonCfg.AuthTimeout != 0 && nowTime-loginMsg.Timestamp > config.ServerCommonCfg.AuthTimeout { |
||||
err = fmt.Errorf("authorization timeout") |
||||
return |
||||
} |
||||
if util.GetAuthKey(config.ServerCommonCfg.PrivilegeToken, loginMsg.Timestamp) != loginMsg.PrivilegeKey { |
||||
err = fmt.Errorf("authorization failed") |
||||
return |
||||
} |
||||
|
||||
// If client's RunId is empty, it's a new client, we just create a new controller.
|
||||
// Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one.
|
||||
if loginMsg.RunId == "" { |
||||
loginMsg.RunId, err = util.RandId() |
||||
if err != nil { |
||||
return |
||||
} |
||||
} |
||||
|
||||
ctl := NewControl(svr, ctlConn, loginMsg) |
||||
|
||||
if oldCtl := svr.ctlManager.Add(loginMsg.RunId, ctl); oldCtl != nil { |
||||
oldCtl.allShutdown.WaitDown() |
||||
} |
||||
|
||||
ctlConn.AddLogPrefix(loginMsg.RunId) |
||||
ctl.Start() |
||||
return |
||||
} |
||||
|
||||
// RegisterWorkConn register a new work connection to control and proxies need it.
|
||||
func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) { |
||||
ctl, exist := svr.ctlManager.GetById(newMsg.RunId) |
||||
if !exist { |
||||
workConn.Warn("No client control found for run id [%s]", newMsg.RunId) |
||||
return |
||||
} |
||||
ctl.RegisterWorkConn(workConn) |
||||
return |
||||
} |
||||
|
||||
func (svr *Service) RegisterProxy(name string, pxy Proxy) error { |
||||
err := svr.pxyManager.Add(name, pxy) |
||||
return err |
||||
} |
||||
|
||||
func (svr *Service) DelProxy(name string) { |
||||
svr.pxyManager.Del(name) |
||||
} |
File diff suppressed because one or more lines are too long
@ -1,231 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"io" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/fatedier/frp/src/models/client" |
||||
"github.com/fatedier/frp/src/models/consts" |
||||
"github.com/fatedier/frp/src/models/msg" |
||||
"github.com/fatedier/frp/src/utils/conn" |
||||
"github.com/fatedier/frp/src/utils/log" |
||||
"github.com/fatedier/frp/src/utils/pcrypto" |
||||
) |
||||
|
||||
func ControlProcess(cli *client.ProxyClient, wait *sync.WaitGroup) { |
||||
defer wait.Done() |
||||
|
||||
msgSendChan := make(chan interface{}, 1024) |
||||
|
||||
c, err := loginToServer(cli) |
||||
if err != nil { |
||||
log.Error("ProxyName [%s], connect to server failed!", cli.Name) |
||||
return |
||||
} |
||||
defer c.Close() |
||||
|
||||
go heartbeatSender(c, msgSendChan) |
||||
|
||||
go msgSender(cli, c, msgSendChan) |
||||
msgReader(cli, c, msgSendChan) |
||||
|
||||
close(msgSendChan) |
||||
} |
||||
|
||||
// loop for reading messages from frpc after control connection is established
|
||||
func msgReader(cli *client.ProxyClient, c *conn.Conn, msgSendChan chan interface{}) error { |
||||
// for heartbeat
|
||||
var heartbeatTimeout bool = false |
||||
timer := time.AfterFunc(time.Duration(client.HeartBeatTimeout)*time.Second, func() { |
||||
heartbeatTimeout = true |
||||
if c != nil { |
||||
c.Close() |
||||
} |
||||
if cli != nil { |
||||
// if it's not udp type, nothing will happen
|
||||
cli.CloseUdpTunnel() |
||||
cli.SetCloseFlag(true) |
||||
} |
||||
log.Error("ProxyName [%s], heartbeatRes from frps timeout", cli.Name) |
||||
}) |
||||
defer timer.Stop() |
||||
|
||||
for { |
||||
buf, err := c.ReadLine() |
||||
if err == io.EOF || c.IsClosed() { |
||||
timer.Stop() |
||||
c.Close() |
||||
cli.SetCloseFlag(true) |
||||
log.Warn("ProxyName [%s], frps close this control conn!", cli.Name) |
||||
var delayTime time.Duration = 1 |
||||
|
||||
// loop until reconnect to frps
|
||||
for { |
||||
log.Info("ProxyName [%s], try to reconnect to frps [%s:%d]...", cli.Name, client.ServerAddr, client.ServerPort) |
||||
c, err = loginToServer(cli) |
||||
if err == nil { |
||||
close(msgSendChan) |
||||
msgSendChan = make(chan interface{}, 1024) |
||||
go heartbeatSender(c, msgSendChan) |
||||
go msgSender(cli, c, msgSendChan) |
||||
cli.SetCloseFlag(false) |
||||
break |
||||
} |
||||
|
||||
if delayTime < 30 { |
||||
delayTime = delayTime * 2 |
||||
} else { |
||||
delayTime = 30 |
||||
} |
||||
time.Sleep(delayTime * time.Second) |
||||
} |
||||
continue |
||||
} else if err != nil { |
||||
log.Warn("ProxyName [%s], read from frps error: %v", cli.Name, err) |
||||
continue |
||||
} |
||||
|
||||
ctlRes := &msg.ControlRes{} |
||||
if err := json.Unmarshal([]byte(buf), &ctlRes); err != nil { |
||||
log.Warn("ProxyName [%s], parse msg from frps error: %v : %s", cli.Name, err, buf) |
||||
continue |
||||
} |
||||
|
||||
switch ctlRes.Type { |
||||
case consts.HeartbeatRes: |
||||
log.Debug("ProxyName [%s], receive heartbeat response", cli.Name) |
||||
timer.Reset(time.Duration(client.HeartBeatTimeout) * time.Second) |
||||
case consts.NoticeUserConn: |
||||
log.Debug("ProxyName [%s], new user connection", cli.Name) |
||||
// join local and remote connections, async
|
||||
go cli.StartTunnel(client.ServerAddr, client.ServerPort) |
||||
default: |
||||
log.Warn("ProxyName [%s}, unsupport msgType [%d]", cli.Name, ctlRes.Type) |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// loop for sending messages from channel to frps
|
||||
func msgSender(cli *client.ProxyClient, c *conn.Conn, msgSendChan chan interface{}) { |
||||
for { |
||||
msg, ok := <-msgSendChan |
||||
if !ok { |
||||
break |
||||
} |
||||
|
||||
buf, _ := json.Marshal(msg) |
||||
err := c.WriteString(string(buf) + "\n") |
||||
if err != nil { |
||||
log.Warn("ProxyName [%s], write to server error, proxy exit", cli.Name) |
||||
c.Close() |
||||
break |
||||
} |
||||
} |
||||
} |
||||
|
||||
func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) { |
||||
if client.HttpProxy == "" { |
||||
c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", client.ServerAddr, client.ServerPort)) |
||||
} else { |
||||
c, err = conn.ConnectServerByHttpProxy(client.HttpProxy, fmt.Sprintf("%s:%d", client.ServerAddr, client.ServerPort)) |
||||
} |
||||
if err != nil { |
||||
log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", cli.Name, client.ServerAddr, client.ServerPort, err) |
||||
return |
||||
} |
||||
|
||||
nowTime := time.Now().Unix() |
||||
req := &msg.ControlReq{ |
||||
Type: consts.NewCtlConn, |
||||
ProxyName: cli.Name, |
||||
UseEncryption: cli.UseEncryption, |
||||
UseGzip: cli.UseGzip, |
||||
PrivilegeMode: cli.PrivilegeMode, |
||||
ProxyType: cli.Type, |
||||
PoolCount: cli.PoolCount, |
||||
HostHeaderRewrite: cli.HostHeaderRewrite, |
||||
HttpUserName: cli.HttpUserName, |
||||
HttpPassWord: cli.HttpPassWord, |
||||
SubDomain: cli.SubDomain, |
||||
Timestamp: nowTime, |
||||
} |
||||
if cli.PrivilegeMode { |
||||
privilegeKey := pcrypto.GetAuthKey(cli.Name + client.PrivilegeToken + fmt.Sprintf("%d", nowTime)) |
||||
req.RemotePort = cli.RemotePort |
||||
req.CustomDomains = cli.CustomDomains |
||||
req.Locations = cli.Locations |
||||
req.PrivilegeKey = privilegeKey |
||||
} else { |
||||
authKey := pcrypto.GetAuthKey(cli.Name + cli.AuthToken + fmt.Sprintf("%d", nowTime)) |
||||
req.AuthKey = authKey |
||||
} |
||||
|
||||
buf, _ := json.Marshal(req) |
||||
err = c.WriteString(string(buf) + "\n") |
||||
if err != nil { |
||||
log.Error("ProxyName [%s], write to server error, %v", cli.Name, err) |
||||
return |
||||
} |
||||
|
||||
res, err := c.ReadLine() |
||||
if err != nil { |
||||
log.Error("ProxyName [%s], read from server error, %v", cli.Name, err) |
||||
return |
||||
} |
||||
log.Debug("ProxyName [%s], read [%s]", cli.Name, res) |
||||
|
||||
ctlRes := &msg.ControlRes{} |
||||
if err = json.Unmarshal([]byte(res), &ctlRes); err != nil { |
||||
log.Error("ProxyName [%s], format server response error, %v", cli.Name, err) |
||||
return |
||||
} |
||||
|
||||
if ctlRes.Code != 0 { |
||||
log.Error("ProxyName [%s], start proxy error, %s", cli.Name, ctlRes.Msg) |
||||
return c, fmt.Errorf("%s", ctlRes.Msg) |
||||
} |
||||
|
||||
log.Info("ProxyName [%s], connect to server [%s:%d] success!", cli.Name, client.ServerAddr, client.ServerPort) |
||||
|
||||
if cli.Type == "udp" { |
||||
// we only need one udp work connection
|
||||
// all udp messages will be forwarded throngh this connection
|
||||
go cli.StartUdpTunnelOnce(client.ServerAddr, client.ServerPort) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func heartbeatSender(c *conn.Conn, msgSendChan chan interface{}) { |
||||
heartbeatReq := &msg.ControlReq{ |
||||
Type: consts.HeartbeatReq, |
||||
} |
||||
log.Info("Start to send heartbeat to frps") |
||||
for { |
||||
time.Sleep(time.Duration(client.HeartBeatInterval) * time.Second) |
||||
if c != nil && !c.IsClosed() { |
||||
log.Debug("Send heartbeat to server") |
||||
msgSendChan <- heartbeatReq |
||||
} else { |
||||
break |
||||
} |
||||
} |
||||
log.Info("Heartbeat goroutine exit") |
||||
} |
@ -1,365 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"io" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/fatedier/frp/src/models/consts" |
||||
"github.com/fatedier/frp/src/models/metric" |
||||
"github.com/fatedier/frp/src/models/msg" |
||||
"github.com/fatedier/frp/src/models/server" |
||||
"github.com/fatedier/frp/src/utils/conn" |
||||
"github.com/fatedier/frp/src/utils/log" |
||||
"github.com/fatedier/frp/src/utils/pcrypto" |
||||
) |
||||
|
||||
func ProcessControlConn(l *conn.Listener) { |
||||
for { |
||||
c, err := l.Accept() |
||||
if err != nil { |
||||
return |
||||
} |
||||
log.Debug("Get new connection, %v", c.GetRemoteAddr()) |
||||
go controlWorker(c) |
||||
} |
||||
} |
||||
|
||||
// connection from every client and server
|
||||
func controlWorker(c *conn.Conn) { |
||||
// if login message type is NewWorkConn, don't close this connection
|
||||
var closeFlag bool = true |
||||
var s *server.ProxyServer |
||||
defer func() { |
||||
if closeFlag { |
||||
c.Close() |
||||
if s != nil { |
||||
s.Close() |
||||
} |
||||
} |
||||
}() |
||||
|
||||
// get login message
|
||||
buf, err := c.ReadLine() |
||||
if err != nil { |
||||
log.Warn("Read error, %v", err) |
||||
return |
||||
} |
||||
log.Debug("Get msg from frpc: %s", buf) |
||||
|
||||
cliReq := &msg.ControlReq{} |
||||
if err := json.Unmarshal([]byte(buf), &cliReq); err != nil { |
||||
log.Warn("Parse msg from frpc error: %v : %s", err, buf) |
||||
return |
||||
} |
||||
|
||||
// login when type is NewCtlConn or NewWorkConn
|
||||
ret, info, s := doLogin(cliReq, c) |
||||
// if login type is NewWorkConn, nothing will be send to frpc
|
||||
if cliReq.Type == consts.NewCtlConn { |
||||
cliRes := &msg.ControlRes{ |
||||
Type: consts.NewCtlConnRes, |
||||
Code: ret, |
||||
Msg: info, |
||||
} |
||||
byteBuf, _ := json.Marshal(cliRes) |
||||
err = c.WriteString(string(byteBuf) + "\n") |
||||
if err != nil { |
||||
log.Warn("ProxyName [%s], write to client error, proxy exit", cliReq.ProxyName) |
||||
return |
||||
} |
||||
} else { |
||||
if ret == 0 { |
||||
closeFlag = false |
||||
} |
||||
return |
||||
} |
||||
|
||||
// if login failed, just return
|
||||
if ret > 0 { |
||||
return |
||||
} |
||||
|
||||
// create a channel for sending messages
|
||||
msgSendChan := make(chan interface{}, 1024) |
||||
go msgSender(s, c, msgSendChan) |
||||
go noticeUserConn(s, msgSendChan) |
||||
|
||||
// loop for reading control messages from frpc and deal with different types
|
||||
msgReader(s, c, msgSendChan) |
||||
|
||||
close(msgSendChan) |
||||
log.Info("ProxyName [%s], I'm dead!", s.Name) |
||||
return |
||||
} |
||||
|
||||
// when frps get one new user connection, send NoticeUserConn message to frpc and accept one new WorkConn later
|
||||
func noticeUserConn(s *server.ProxyServer, msgSendChan chan interface{}) { |
||||
for { |
||||
closeFlag := s.WaitUserConn() |
||||
if closeFlag { |
||||
log.Debug("ProxyName [%s], goroutine for noticing user conn is closed", s.Name) |
||||
break |
||||
} |
||||
notice := &msg.ControlRes{ |
||||
Type: consts.NoticeUserConn, |
||||
} |
||||
msgSendChan <- notice |
||||
log.Debug("ProxyName [%s], notice client to add work conn", s.Name) |
||||
} |
||||
} |
||||
|
||||
// loop for reading messages from frpc after control connection is established
|
||||
func msgReader(s *server.ProxyServer, c *conn.Conn, msgSendChan chan interface{}) error { |
||||
// for heartbeat
|
||||
var heartbeatTimeout bool = false |
||||
timer := time.AfterFunc(time.Duration(server.HeartBeatTimeout)*time.Second, func() { |
||||
heartbeatTimeout = true |
||||
s.Close() |
||||
log.Error("ProxyName [%s], client heartbeat timeout", s.Name) |
||||
}) |
||||
defer timer.Stop() |
||||
|
||||
for { |
||||
buf, err := c.ReadLine() |
||||
if err != nil { |
||||
if err == io.EOF { |
||||
log.Warn("ProxyName [%s], client is dead!", s.Name) |
||||
s.Close() |
||||
return err |
||||
} else if c == nil || c.IsClosed() { |
||||
log.Warn("ProxyName [%s], client connection is closed", s.Name) |
||||
s.Close() |
||||
return err |
||||
} |
||||
log.Warn("ProxyName [%s], read error: %v", s.Name, err) |
||||
continue |
||||
} |
||||
|
||||
cliReq := &msg.ControlReq{} |
||||
if err := json.Unmarshal([]byte(buf), &cliReq); err != nil { |
||||
log.Warn("ProxyName [%s], parse msg from frpc error: %v : %s", s.Name, err, buf) |
||||
continue |
||||
} |
||||
|
||||
switch cliReq.Type { |
||||
case consts.HeartbeatReq: |
||||
log.Debug("ProxyName [%s], get heartbeat", s.Name) |
||||
timer.Reset(time.Duration(server.HeartBeatTimeout) * time.Second) |
||||
heartbeatRes := &msg.ControlRes{ |
||||
Type: consts.HeartbeatRes, |
||||
} |
||||
msgSendChan <- heartbeatRes |
||||
default: |
||||
log.Warn("ProxyName [%s}, unsupport msgType [%d]", s.Name, cliReq.Type) |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// loop for sending messages from channel to frpc
|
||||
func msgSender(s *server.ProxyServer, c *conn.Conn, msgSendChan chan interface{}) { |
||||
for { |
||||
msg, ok := <-msgSendChan |
||||
if !ok { |
||||
break |
||||
} |
||||
|
||||
buf, _ := json.Marshal(msg) |
||||
err := c.WriteString(string(buf) + "\n") |
||||
if err != nil { |
||||
log.Warn("ProxyName [%s], write to client error, proxy exit", s.Name) |
||||
s.Close() |
||||
break |
||||
} |
||||
} |
||||
} |
||||
|
||||
// if success, ret equals 0, otherwise greater than 0
|
||||
// NewCtlConn
|
||||
// NewWorkConn
|
||||
// NewWorkConnUdp
|
||||
func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string, s *server.ProxyServer) { |
||||
ret = 1 |
||||
// check if PrivilegeMode is enabled
|
||||
if req.PrivilegeMode && !server.PrivilegeMode { |
||||
info = fmt.Sprintf("ProxyName [%s], PrivilegeMode is disabled in frps", req.ProxyName) |
||||
log.Warn("info") |
||||
return |
||||
} |
||||
|
||||
var ok bool |
||||
s, ok = server.GetProxyServer(req.ProxyName) |
||||
if req.PrivilegeMode && req.Type == consts.NewCtlConn { |
||||
log.Debug("ProxyName [%s], doLogin and privilege mode is enabled", req.ProxyName) |
||||
} else { |
||||
if !ok { |
||||
info = fmt.Sprintf("ProxyName [%s] is not exist", req.ProxyName) |
||||
log.Warn(info) |
||||
return |
||||
} |
||||
} |
||||
|
||||
// check authKey or privilegeKey
|
||||
nowTime := time.Now().Unix() |
||||
if req.PrivilegeMode { |
||||
privilegeKey := pcrypto.GetAuthKey(req.ProxyName + server.PrivilegeToken + fmt.Sprintf("%d", req.Timestamp)) |
||||
// privilegeKey unavaiable after server.AuthTimeout minutes
|
||||
if server.AuthTimeout != 0 && nowTime-req.Timestamp > server.AuthTimeout { |
||||
info = fmt.Sprintf("ProxyName [%s], privilege mode authorization timeout", req.ProxyName) |
||||
log.Warn(info) |
||||
return |
||||
} else if req.PrivilegeKey != privilegeKey { |
||||
info = fmt.Sprintf("ProxyName [%s], privilege mode authorization failed", req.ProxyName) |
||||
log.Warn(info) |
||||
log.Debug("PrivilegeKey [%s] and get [%s]", privilegeKey, req.PrivilegeKey) |
||||
return |
||||
} |
||||
} else { |
||||
authKey := pcrypto.GetAuthKey(req.ProxyName + s.AuthToken + fmt.Sprintf("%d", req.Timestamp)) |
||||
if server.AuthTimeout != 0 && nowTime-req.Timestamp > server.AuthTimeout { |
||||
info = fmt.Sprintf("ProxyName [%s], authorization timeout", req.ProxyName) |
||||
log.Warn(info) |
||||
return |
||||
} else if req.AuthKey != authKey { |
||||
info = fmt.Sprintf("ProxyName [%s], authorization failed", req.ProxyName) |
||||
log.Warn(info) |
||||
log.Debug("AuthKey [%s] and get [%s]", authKey, req.AuthKey) |
||||
return |
||||
} |
||||
} |
||||
|
||||
// control conn
|
||||
if req.Type == consts.NewCtlConn { |
||||
if req.PrivilegeMode { |
||||
s = server.NewProxyServerFromCtlMsg(req) |
||||
// we check listen_port if privilege_allow_ports are set
|
||||
// and PrivilegeMode is enabled
|
||||
if s.Type == "tcp" { |
||||
if len(server.PrivilegeAllowPorts) != 0 { |
||||
_, ok := server.PrivilegeAllowPorts[s.ListenPort] |
||||
if !ok { |
||||
info = fmt.Sprintf("ProxyName [%s], remote_port [%d] isn't allowed", req.ProxyName, s.ListenPort) |
||||
log.Warn(info) |
||||
return |
||||
} |
||||
} |
||||
} else if s.Type == "http" || s.Type == "https" { |
||||
for _, domain := range s.CustomDomains { |
||||
if server.SubDomainHost != "" && strings.Contains(domain, server.SubDomainHost) { |
||||
info = fmt.Sprintf("ProxyName [%s], custom domain [%s] should not belong to subdomain_host [%s]", req.ProxyName, domain, server.SubDomainHost) |
||||
log.Warn(info) |
||||
return |
||||
} |
||||
} |
||||
|
||||
if s.SubDomain != "" { |
||||
if strings.Contains(s.SubDomain, ".") || strings.Contains(s.SubDomain, "*") { |
||||
info = fmt.Sprintf("ProxyName [%s], '.' and '*' is not supported in subdomain", req.ProxyName) |
||||
log.Warn(info) |
||||
return |
||||
} |
||||
if server.SubDomainHost == "" { |
||||
info = fmt.Sprintf("ProxyName [%s], subdomain is not supported because this feature is not enabled by remote server", req.ProxyName) |
||||
log.Warn(info) |
||||
return |
||||
} |
||||
s.SubDomain = s.SubDomain + "." + server.SubDomainHost |
||||
} |
||||
} |
||||
err := server.CreateProxy(s) |
||||
if err != nil { |
||||
info = fmt.Sprintf("ProxyName [%s], %v", req.ProxyName, err) |
||||
log.Warn(info) |
||||
return |
||||
} |
||||
} |
||||
|
||||
// check if vhost_port is set
|
||||
if s.Type == "http" && server.VhostHttpMuxer == nil { |
||||
info = fmt.Sprintf("ProxyName [%s], type [http] not support when vhost_http_port is not set", req.ProxyName) |
||||
log.Warn(info) |
||||
return |
||||
} |
||||
if s.Type == "https" && server.VhostHttpsMuxer == nil { |
||||
info = fmt.Sprintf("ProxyName [%s], type [https] not support when vhost_https_port is not set", req.ProxyName) |
||||
log.Warn(info) |
||||
return |
||||
} |
||||
|
||||
// set infomations from frpc
|
||||
s.BindAddr = server.BindAddr |
||||
s.UseEncryption = req.UseEncryption |
||||
s.UseGzip = req.UseGzip |
||||
s.HostHeaderRewrite = req.HostHeaderRewrite |
||||
s.HttpUserName = req.HttpUserName |
||||
s.HttpPassWord = req.HttpPassWord |
||||
|
||||
if req.PoolCount > server.MaxPoolCount { |
||||
s.PoolCount = server.MaxPoolCount |
||||
} else if req.PoolCount < 0 { |
||||
s.PoolCount = 0 |
||||
} else { |
||||
s.PoolCount = req.PoolCount |
||||
} |
||||
|
||||
if s.Status == consts.Working { |
||||
info = fmt.Sprintf("ProxyName [%s], already in use", req.ProxyName) |
||||
log.Warn(info) |
||||
return |
||||
} |
||||
|
||||
// update metric's proxy status
|
||||
metric.SetProxyInfo(s.Name, s.Type, s.BindAddr, s.UseEncryption, s.UseGzip, s.PrivilegeMode, s.CustomDomains, s.Locations, s.ListenPort) |
||||
|
||||
// start proxy and listen for user connections, no block
|
||||
err := s.Start(c) |
||||
if err != nil { |
||||
info = fmt.Sprintf("ProxyName [%s], start proxy error: %v", req.ProxyName, err) |
||||
log.Warn(info) |
||||
return |
||||
} |
||||
log.Info("ProxyName [%s], start proxy success", req.ProxyName) |
||||
if req.PrivilegeMode { |
||||
log.Info("ProxyName [%s], created by PrivilegeMode", req.ProxyName) |
||||
} |
||||
} else if req.Type == consts.NewWorkConn { |
||||
// work conn
|
||||
if s.Status != consts.Working { |
||||
log.Warn("ProxyName [%s], is not working when it gets one new work connnection", req.ProxyName) |
||||
return |
||||
} |
||||
// the connection will close after join over
|
||||
s.RegisterNewWorkConn(c) |
||||
} else if req.Type == consts.NewWorkConnUdp { |
||||
// work conn for udp
|
||||
if s.Status != consts.Working { |
||||
log.Warn("ProxyName [%s], is not working when it gets one new work connnection for udp", req.ProxyName) |
||||
return |
||||
} |
||||
s.RegisterNewWorkConnUdp(c) |
||||
} else { |
||||
info = fmt.Sprintf("Unsupport login message type [%d]", req.Type) |
||||
log.Warn("Unsupport login message type [%d]", req.Type) |
||||
return |
||||
} |
||||
|
||||
ret = 0 |
||||
return |
||||
} |
@ -1,199 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"encoding/base64" |
||||
"encoding/json" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"net/http" |
||||
"os" |
||||
"strconv" |
||||
"strings" |
||||
"time" |
||||
|
||||
docopt "github.com/docopt/docopt-go" |
||||
|
||||
"github.com/fatedier/frp/src/assets" |
||||
"github.com/fatedier/frp/src/models/server" |
||||
"github.com/fatedier/frp/src/utils/conn" |
||||
"github.com/fatedier/frp/src/utils/log" |
||||
"github.com/fatedier/frp/src/utils/version" |
||||
"github.com/fatedier/frp/src/utils/vhost" |
||||
) |
||||
|
||||
var usage string = `frps is the server of frp |
||||
|
||||
Usage:
|
||||
frps [-c config_file] [-L log_file] [--log-level=<log_level>] [--addr=<bind_addr>] |
||||
frps [-c config_file] --reload |
||||
frps -h | --help |
||||
frps -v | --version |
||||
|
||||
Options: |
||||
-c config_file set config file |
||||
-L log_file set output log file, including console |
||||
--log-level=<log_level> set log level: debug, info, warn, error |
||||
--addr=<bind_addr> listen addr for client, example: 0.0.0.0:7000 |
||||
--reload reload ini file and configures in common section won't be changed |
||||
-h --help show this screen |
||||
-v --version show version |
||||
` |
||||
|
||||
func main() { |
||||
// the configures parsed from file will be replaced by those from command line if exist
|
||||
args, err := docopt.Parse(usage, nil, true, version.Full(), false) |
||||
|
||||
if args["-c"] != nil { |
||||
server.ConfigFile = args["-c"].(string) |
||||
} |
||||
err = server.LoadConf(server.ConfigFile) |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
os.Exit(-1) |
||||
} |
||||
|
||||
// reload check
|
||||
if args["--reload"] != nil { |
||||
if args["--reload"].(bool) { |
||||
req, err := http.NewRequest("GET", "http://"+server.BindAddr+":"+fmt.Sprintf("%d", server.DashboardPort)+"/api/reload", nil) |
||||
if err != nil { |
||||
fmt.Printf("frps reload error: %v\n", err) |
||||
os.Exit(1) |
||||
} |
||||
|
||||
authStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(server.DashboardUsername+":"+server.DashboardPassword)) |
||||
|
||||
req.Header.Add("Authorization", authStr) |
||||
defaultClient := &http.Client{} |
||||
resp, err := defaultClient.Do(req) |
||||
|
||||
if err != nil { |
||||
fmt.Printf("frps reload error: %v\n", err) |
||||
os.Exit(1) |
||||
} else { |
||||
defer resp.Body.Close() |
||||
body, err := ioutil.ReadAll(resp.Body) |
||||
if err != nil { |
||||
fmt.Printf("frps reload error: %v\n", err) |
||||
os.Exit(1) |
||||
} |
||||
res := &server.GeneralResponse{} |
||||
err = json.Unmarshal(body, &res) |
||||
if err != nil { |
||||
fmt.Printf("http response error: %s\n", strings.TrimSpace(string(body))) |
||||
os.Exit(1) |
||||
} else if res.Code != 0 { |
||||
fmt.Printf("reload error: %s\n", res.Msg) |
||||
os.Exit(1) |
||||
} |
||||
fmt.Printf("reload success\n") |
||||
os.Exit(0) |
||||
} |
||||
} |
||||
} |
||||
|
||||
if args["-L"] != nil { |
||||
if args["-L"].(string) == "console" { |
||||
server.LogWay = "console" |
||||
} else { |
||||
server.LogWay = "file" |
||||
server.LogFile = args["-L"].(string) |
||||
} |
||||
} |
||||
|
||||
if args["--log-level"] != nil { |
||||
server.LogLevel = args["--log-level"].(string) |
||||
} |
||||
|
||||
if args["--addr"] != nil { |
||||
addr := strings.Split(args["--addr"].(string), ":") |
||||
if len(addr) != 2 { |
||||
fmt.Println("--addr format error: example 0.0.0.0:7000") |
||||
os.Exit(1) |
||||
} |
||||
bindPort, err := strconv.ParseInt(addr[1], 10, 64) |
||||
if err != nil { |
||||
fmt.Println("--addr format error, example 0.0.0.0:7000") |
||||
os.Exit(1) |
||||
} |
||||
server.BindAddr = addr[0] |
||||
server.BindPort = bindPort |
||||
} |
||||
|
||||
if args["-v"] != nil { |
||||
if args["-v"].(bool) { |
||||
fmt.Println(version.Full()) |
||||
os.Exit(0) |
||||
} |
||||
} |
||||
|
||||
log.InitLog(server.LogWay, server.LogFile, server.LogLevel, server.LogMaxDays) |
||||
|
||||
// init assets
|
||||
err = assets.Load(server.AssetsDir) |
||||
if err != nil { |
||||
log.Error("Load assets error: %v", err) |
||||
os.Exit(1) |
||||
} |
||||
|
||||
l, err := conn.Listen(server.BindAddr, server.BindPort) |
||||
if err != nil { |
||||
log.Error("Create server listener error, %v", err) |
||||
os.Exit(1) |
||||
} |
||||
|
||||
// create vhost if VhostHttpPort != 0
|
||||
if server.VhostHttpPort != 0 { |
||||
vhostListener, err := conn.Listen(server.BindAddr, server.VhostHttpPort) |
||||
if err != nil { |
||||
log.Error("Create vhost http listener error, %v", err) |
||||
os.Exit(1) |
||||
} |
||||
server.VhostHttpMuxer, err = vhost.NewHttpMuxer(vhostListener, 30*time.Second) |
||||
if err != nil { |
||||
log.Error("Create vhost httpMuxer error, %v", err) |
||||
} |
||||
} |
||||
|
||||
// create vhost if VhostHttpPort != 0
|
||||
if server.VhostHttpsPort != 0 { |
||||
vhostListener, err := conn.Listen(server.BindAddr, server.VhostHttpsPort) |
||||
if err != nil { |
||||
log.Error("Create vhost https listener error, %v", err) |
||||
os.Exit(1) |
||||
} |
||||
server.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(vhostListener, 30*time.Second) |
||||
if err != nil { |
||||
log.Error("Create vhost httpsMuxer error, %v", err) |
||||
} |
||||
} |
||||
|
||||
// create dashboard web server if DashboardPort is set, so it won't be 0
|
||||
if server.DashboardPort != 0 { |
||||
err := server.RunDashboardServer(server.BindAddr, server.DashboardPort) |
||||
if err != nil { |
||||
log.Error("Create dashboard web server error, %v", err) |
||||
os.Exit(1) |
||||
} |
||||
} |
||||
|
||||
log.Info("Start frps success") |
||||
if server.PrivilegeMode == true { |
||||
log.Info("PrivilegeMode is enabled, you should pay more attention to security issues") |
||||
} |
||||
ProcessControlConn(l) |
||||
} |
@ -1,186 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/fatedier/frp/src/models/config" |
||||
"github.com/fatedier/frp/src/models/consts" |
||||
"github.com/fatedier/frp/src/models/msg" |
||||
"github.com/fatedier/frp/src/utils/conn" |
||||
"github.com/fatedier/frp/src/utils/log" |
||||
"github.com/fatedier/frp/src/utils/pcrypto" |
||||
) |
||||
|
||||
type ProxyClient struct { |
||||
config.BaseConf |
||||
LocalIp string |
||||
LocalPort int64 |
||||
|
||||
RemotePort int64 |
||||
CustomDomains []string |
||||
Locations []string |
||||
|
||||
udpTunnel *conn.Conn |
||||
once sync.Once |
||||
closeFlag bool |
||||
|
||||
mutex sync.RWMutex |
||||
} |
||||
|
||||
// if proxy type is udp, keep a tcp connection for transferring udp packages
|
||||
func (pc *ProxyClient) StartUdpTunnelOnce(addr string, port int64) { |
||||
pc.once.Do(func() { |
||||
var err error |
||||
var c *conn.Conn |
||||
udpProcessor := NewUdpProcesser(nil, pc.LocalIp, pc.LocalPort) |
||||
for { |
||||
if !pc.IsClosed() && (pc.udpTunnel == nil || pc.udpTunnel.IsClosed()) { |
||||
if HttpProxy == "" { |
||||
c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", addr, port)) |
||||
} else { |
||||
c, err = conn.ConnectServerByHttpProxy(HttpProxy, fmt.Sprintf("%s:%d", addr, port)) |
||||
} |
||||
if err != nil { |
||||
log.Error("ProxyName [%s], udp tunnel connect to server [%s:%d] error, %v", pc.Name, addr, port, err) |
||||
time.Sleep(10 * time.Second) |
||||
continue |
||||
} |
||||
log.Info("ProxyName [%s], udp tunnel connect to server [%s:%d] success", pc.Name, addr, port) |
||||
|
||||
nowTime := time.Now().Unix() |
||||
req := &msg.ControlReq{ |
||||
Type: consts.NewWorkConnUdp, |
||||
ProxyName: pc.Name, |
||||
PrivilegeMode: pc.PrivilegeMode, |
||||
Timestamp: nowTime, |
||||
} |
||||
if pc.PrivilegeMode == true { |
||||
req.PrivilegeKey = pcrypto.GetAuthKey(pc.Name + PrivilegeToken + fmt.Sprintf("%d", nowTime)) |
||||
} else { |
||||
req.AuthKey = pcrypto.GetAuthKey(pc.Name + pc.AuthToken + fmt.Sprintf("%d", nowTime)) |
||||
} |
||||
|
||||
buf, _ := json.Marshal(req) |
||||
err = c.WriteString(string(buf) + "\n") |
||||
if err != nil { |
||||
log.Error("ProxyName [%s], udp tunnel write to server error, %v", pc.Name, err) |
||||
c.Close() |
||||
time.Sleep(1 * time.Second) |
||||
continue |
||||
} |
||||
pc.mutex.Lock() |
||||
pc.udpTunnel = c |
||||
udpProcessor.UpdateTcpConn(pc.udpTunnel) |
||||
pc.mutex.Unlock() |
||||
|
||||
udpProcessor.Run() |
||||
} |
||||
time.Sleep(1 * time.Second) |
||||
} |
||||
}) |
||||
} |
||||
|
||||
func (pc *ProxyClient) CloseUdpTunnel() { |
||||
pc.mutex.RLock() |
||||
defer pc.mutex.RUnlock() |
||||
if pc.udpTunnel != nil { |
||||
pc.udpTunnel.Close() |
||||
} |
||||
} |
||||
|
||||
func (pc *ProxyClient) GetLocalConn() (c *conn.Conn, err error) { |
||||
c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", pc.LocalIp, pc.LocalPort)) |
||||
if err != nil { |
||||
log.Error("ProxyName [%s], connect to local port error, %v", pc.Name, err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (pc *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err error) { |
||||
defer func() { |
||||
if err != nil && c != nil { |
||||
c.Close() |
||||
} |
||||
}() |
||||
|
||||
if HttpProxy == "" { |
||||
c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", addr, port)) |
||||
} else { |
||||
c, err = conn.ConnectServerByHttpProxy(HttpProxy, fmt.Sprintf("%s:%d", addr, port)) |
||||
} |
||||
if err != nil { |
||||
log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", pc.Name, addr, port, err) |
||||
return |
||||
} |
||||
|
||||
nowTime := time.Now().Unix() |
||||
req := &msg.ControlReq{ |
||||
Type: consts.NewWorkConn, |
||||
ProxyName: pc.Name, |
||||
PrivilegeMode: pc.PrivilegeMode, |
||||
Timestamp: nowTime, |
||||
} |
||||
if pc.PrivilegeMode == true { |
||||
req.PrivilegeKey = pcrypto.GetAuthKey(pc.Name + PrivilegeToken + fmt.Sprintf("%d", nowTime)) |
||||
} else { |
||||
req.AuthKey = pcrypto.GetAuthKey(pc.Name + pc.AuthToken + fmt.Sprintf("%d", nowTime)) |
||||
} |
||||
|
||||
buf, _ := json.Marshal(req) |
||||
err = c.WriteString(string(buf) + "\n") |
||||
if err != nil { |
||||
log.Error("ProxyName [%s], write to server error, %v", pc.Name, err) |
||||
return |
||||
} |
||||
|
||||
err = nil |
||||
return |
||||
} |
||||
|
||||
func (pc *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err error) { |
||||
localConn, err := pc.GetLocalConn() |
||||
if err != nil { |
||||
return |
||||
} |
||||
remoteConn, err := pc.GetRemoteConn(serverAddr, serverPort) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
// l means local, r means remote
|
||||
log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", localConn.GetLocalAddr(), localConn.GetRemoteAddr(), |
||||
remoteConn.GetLocalAddr(), remoteConn.GetRemoteAddr()) |
||||
needRecord := false |
||||
go msg.JoinMore(localConn, remoteConn, pc.BaseConf, needRecord) |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (pc *ProxyClient) SetCloseFlag(closeFlag bool) { |
||||
pc.mutex.Lock() |
||||
defer pc.mutex.Unlock() |
||||
pc.closeFlag = closeFlag |
||||
} |
||||
|
||||
func (pc *ProxyClient) IsClosed() bool { |
||||
pc.mutex.RLock() |
||||
defer pc.mutex.RUnlock() |
||||
return pc.closeFlag |
||||
} |
@ -1,302 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
ini "github.com/vaughan0/go-ini" |
||||
) |
||||
|
||||
// common config
|
||||
var ( |
||||
ServerAddr string = "0.0.0.0" |
||||
ServerPort int64 = 7000 |
||||
HttpProxy string = "" |
||||
LogFile string = "console" |
||||
LogWay string = "console" |
||||
LogLevel string = "info" |
||||
LogMaxDays int64 = 3 |
||||
PrivilegeToken string = "" |
||||
HeartBeatInterval int64 = 10 |
||||
HeartBeatTimeout int64 = 30 |
||||
) |
||||
|
||||
var ProxyClients map[string]*ProxyClient = make(map[string]*ProxyClient) |
||||
|
||||
func LoadConf(confFile string) (err error) { |
||||
var tmpStr string |
||||
var ok bool |
||||
|
||||
conf, err := ini.LoadFile(confFile) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
// common
|
||||
tmpStr, ok = conf.Get("common", "server_addr") |
||||
if ok { |
||||
ServerAddr = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "server_port") |
||||
if ok { |
||||
ServerPort, _ = strconv.ParseInt(tmpStr, 10, 64) |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "http_proxy") |
||||
if ok { |
||||
HttpProxy = tmpStr |
||||
} else { |
||||
// get http_proxy from env
|
||||
HttpProxy = os.Getenv("http_proxy") |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "log_file") |
||||
if ok { |
||||
LogFile = tmpStr |
||||
if LogFile == "console" { |
||||
LogWay = "console" |
||||
} else { |
||||
LogWay = "file" |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "log_level") |
||||
if ok { |
||||
LogLevel = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "log_max_days") |
||||
if ok { |
||||
LogMaxDays, _ = strconv.ParseInt(tmpStr, 10, 64) |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "privilege_token") |
||||
if ok { |
||||
PrivilegeToken = tmpStr |
||||
} |
||||
|
||||
var authToken string |
||||
tmpStr, ok = conf.Get("common", "auth_token") |
||||
if ok { |
||||
authToken = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "heartbeat_timeout") |
||||
if ok { |
||||
v, err := strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil { |
||||
return fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect") |
||||
} else { |
||||
HeartBeatTimeout = v |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "heartbeat_interval") |
||||
if ok { |
||||
v, err := strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil { |
||||
return fmt.Errorf("Parse conf error: heartbeat_interval is incorrect") |
||||
} else { |
||||
HeartBeatInterval = v |
||||
} |
||||
} |
||||
|
||||
if HeartBeatInterval <= 0 { |
||||
return fmt.Errorf("Parse conf error: heartbeat_interval is incorrect") |
||||
} |
||||
|
||||
if HeartBeatTimeout < HeartBeatInterval { |
||||
return fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect, heartbeat_timeout is less than heartbeat_interval") |
||||
} |
||||
|
||||
// proxies
|
||||
for name, section := range conf { |
||||
if name != "common" { |
||||
proxyClient := &ProxyClient{} |
||||
// name
|
||||
proxyClient.Name = name |
||||
|
||||
// auth_token
|
||||
proxyClient.AuthToken = authToken |
||||
|
||||
// local_ip
|
||||
proxyClient.LocalIp, ok = section["local_ip"] |
||||
if !ok { |
||||
// use 127.0.0.1 as default
|
||||
proxyClient.LocalIp = "127.0.0.1" |
||||
} |
||||
|
||||
// local_port
|
||||
tmpStr, ok = section["local_port"] |
||||
if ok { |
||||
proxyClient.LocalPort, err = strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] local_port error", proxyClient.Name) |
||||
} |
||||
} else { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] local_port not found", proxyClient.Name) |
||||
} |
||||
|
||||
// type
|
||||
proxyClient.Type = "tcp" |
||||
tmpStr, ok = section["type"] |
||||
if ok { |
||||
if tmpStr != "tcp" && tmpStr != "http" && tmpStr != "https" && tmpStr != "udp" { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] type error", proxyClient.Name) |
||||
} |
||||
proxyClient.Type = tmpStr |
||||
} |
||||
|
||||
// use_encryption
|
||||
proxyClient.UseEncryption = false |
||||
tmpStr, ok = section["use_encryption"] |
||||
if ok && tmpStr == "true" { |
||||
proxyClient.UseEncryption = true |
||||
} |
||||
|
||||
// use_gzip
|
||||
proxyClient.UseGzip = false |
||||
tmpStr, ok = section["use_gzip"] |
||||
if ok && tmpStr == "true" { |
||||
proxyClient.UseGzip = true |
||||
} |
||||
|
||||
if proxyClient.Type == "http" { |
||||
// host_header_rewrite
|
||||
tmpStr, ok = section["host_header_rewrite"] |
||||
if ok { |
||||
proxyClient.HostHeaderRewrite = tmpStr |
||||
} |
||||
// http_user
|
||||
tmpStr, ok = section["http_user"] |
||||
if ok { |
||||
proxyClient.HttpUserName = tmpStr |
||||
} |
||||
// http_pwd
|
||||
tmpStr, ok = section["http_pwd"] |
||||
if ok { |
||||
proxyClient.HttpPassWord = tmpStr |
||||
} |
||||
|
||||
} |
||||
if proxyClient.Type == "http" || proxyClient.Type == "https" { |
||||
// subdomain
|
||||
tmpStr, ok = section["subdomain"] |
||||
if ok { |
||||
proxyClient.SubDomain = tmpStr |
||||
} |
||||
} |
||||
|
||||
// privilege_mode
|
||||
proxyClient.PrivilegeMode = false |
||||
tmpStr, ok = section["privilege_mode"] |
||||
if ok && tmpStr == "true" { |
||||
proxyClient.PrivilegeMode = true |
||||
} |
||||
|
||||
// pool_count
|
||||
proxyClient.PoolCount = 0 |
||||
tmpStr, ok = section["pool_count"] |
||||
if ok { |
||||
tmpInt, err := strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil || tmpInt < 0 { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] pool_count error", proxyClient.Name) |
||||
} |
||||
proxyClient.PoolCount = tmpInt |
||||
} |
||||
|
||||
// configures used in privilege mode
|
||||
if proxyClient.PrivilegeMode == true { |
||||
if PrivilegeToken == "" { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] privilege_token must be set when privilege_mode = true", proxyClient.Name) |
||||
} else { |
||||
proxyClient.PrivilegeToken = PrivilegeToken |
||||
} |
||||
|
||||
if proxyClient.Type == "tcp" || proxyClient.Type == "udp" { |
||||
// remote_port
|
||||
tmpStr, ok = section["remote_port"] |
||||
if ok { |
||||
proxyClient.RemotePort, err = strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] remote_port error", proxyClient.Name) |
||||
} |
||||
} else { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", proxyClient.Name) |
||||
} |
||||
} else if proxyClient.Type == "http" { |
||||
// custom_domains
|
||||
tmpStr, ok = section["custom_domains"] |
||||
if ok { |
||||
proxyClient.CustomDomains = strings.Split(tmpStr, ",") |
||||
for i, domain := range proxyClient.CustomDomains { |
||||
proxyClient.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain)) |
||||
} |
||||
} |
||||
|
||||
// subdomain
|
||||
tmpStr, ok = section["subdomain"] |
||||
if ok { |
||||
proxyClient.SubDomain = tmpStr |
||||
} |
||||
|
||||
if len(proxyClient.CustomDomains) == 0 && proxyClient.SubDomain == "" { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] custom_domains and subdomain should set at least one of them when type is http", proxyClient.Name) |
||||
} |
||||
|
||||
// locations
|
||||
tmpStr, ok = section["locations"] |
||||
if ok { |
||||
proxyClient.Locations = strings.Split(tmpStr, ",") |
||||
} else { |
||||
proxyClient.Locations = []string{""} |
||||
} |
||||
} else if proxyClient.Type == "https" { |
||||
// custom_domains
|
||||
tmpStr, ok = section["custom_domains"] |
||||
if ok { |
||||
proxyClient.CustomDomains = strings.Split(tmpStr, ",") |
||||
for i, domain := range proxyClient.CustomDomains { |
||||
proxyClient.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain)) |
||||
} |
||||
} |
||||
|
||||
// subdomain
|
||||
tmpStr, ok = section["subdomain"] |
||||
if ok { |
||||
proxyClient.SubDomain = tmpStr |
||||
} |
||||
|
||||
if len(proxyClient.CustomDomains) == 0 && proxyClient.SubDomain == "" { |
||||
return fmt.Errorf("Parse conf error: proxy [%s] custom_domains and subdomain should set at least one of them when type is https", proxyClient.Name) |
||||
} |
||||
} |
||||
} |
||||
|
||||
ProxyClients[proxyClient.Name] = proxyClient |
||||
} |
||||
} |
||||
|
||||
if len(ProxyClients) == 0 { |
||||
return fmt.Errorf("Parse conf error: no proxy config found") |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -1,49 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg |
||||
|
||||
type GeneralRes struct { |
||||
Code int64 `json:"code"` |
||||
Msg string `json:"msg"` |
||||
} |
||||
|
||||
// messages between control connections of frpc and frps
|
||||
type ControlReq struct { |
||||
Type int64 `json:"type"` |
||||
ProxyName string `json:"proxy_name"` |
||||
AuthKey string `json:"auth_key"` |
||||
UseEncryption bool `json:"use_encryption"` |
||||
UseGzip bool `json:"use_gzip"` |
||||
PoolCount int64 `json:"pool_count"` |
||||
|
||||
// configures used if privilege_mode is enabled
|
||||
PrivilegeMode bool `json:"privilege_mode"` |
||||
PrivilegeKey string `json:"privilege_key"` |
||||
ProxyType string `json:"proxy_type"` |
||||
RemotePort int64 `json:"remote_port"` |
||||
CustomDomains []string `json:"custom_domains, omitempty"` |
||||
Locations []string `json:"locations"` |
||||
HostHeaderRewrite string `json:"host_header_rewrite"` |
||||
HttpUserName string `json:"http_username"` |
||||
HttpPassWord string `json:"http_password"` |
||||
SubDomain string `json:"subdomain"` |
||||
Timestamp int64 `json:"timestamp"` |
||||
} |
||||
|
||||
type ControlRes struct { |
||||
Type int64 `json:"type"` |
||||
Code int64 `json:"code"` |
||||
Msg string `json:"msg"` |
||||
} |
@ -1,257 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/binary" |
||||
"fmt" |
||||
"io" |
||||
"sync" |
||||
|
||||
"github.com/fatedier/frp/src/models/config" |
||||
"github.com/fatedier/frp/src/models/metric" |
||||
"github.com/fatedier/frp/src/utils/conn" |
||||
"github.com/fatedier/frp/src/utils/log" |
||||
"github.com/fatedier/frp/src/utils/pcrypto" |
||||
"github.com/fatedier/frp/src/utils/pool" |
||||
) |
||||
|
||||
// deprecated
|
||||
// will block until connection close
|
||||
func Join(c1 *conn.Conn, c2 *conn.Conn) { |
||||
var wait sync.WaitGroup |
||||
pipe := func(to *conn.Conn, from *conn.Conn) { |
||||
defer to.Close() |
||||
defer from.Close() |
||||
defer wait.Done() |
||||
|
||||
var err error |
||||
_, err = io.Copy(to.TcpConn, from.TcpConn) |
||||
if err != nil { |
||||
log.Warn("join connections error, %v", err) |
||||
} |
||||
} |
||||
|
||||
wait.Add(2) |
||||
go pipe(c1, c2) |
||||
go pipe(c2, c1) |
||||
wait.Wait() |
||||
return |
||||
} |
||||
|
||||
// join two connections and do some operations
|
||||
func JoinMore(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser, conf config.BaseConf, needRecord bool) { |
||||
var wait sync.WaitGroup |
||||
encryptPipe := func(from io.ReadCloser, to io.WriteCloser) { |
||||
defer from.Close() |
||||
defer to.Close() |
||||
defer wait.Done() |
||||
|
||||
// we don't care about errors here
|
||||
pipeEncrypt(from, to, conf, needRecord) |
||||
} |
||||
|
||||
decryptPipe := func(to io.ReadCloser, from io.WriteCloser) { |
||||
defer from.Close() |
||||
defer to.Close() |
||||
defer wait.Done() |
||||
|
||||
// we don't care about errors here
|
||||
pipeDecrypt(to, from, conf, needRecord) |
||||
} |
||||
|
||||
if needRecord { |
||||
metric.OpenConnection(conf.Name) |
||||
} |
||||
wait.Add(2) |
||||
go encryptPipe(c1, c2) |
||||
go decryptPipe(c2, c1) |
||||
wait.Wait() |
||||
if needRecord { |
||||
metric.CloseConnection(conf.Name) |
||||
} |
||||
log.Debug("ProxyName [%s], One tunnel stopped", conf.Name) |
||||
return |
||||
} |
||||
|
||||
func pkgMsg(data []byte) []byte { |
||||
llen := uint32(len(data)) |
||||
buf := new(bytes.Buffer) |
||||
binary.Write(buf, binary.BigEndian, llen) |
||||
buf.Write(data) |
||||
return buf.Bytes() |
||||
} |
||||
|
||||
func unpkgMsg(data []byte) (int, []byte, []byte) { |
||||
if len(data) < 4 { |
||||
return -1, nil, data |
||||
} |
||||
llen := int(binary.BigEndian.Uint32(data[0:4])) |
||||
// no complete
|
||||
if len(data) < llen+4 { |
||||
return -1, nil, data |
||||
} |
||||
|
||||
return 0, data[4 : llen+4], data[llen+4:] |
||||
} |
||||
|
||||
// decrypt msg from reader, then write into writer
|
||||
func pipeDecrypt(r io.Reader, w io.Writer, conf config.BaseConf, needRecord bool) (err error) { |
||||
laes := new(pcrypto.Pcrypto) |
||||
key := conf.AuthToken |
||||
if conf.PrivilegeMode { |
||||
key = conf.PrivilegeToken |
||||
} |
||||
if err := laes.Init([]byte(key)); err != nil { |
||||
log.Warn("ProxyName [%s], Pcrypto Init error: %v", conf.Name, err) |
||||
return fmt.Errorf("Pcrypto Init error: %v", err) |
||||
} |
||||
|
||||
// get []byte from buffer pool
|
||||
buf := pool.GetBuf(5*1024 + 4) |
||||
defer pool.PutBuf(buf) |
||||
|
||||
var left, res []byte |
||||
var cnt int = -1 |
||||
|
||||
// record
|
||||
var flowBytes int64 = 0 |
||||
if needRecord { |
||||
defer func() { |
||||
metric.AddFlowOut(conf.Name, flowBytes) |
||||
}() |
||||
} |
||||
|
||||
for { |
||||
// there may be more than 1 package in variable
|
||||
// and we read more bytes if unpkgMsg returns an error
|
||||
var newBuf []byte |
||||
if cnt < 0 { |
||||
n, err := r.Read(buf) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
newBuf = append(left, buf[0:n]...) |
||||
} else { |
||||
newBuf = left |
||||
} |
||||
cnt, res, left = unpkgMsg(newBuf) |
||||
if cnt < 0 { |
||||
// limit one package length, maximum is 1MB
|
||||
if len(res) > 1024*1024 { |
||||
log.Warn("ProxyName [%s], package length exceeds the limit") |
||||
return fmt.Errorf("package length error") |
||||
} |
||||
continue |
||||
} |
||||
|
||||
// aes
|
||||
if conf.UseEncryption { |
||||
res, err = laes.Decrypt(res) |
||||
if err != nil { |
||||
log.Warn("ProxyName [%s], decrypt error, %v", conf.Name, err) |
||||
return fmt.Errorf("Decrypt error: %v", err) |
||||
} |
||||
} |
||||
// gzip
|
||||
if conf.UseGzip { |
||||
res, err = laes.Decompression(res) |
||||
if err != nil { |
||||
log.Warn("ProxyName [%s], decompression error, %v", conf.Name, err) |
||||
return fmt.Errorf("Decompression error: %v", err) |
||||
} |
||||
} |
||||
|
||||
_, err = w.Write(res) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if needRecord { |
||||
flowBytes += int64(len(res)) |
||||
if flowBytes >= 1024*1024 { |
||||
metric.AddFlowOut(conf.Name, flowBytes) |
||||
flowBytes = 0 |
||||
} |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// recvive msg from reader, then encrypt msg into writer
|
||||
func pipeEncrypt(r io.Reader, w io.Writer, conf config.BaseConf, needRecord bool) (err error) { |
||||
laes := new(pcrypto.Pcrypto) |
||||
key := conf.AuthToken |
||||
if conf.PrivilegeMode { |
||||
key = conf.PrivilegeToken |
||||
} |
||||
if err := laes.Init([]byte(key)); err != nil { |
||||
log.Warn("ProxyName [%s], Pcrypto Init error: %v", conf.Name, err) |
||||
return fmt.Errorf("Pcrypto Init error: %v", err) |
||||
} |
||||
|
||||
// record
|
||||
var flowBytes int64 = 0 |
||||
if needRecord { |
||||
defer func() { |
||||
metric.AddFlowIn(conf.Name, flowBytes) |
||||
}() |
||||
} |
||||
|
||||
// get []byte from buffer pool
|
||||
buf := pool.GetBuf(5*1024 + 4) |
||||
defer pool.PutBuf(buf) |
||||
|
||||
for { |
||||
n, err := r.Read(buf) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if needRecord { |
||||
flowBytes += int64(n) |
||||
if flowBytes >= 1024*1024 { |
||||
metric.AddFlowIn(conf.Name, flowBytes) |
||||
flowBytes = 0 |
||||
} |
||||
} |
||||
|
||||
res := buf[0:n] |
||||
// gzip
|
||||
if conf.UseGzip { |
||||
res, err = laes.Compression(res) |
||||
if err != nil { |
||||
log.Warn("ProxyName [%s], compression error: %v", conf.Name, err) |
||||
return fmt.Errorf("Compression error: %v", err) |
||||
} |
||||
} |
||||
// aes
|
||||
if conf.UseEncryption { |
||||
res, err = laes.Encrypt(res) |
||||
if err != nil { |
||||
log.Warn("ProxyName [%s], encrypt error: %v", conf.Name, err) |
||||
return fmt.Errorf("Encrypt error: %v", err) |
||||
} |
||||
} |
||||
|
||||
res = pkgMsg(res) |
||||
_, err = w.Write(res) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -1,456 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server |
||||
|
||||
import ( |
||||
"fmt" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
|
||||
ini "github.com/vaughan0/go-ini" |
||||
|
||||
"github.com/fatedier/frp/src/models/consts" |
||||
"github.com/fatedier/frp/src/models/metric" |
||||
"github.com/fatedier/frp/src/utils/log" |
||||
"github.com/fatedier/frp/src/utils/vhost" |
||||
) |
||||
|
||||
// common config
|
||||
var ( |
||||
ConfigFile string = "./frps.ini" |
||||
BindAddr string = "0.0.0.0" |
||||
BindPort int64 = 7000 |
||||
VhostHttpPort int64 = 0 // if VhostHttpPort equals 0, don't listen a public port for http protocol
|
||||
VhostHttpsPort int64 = 0 // if VhostHttpsPort equals 0, don't listen a public port for https protocol
|
||||
DashboardPort int64 = 0 // if DashboardPort equals 0, dashboard is not available
|
||||
DashboardUsername string = "admin" |
||||
DashboardPassword string = "admin" |
||||
AssetsDir string = "" |
||||
LogFile string = "console" |
||||
LogWay string = "console" // console or file
|
||||
LogLevel string = "info" |
||||
LogMaxDays int64 = 3 |
||||
PrivilegeMode bool = false |
||||
PrivilegeToken string = "" |
||||
AuthTimeout int64 = 900 |
||||
SubDomainHost string = "" |
||||
|
||||
// if PrivilegeAllowPorts is not nil, tcp proxies which remote port exist in this map can be connected
|
||||
PrivilegeAllowPorts map[int64]struct{} |
||||
MaxPoolCount int64 = 100 |
||||
HeartBeatTimeout int64 = 30 |
||||
UserConnTimeout int64 = 10 |
||||
|
||||
VhostHttpMuxer *vhost.HttpMuxer |
||||
VhostHttpsMuxer *vhost.HttpsMuxer |
||||
ProxyServers map[string]*ProxyServer = make(map[string]*ProxyServer) // all proxy servers info and resources
|
||||
ProxyServersMutex sync.RWMutex |
||||
) |
||||
|
||||
func LoadConf(confFile string) (err error) { |
||||
err = loadCommonConf(confFile) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
// load all proxy server's configure and initialize
|
||||
// and set ProxyServers map
|
||||
newProxyServers, err := loadProxyConf(confFile) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
for _, proxyServer := range newProxyServers { |
||||
proxyServer.Init() |
||||
} |
||||
ProxyServersMutex.Lock() |
||||
ProxyServers = newProxyServers |
||||
ProxyServersMutex.Unlock() |
||||
return nil |
||||
} |
||||
|
||||
func loadCommonConf(confFile string) error { |
||||
var tmpStr string |
||||
var ok bool |
||||
conf, err := ini.LoadFile(confFile) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
// common
|
||||
tmpStr, ok = conf.Get("common", "bind_addr") |
||||
if ok { |
||||
BindAddr = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "bind_port") |
||||
if ok { |
||||
v, err := strconv.ParseInt(tmpStr, 10, 64) |
||||
if err == nil { |
||||
BindPort = v |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "vhost_http_port") |
||||
if ok { |
||||
VhostHttpPort, _ = strconv.ParseInt(tmpStr, 10, 64) |
||||
} else { |
||||
VhostHttpPort = 0 |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "vhost_https_port") |
||||
if ok { |
||||
VhostHttpsPort, _ = strconv.ParseInt(tmpStr, 10, 64) |
||||
} else { |
||||
VhostHttpsPort = 0 |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "dashboard_port") |
||||
if ok { |
||||
DashboardPort, _ = strconv.ParseInt(tmpStr, 10, 64) |
||||
} else { |
||||
DashboardPort = 0 |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "dashboard_user") |
||||
if ok { |
||||
DashboardUsername = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "dashboard_pwd") |
||||
if ok { |
||||
DashboardPassword = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "assets_dir") |
||||
if ok { |
||||
AssetsDir = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "log_file") |
||||
if ok { |
||||
LogFile = tmpStr |
||||
if LogFile == "console" { |
||||
LogWay = "console" |
||||
} else { |
||||
LogWay = "file" |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "log_level") |
||||
if ok { |
||||
LogLevel = tmpStr |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "log_max_days") |
||||
if ok { |
||||
v, err := strconv.ParseInt(tmpStr, 10, 64) |
||||
if err == nil { |
||||
LogMaxDays = v |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "privilege_mode") |
||||
if ok { |
||||
if tmpStr == "true" { |
||||
PrivilegeMode = true |
||||
} |
||||
} |
||||
|
||||
if PrivilegeMode == true { |
||||
tmpStr, ok = conf.Get("common", "privilege_token") |
||||
if ok { |
||||
if tmpStr == "" { |
||||
return fmt.Errorf("Parse conf error: privilege_token can not be null") |
||||
} |
||||
PrivilegeToken = tmpStr |
||||
} else { |
||||
return fmt.Errorf("Parse conf error: privilege_token must be set if privilege_mode is enabled") |
||||
} |
||||
|
||||
PrivilegeAllowPorts = make(map[int64]struct{}) |
||||
tmpStr, ok = conf.Get("common", "privilege_allow_ports") |
||||
if ok { |
||||
// for example: 1000-2000,2001,2002,3000-4000
|
||||
portRanges := strings.Split(tmpStr, ",") |
||||
for _, portRangeStr := range portRanges { |
||||
// 1000-2000 or 2001
|
||||
portArray := strings.Split(portRangeStr, "-") |
||||
// lenght: only 1 or 2 is correct
|
||||
rangeType := len(portArray) |
||||
if rangeType == 1 { |
||||
singlePort, err := strconv.ParseInt(portArray[0], 10, 64) |
||||
if err != nil { |
||||
return fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) |
||||
} |
||||
PrivilegeAllowPorts[singlePort] = struct{}{} |
||||
} else if rangeType == 2 { |
||||
min, err := strconv.ParseInt(portArray[0], 10, 64) |
||||
if err != nil { |
||||
return fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) |
||||
} |
||||
max, err := strconv.ParseInt(portArray[1], 10, 64) |
||||
if err != nil { |
||||
return fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) |
||||
} |
||||
if max < min { |
||||
return fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect") |
||||
} |
||||
for i := min; i <= max; i++ { |
||||
PrivilegeAllowPorts[i] = struct{}{} |
||||
} |
||||
} else { |
||||
return fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect") |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "max_pool_count") |
||||
if ok { |
||||
v, err := strconv.ParseInt(tmpStr, 10, 64) |
||||
if err == nil && v >= 0 { |
||||
MaxPoolCount = v |
||||
} |
||||
} |
||||
tmpStr, ok = conf.Get("common", "authentication_timeout") |
||||
if ok { |
||||
v, err := strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil { |
||||
return fmt.Errorf("Parse conf error: authentication_timeout is incorrect") |
||||
} else { |
||||
AuthTimeout = v |
||||
} |
||||
} |
||||
SubDomainHost, ok = conf.Get("common", "subdomain_host") |
||||
if ok { |
||||
SubDomainHost = strings.ToLower(strings.TrimSpace(SubDomainHost)) |
||||
} |
||||
|
||||
tmpStr, ok = conf.Get("common", "heartbeat_timeout") |
||||
if ok { |
||||
v, err := strconv.ParseInt(tmpStr, 10, 64) |
||||
if err != nil { |
||||
return fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect") |
||||
} else { |
||||
HeartBeatTimeout = v |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err error) { |
||||
var ok bool |
||||
proxyServers = make(map[string]*ProxyServer) |
||||
conf, err := ini.LoadFile(confFile) |
||||
if err != nil { |
||||
return proxyServers, err |
||||
} |
||||
// servers
|
||||
for name, section := range conf { |
||||
if name != "common" { |
||||
proxyServer := NewProxyServer() |
||||
proxyServer.Name = name |
||||
|
||||
proxyServer.Type, ok = section["type"] |
||||
if ok { |
||||
if proxyServer.Type != "tcp" && proxyServer.Type != "http" && proxyServer.Type != "https" && proxyServer.Type != "udp" { |
||||
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] type error", proxyServer.Name) |
||||
} |
||||
} else { |
||||
proxyServer.Type = "tcp" |
||||
} |
||||
|
||||
proxyServer.AuthToken, ok = section["auth_token"] |
||||
if !ok { |
||||
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] no auth_token found", proxyServer.Name) |
||||
} |
||||
|
||||
// for tcp and udp
|
||||
if proxyServer.Type == "tcp" || proxyServer.Type == "udp" { |
||||
proxyServer.BindAddr, ok = section["bind_addr"] |
||||
if !ok { |
||||
proxyServer.BindAddr = "0.0.0.0" |
||||
} |
||||
|
||||
portStr, ok := section["listen_port"] |
||||
if ok { |
||||
proxyServer.ListenPort, err = strconv.ParseInt(portStr, 10, 64) |
||||
if err != nil { |
||||
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] listen_port error", proxyServer.Name) |
||||
} |
||||
} else { |
||||
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] listen_port not found", proxyServer.Name) |
||||
} |
||||
} else if proxyServer.Type == "http" { |
||||
// for http
|
||||
proxyServer.ListenPort = VhostHttpPort |
||||
|
||||
domainStr, ok := section["custom_domains"] |
||||
if ok { |
||||
proxyServer.CustomDomains = strings.Split(domainStr, ",") |
||||
for i, domain := range proxyServer.CustomDomains { |
||||
domain = strings.ToLower(strings.TrimSpace(domain)) |
||||
// custom domain should not belong to subdomain_host
|
||||
if SubDomainHost != "" && strings.Contains(domain, SubDomainHost) { |
||||
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom domain should not belong to subdomain_host", proxyServer.Name) |
||||
} |
||||
proxyServer.CustomDomains[i] = domain |
||||
} |
||||
} |
||||
|
||||
// subdomain
|
||||
subdomainStr, ok := section["subdomain"] |
||||
if ok { |
||||
if strings.Contains(subdomainStr, ".") || strings.Contains(subdomainStr, "*") { |
||||
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] '.' and '*' is not supported in subdomain", proxyServer.Name) |
||||
} |
||||
|
||||
if SubDomainHost == "" { |
||||
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] subdomain is not supported because subdomain_host is empty", proxyServer.Name) |
||||
} |
||||
proxyServer.SubDomain = subdomainStr + "." + SubDomainHost |
||||
} |
||||
|
||||
if len(proxyServer.CustomDomains) == 0 && proxyServer.SubDomain == "" { |
||||
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom_domains and subdomain should set at least one of them when type is http", proxyServer.Name) |
||||
} |
||||
|
||||
// locations
|
||||
locations, ok := section["locations"] |
||||
if ok { |
||||
proxyServer.Locations = strings.Split(locations, ",") |
||||
} else { |
||||
proxyServer.Locations = []string{""} |
||||
} |
||||
} else if proxyServer.Type == "https" { |
||||
// for https
|
||||
proxyServer.ListenPort = VhostHttpsPort |
||||
|
||||
domainStr, ok := section["custom_domains"] |
||||
if ok { |
||||
proxyServer.CustomDomains = strings.Split(domainStr, ",") |
||||
for i, domain := range proxyServer.CustomDomains { |
||||
domain = strings.ToLower(strings.TrimSpace(domain)) |
||||
if SubDomainHost != "" && strings.Contains(domain, SubDomainHost) { |
||||
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom domain should not belong to subdomain_host", proxyServer.Name) |
||||
} |
||||
proxyServer.CustomDomains[i] = domain |
||||
} |
||||
} |
||||
|
||||
// subdomain
|
||||
subdomainStr, ok := section["subdomain"] |
||||
if ok { |
||||
if strings.Contains(subdomainStr, ".") || strings.Contains(subdomainStr, "*") { |
||||
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] '.' and '*' is not supported in subdomain", proxyServer.Name) |
||||
} |
||||
|
||||
if SubDomainHost == "" { |
||||
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] subdomain is not supported because subdomain_host is empty", proxyServer.Name) |
||||
} |
||||
proxyServer.SubDomain = subdomainStr + "." + SubDomainHost |
||||
} |
||||
|
||||
if len(proxyServer.CustomDomains) == 0 && proxyServer.SubDomain == "" { |
||||
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom_domains and subdomain should set at least one of them when type is https", proxyServer.Name) |
||||
} |
||||
} |
||||
proxyServers[proxyServer.Name] = proxyServer |
||||
} |
||||
} |
||||
|
||||
// set metric statistics of all proxies
|
||||
for name, p := range proxyServers { |
||||
metric.SetProxyInfo(name, p.Type, p.BindAddr, p.UseEncryption, p.UseGzip, |
||||
p.PrivilegeMode, p.CustomDomains, p.Locations, p.ListenPort) |
||||
} |
||||
return proxyServers, nil |
||||
} |
||||
|
||||
// the function can only reload proxy configures
|
||||
// common section won't be changed
|
||||
func ReloadConf(confFile string) (err error) { |
||||
loadProxyServers, err := loadProxyConf(confFile) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
ProxyServersMutex.Lock() |
||||
for name, proxyServer := range loadProxyServers { |
||||
oldProxyServer, ok := ProxyServers[name] |
||||
if ok { |
||||
if !oldProxyServer.Compare(proxyServer) { |
||||
oldProxyServer.Close() |
||||
proxyServer.Init() |
||||
ProxyServers[name] = proxyServer |
||||
log.Info("ProxyName [%s] configure change, restart", name) |
||||
} |
||||
} else { |
||||
proxyServer.Init() |
||||
ProxyServers[name] = proxyServer |
||||
log.Info("ProxyName [%s] is new, init it", name) |
||||
} |
||||
} |
||||
|
||||
// proxies created by PrivilegeMode won't be deleted
|
||||
for name, oldProxyServer := range ProxyServers { |
||||
_, ok := loadProxyServers[name] |
||||
if !ok { |
||||
if !oldProxyServer.PrivilegeMode { |
||||
oldProxyServer.Close() |
||||
delete(ProxyServers, name) |
||||
log.Info("ProxyName [%s] deleted, close it", name) |
||||
} else { |
||||
log.Info("ProxyName [%s] created by PrivilegeMode, won't be closed", name) |
||||
} |
||||
} |
||||
} |
||||
ProxyServersMutex.Unlock() |
||||
return nil |
||||
} |
||||
|
||||
func CreateProxy(s *ProxyServer) error { |
||||
ProxyServersMutex.Lock() |
||||
defer ProxyServersMutex.Unlock() |
||||
oldServer, ok := ProxyServers[s.Name] |
||||
if ok { |
||||
if oldServer.Status == consts.Working { |
||||
return fmt.Errorf("this proxy is already working now") |
||||
} |
||||
oldServer.Lock() |
||||
oldServer.Release() |
||||
oldServer.Unlock() |
||||
if oldServer.PrivilegeMode { |
||||
delete(ProxyServers, s.Name) |
||||
} |
||||
} |
||||
ProxyServers[s.Name] = s |
||||
metric.SetProxyInfo(s.Name, s.Type, s.BindAddr, s.UseEncryption, s.UseGzip, |
||||
s.PrivilegeMode, s.CustomDomains, s.Locations, s.ListenPort) |
||||
return nil |
||||
} |
||||
|
||||
func DeleteProxy(proxyName string) { |
||||
ProxyServersMutex.Lock() |
||||
defer ProxyServersMutex.Unlock() |
||||
delete(ProxyServers, proxyName) |
||||
} |
||||
|
||||
func GetProxyServer(proxyName string) (p *ProxyServer, ok bool) { |
||||
ProxyServersMutex.RLock() |
||||
defer ProxyServersMutex.RUnlock() |
||||
p, ok = ProxyServers[proxyName] |
||||
return |
||||
} |
@ -1,484 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/fatedier/frp/src/models/config" |
||||
"github.com/fatedier/frp/src/models/consts" |
||||
"github.com/fatedier/frp/src/models/metric" |
||||
"github.com/fatedier/frp/src/models/msg" |
||||
"github.com/fatedier/frp/src/utils/conn" |
||||
"github.com/fatedier/frp/src/utils/log" |
||||
"github.com/fatedier/frp/src/utils/pool" |
||||
) |
||||
|
||||
type Listener interface { |
||||
Accept() (*conn.Conn, error) |
||||
Close() error |
||||
} |
||||
|
||||
type ProxyServer struct { |
||||
config.BaseConf |
||||
BindAddr string |
||||
ListenPort int64 |
||||
CustomDomains []string |
||||
Locations []string |
||||
|
||||
Status int64 |
||||
CtlConn *conn.Conn // control connection with frpc
|
||||
WorkConnUdp *conn.Conn // work connection for udp
|
||||
|
||||
udpConn *net.UDPConn |
||||
listeners []Listener // accept new connection from remote users
|
||||
ctlMsgChan chan int64 // every time accept a new user conn, put "1" to the channel
|
||||
workConnChan chan *conn.Conn // get new work conns from control goroutine
|
||||
udpSenderChan chan *msg.UdpPacket |
||||
mutex sync.RWMutex |
||||
closeChan chan struct{} // close this channel for notifying other goroutines that the proxy is closed
|
||||
} |
||||
|
||||
func NewProxyServer() (p *ProxyServer) { |
||||
p = &ProxyServer{ |
||||
CustomDomains: make([]string, 0), |
||||
Locations: make([]string, 0), |
||||
} |
||||
return p |
||||
} |
||||
|
||||
func NewProxyServerFromCtlMsg(req *msg.ControlReq) (p *ProxyServer) { |
||||
p = &ProxyServer{} |
||||
p.Name = req.ProxyName |
||||
p.Type = req.ProxyType |
||||
p.UseEncryption = req.UseEncryption |
||||
p.UseGzip = req.UseGzip |
||||
p.PrivilegeMode = req.PrivilegeMode |
||||
p.PrivilegeToken = PrivilegeToken |
||||
p.BindAddr = BindAddr |
||||
if p.Type == "tcp" || p.Type == "udp" { |
||||
p.ListenPort = req.RemotePort |
||||
} else if p.Type == "http" { |
||||
p.ListenPort = VhostHttpPort |
||||
} else if p.Type == "https" { |
||||
p.ListenPort = VhostHttpsPort |
||||
} |
||||
p.CustomDomains = req.CustomDomains |
||||
p.SubDomain = req.SubDomain |
||||
p.Locations = req.Locations |
||||
p.HostHeaderRewrite = req.HostHeaderRewrite |
||||
p.HttpUserName = req.HttpUserName |
||||
p.HttpPassWord = req.HttpPassWord |
||||
|
||||
p.Init() |
||||
return |
||||
} |
||||
|
||||
func (p *ProxyServer) Init() { |
||||
p.Lock() |
||||
p.Status = consts.Idle |
||||
metric.SetStatus(p.Name, p.Status) |
||||
p.workConnChan = make(chan *conn.Conn, p.PoolCount+10) |
||||
p.ctlMsgChan = make(chan int64, p.PoolCount+10) |
||||
p.udpSenderChan = make(chan *msg.UdpPacket, 1024) |
||||
p.listeners = make([]Listener, 0) |
||||
p.closeChan = make(chan struct{}) |
||||
p.Unlock() |
||||
} |
||||
|
||||
func (p *ProxyServer) Compare(p2 *ProxyServer) bool { |
||||
if p.Name != p2.Name || p.AuthToken != p2.AuthToken || p.Type != p2.Type || |
||||
p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort || p.HostHeaderRewrite != p2.HostHeaderRewrite { |
||||
return false |
||||
} |
||||
if len(p.CustomDomains) != len(p2.CustomDomains) { |
||||
return false |
||||
} |
||||
for i, _ := range p.CustomDomains { |
||||
if p.CustomDomains[i] != p2.CustomDomains[i] { |
||||
return false |
||||
} |
||||
} |
||||
|
||||
if len(p.Locations) != len(p2.Locations) { |
||||
return false |
||||
} |
||||
for i, _ := range p.Locations { |
||||
if p.Locations[i] != p2.Locations[i] { |
||||
return false |
||||
} |
||||
} |
||||
return true |
||||
} |
||||
|
||||
func (p *ProxyServer) Lock() { |
||||
p.mutex.Lock() |
||||
} |
||||
|
||||
func (p *ProxyServer) Unlock() { |
||||
p.mutex.Unlock() |
||||
} |
||||
|
||||
// start listening for user conns
|
||||
func (p *ProxyServer) Start(c *conn.Conn) (err error) { |
||||
p.CtlConn = c |
||||
p.Init() |
||||
if p.Type == "tcp" { |
||||
l, err := conn.Listen(p.BindAddr, p.ListenPort) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
p.listeners = append(p.listeners, l) |
||||
} else if p.Type == "http" { |
||||
for _, domain := range p.CustomDomains { |
||||
if len(p.Locations) == 0 { |
||||
l, err := VhostHttpMuxer.Listen(domain, "", p.HostHeaderRewrite, p.HttpUserName, p.HttpPassWord) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
log.Info("ProxyName [%s], type http listen for host [%s] location [%s]", p.Name, domain, "") |
||||
p.listeners = append(p.listeners, l) |
||||
} else { |
||||
for _, location := range p.Locations { |
||||
l, err := VhostHttpMuxer.Listen(domain, location, p.HostHeaderRewrite, p.HttpUserName, p.HttpPassWord) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
log.Info("ProxyName [%s], type http listen for host [%s] location [%s]", p.Name, domain, location) |
||||
p.listeners = append(p.listeners, l) |
||||
} |
||||
} |
||||
} |
||||
if p.SubDomain != "" { |
||||
if len(p.Locations) == 0 { |
||||
l, err := VhostHttpMuxer.Listen(p.SubDomain, "", p.HostHeaderRewrite, p.HttpUserName, p.HttpPassWord) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
log.Info("ProxyName [%s], type http listen for host [%s] location [%s]", p.Name, p.SubDomain, "") |
||||
p.listeners = append(p.listeners, l) |
||||
} else { |
||||
for _, location := range p.Locations { |
||||
l, err := VhostHttpMuxer.Listen(p.SubDomain, location, p.HostHeaderRewrite, p.HttpUserName, p.HttpPassWord) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
log.Info("ProxyName [%s], type http listen for host [%s] location [%s]", p.Name, p.SubDomain, location) |
||||
p.listeners = append(p.listeners, l) |
||||
} |
||||
} |
||||
} |
||||
} else if p.Type == "https" { |
||||
for _, domain := range p.CustomDomains { |
||||
l, err := VhostHttpsMuxer.Listen(domain, "", p.HostHeaderRewrite, p.HttpUserName, p.HttpPassWord) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
log.Info("ProxyName [%s], type https listen for host [%s]", p.Name, domain) |
||||
p.listeners = append(p.listeners, l) |
||||
} |
||||
if p.SubDomain != "" { |
||||
l, err := VhostHttpsMuxer.Listen(p.SubDomain, "", p.HostHeaderRewrite, p.HttpUserName, p.HttpPassWord) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
log.Info("ProxyName [%s], type https listen for host [%s]", p.Name, p.SubDomain) |
||||
p.listeners = append(p.listeners, l) |
||||
} |
||||
} |
||||
|
||||
p.Lock() |
||||
p.Status = consts.Working |
||||
p.Unlock() |
||||
metric.SetStatus(p.Name, p.Status) |
||||
|
||||
if p.Type == "udp" { |
||||
// udp is special
|
||||
p.udpConn, err = conn.ListenUDP(p.BindAddr, p.ListenPort) |
||||
if err != nil { |
||||
log.Warn("ProxyName [%s], listen udp port error: %v", p.Name, err) |
||||
return err |
||||
} |
||||
go func() { |
||||
for { |
||||
buf := pool.GetBuf(2048) |
||||
n, remoteAddr, err := p.udpConn.ReadFromUDP(buf) |
||||
if err != nil { |
||||
log.Info("ProxyName [%s], udp listener is closed", p.Name) |
||||
return |
||||
} |
||||
localAddr, _ := net.ResolveUDPAddr("udp", p.udpConn.LocalAddr().String()) |
||||
udpPacket := msg.NewUdpPacket(buf[0:n], remoteAddr, localAddr) |
||||
select { |
||||
case p.udpSenderChan <- udpPacket: |
||||
default: |
||||
log.Warn("ProxyName [%s], udp sender channel is full", p.Name) |
||||
} |
||||
pool.PutBuf(buf) |
||||
} |
||||
}() |
||||
} else { |
||||
// create connection pool if needed
|
||||
if p.PoolCount > 0 { |
||||
go p.connectionPoolManager(p.closeChan) |
||||
} |
||||
|
||||
// start a goroutine for every listener to accept user connection
|
||||
for _, listener := range p.listeners { |
||||
go func(l Listener) { |
||||
for { |
||||
// block
|
||||
// if listener is closed, err returned
|
||||
c, err := l.Accept() |
||||
if err != nil { |
||||
log.Info("ProxyName [%s], listener is closed", p.Name) |
||||
return |
||||
} |
||||
log.Debug("ProxyName [%s], get one new user conn [%s]", p.Name, c.GetRemoteAddr()) |
||||
|
||||
if p.Status != consts.Working { |
||||
log.Debug("ProxyName [%s] is not working, new user conn close", p.Name) |
||||
c.Close() |
||||
return |
||||
} |
||||
|
||||
go func(userConn *conn.Conn) { |
||||
workConn, err := p.getWorkConn() |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
// message will be transferred to another without modifying
|
||||
// l means local, r means remote
|
||||
log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(), |
||||
userConn.GetLocalAddr(), userConn.GetRemoteAddr()) |
||||
|
||||
needRecord := true |
||||
go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord) |
||||
}(c) |
||||
} |
||||
}(listener) |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (p *ProxyServer) Close() { |
||||
p.Lock() |
||||
defer p.Unlock() |
||||
|
||||
oldStatus := p.Status |
||||
p.Release() |
||||
|
||||
// if the proxy created by PrivilegeMode, delete it when closed
|
||||
if p.PrivilegeMode && oldStatus == consts.Working { |
||||
// NOTE: this will take the global ProxyServerMap's lock
|
||||
// if we only want to release resources, use Release() instead
|
||||
DeleteProxy(p.Name) |
||||
} |
||||
} |
||||
|
||||
func (p *ProxyServer) Release() { |
||||
if p.Status != consts.Closed { |
||||
p.Status = consts.Closed |
||||
for _, l := range p.listeners { |
||||
if l != nil { |
||||
l.Close() |
||||
} |
||||
} |
||||
if p.ctlMsgChan != nil { |
||||
close(p.ctlMsgChan) |
||||
p.ctlMsgChan = nil |
||||
} |
||||
if p.workConnChan != nil { |
||||
close(p.workConnChan) |
||||
p.workConnChan = nil |
||||
} |
||||
if p.udpSenderChan != nil { |
||||
close(p.udpSenderChan) |
||||
p.udpSenderChan = nil |
||||
} |
||||
if p.closeChan != nil { |
||||
close(p.closeChan) |
||||
p.closeChan = nil |
||||
} |
||||
if p.CtlConn != nil { |
||||
p.CtlConn.Close() |
||||
} |
||||
if p.WorkConnUdp != nil { |
||||
p.WorkConnUdp.Close() |
||||
} |
||||
if p.udpConn != nil { |
||||
p.udpConn.Close() |
||||
p.udpConn = nil |
||||
} |
||||
} |
||||
metric.SetStatus(p.Name, p.Status) |
||||
} |
||||
|
||||
func (p *ProxyServer) WaitUserConn() (closeFlag bool) { |
||||
closeFlag = false |
||||
|
||||
_, ok := <-p.ctlMsgChan |
||||
if !ok { |
||||
closeFlag = true |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (p *ProxyServer) RegisterNewWorkConn(c *conn.Conn) { |
||||
select { |
||||
case p.workConnChan <- c: |
||||
default: |
||||
log.Debug("ProxyName [%s], workConnChan is full, so close this work connection", p.Name) |
||||
c.Close() |
||||
} |
||||
} |
||||
|
||||
// create a tcp connection for forwarding udp packages
|
||||
func (p *ProxyServer) RegisterNewWorkConnUdp(c *conn.Conn) { |
||||
if p.WorkConnUdp != nil && !p.WorkConnUdp.IsClosed() { |
||||
p.WorkConnUdp.Close() |
||||
} |
||||
p.WorkConnUdp = c |
||||
|
||||
// read
|
||||
go func() { |
||||
var ( |
||||
buf string |
||||
err error |
||||
) |
||||
for { |
||||
buf, err = c.ReadLine() |
||||
if err != nil { |
||||
log.Warn("ProxyName [%s], work connection for udp closed", p.Name) |
||||
return |
||||
} |
||||
udpPacket := &msg.UdpPacket{} |
||||
err = udpPacket.UnPack([]byte(buf)) |
||||
if err != nil { |
||||
log.Warn("ProxyName [%s], unpack udp packet error: %v", p.Name, err) |
||||
continue |
||||
} |
||||
|
||||
// send to user
|
||||
_, err = p.udpConn.WriteToUDP(udpPacket.Content, udpPacket.Dst) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
} |
||||
}() |
||||
|
||||
// write
|
||||
go func() { |
||||
for { |
||||
udpPacket, ok := <-p.udpSenderChan |
||||
if !ok { |
||||
return |
||||
} |
||||
err := c.WriteString(string(udpPacket.Pack()) + "\n") |
||||
if err != nil { |
||||
log.Debug("ProxyName [%s], write to work connection for udp error: %v", p.Name, err) |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
} |
||||
|
||||
// When frps get one user connection, we get one work connection from the pool and return it.
|
||||
// If no workConn available in the pool, send message to frpc to get one or more
|
||||
// and wait until it is available.
|
||||
// return an error if wait timeout
|
||||
func (p *ProxyServer) getWorkConn() (workConn *conn.Conn, err error) { |
||||
var ok bool |
||||
// get a work connection from the pool
|
||||
for { |
||||
select { |
||||
case workConn, ok = <-p.workConnChan: |
||||
if !ok { |
||||
err = fmt.Errorf("ProxyName [%s], no work connections available, control is closing", p.Name) |
||||
return |
||||
} |
||||
log.Debug("ProxyName [%s], get work connection from pool", p.Name) |
||||
default: |
||||
// no work connections available in the poll, send message to frpc to get more
|
||||
p.ctlMsgChan <- 1 |
||||
|
||||
select { |
||||
case workConn, ok = <-p.workConnChan: |
||||
if !ok { |
||||
err = fmt.Errorf("ProxyName [%s], no work connections available, control is closing", p.Name) |
||||
return |
||||
} |
||||
|
||||
case <-time.After(time.Duration(UserConnTimeout) * time.Second): |
||||
log.Warn("ProxyName [%s], timeout trying to get work connection", p.Name) |
||||
err = fmt.Errorf("ProxyName [%s], timeout trying to get work connection", p.Name) |
||||
return |
||||
} |
||||
} |
||||
|
||||
// if connection pool is not used, we don't check the status
|
||||
// function CheckClosed will consume at least 1 millisecond if the connection isn't closed
|
||||
if p.PoolCount == 0 || !workConn.CheckClosed() { |
||||
break |
||||
} else { |
||||
log.Warn("ProxyName [%s], connection got from pool, but it's already closed", p.Name) |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (p *ProxyServer) connectionPoolManager(closeCh <-chan struct{}) { |
||||
defer func() { |
||||
if r := recover(); r != nil { |
||||
log.Warn("ProxyName [%s], connectionPoolManager panic %v", p.Name, r) |
||||
} |
||||
}() |
||||
|
||||
for { |
||||
// check if we need more work connections and send messages to frpc to get more
|
||||
time.Sleep(time.Duration(2) * time.Second) |
||||
select { |
||||
// if the channel closed, it means the proxy is closed, so just return
|
||||
case <-closeCh: |
||||
log.Info("ProxyName [%s], connectionPoolManager exit", p.Name) |
||||
return |
||||
default: |
||||
curWorkConnNum := int64(len(p.workConnChan)) |
||||
diff := p.PoolCount - curWorkConnNum |
||||
if diff > 0 { |
||||
if diff < p.PoolCount/5 { |
||||
diff = p.PoolCount*4/5 + 1 |
||||
} else if diff < p.PoolCount/2 { |
||||
diff = p.PoolCount/4 + 1 |
||||
} else if diff < p.PoolCount*4/5 { |
||||
diff = p.PoolCount/5 + 1 |
||||
} else { |
||||
diff = p.PoolCount/10 + 1 |
||||
} |
||||
if diff+curWorkConnNum > p.PoolCount { |
||||
diff = p.PoolCount - curWorkConnNum |
||||
} |
||||
for i := 0; i < int(diff); i++ { |
||||
p.ctlMsgChan <- 1 |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
@ -1,87 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package broadcast |
||||
|
||||
type Broadcast struct { |
||||
listeners []chan interface{} |
||||
reg chan (chan interface{}) |
||||
unreg chan (chan interface{}) |
||||
in chan interface{} |
||||
stop chan int64 |
||||
stopStatus bool |
||||
} |
||||
|
||||
func NewBroadcast() *Broadcast { |
||||
b := &Broadcast{ |
||||
listeners: make([]chan interface{}, 0), |
||||
reg: make(chan (chan interface{})), |
||||
unreg: make(chan (chan interface{})), |
||||
in: make(chan interface{}), |
||||
stop: make(chan int64), |
||||
stopStatus: false, |
||||
} |
||||
|
||||
go func() { |
||||
for { |
||||
select { |
||||
case l := <-b.unreg: |
||||
// remove L from b.listeners
|
||||
// this operation is slow: O(n) but not used frequently
|
||||
// unlike iterating over listeners
|
||||
oldListeners := b.listeners |
||||
b.listeners = make([]chan interface{}, 0, len(oldListeners)) |
||||
for _, oldL := range oldListeners { |
||||
if l != oldL { |
||||
b.listeners = append(b.listeners, oldL) |
||||
} |
||||
} |
||||
|
||||
case l := <-b.reg: |
||||
b.listeners = append(b.listeners, l) |
||||
|
||||
case item := <-b.in: |
||||
for _, l := range b.listeners { |
||||
l <- item |
||||
} |
||||
|
||||
case _ = <-b.stop: |
||||
b.stopStatus = true |
||||
break |
||||
} |
||||
} |
||||
}() |
||||
|
||||
return b |
||||
} |
||||
|
||||
func (b *Broadcast) In() chan interface{} { |
||||
return b.in |
||||
} |
||||
|
||||
func (b *Broadcast) Reg() chan interface{} { |
||||
listener := make(chan interface{}) |
||||
b.reg <- listener |
||||
return listener |
||||
} |
||||
|
||||
func (b *Broadcast) UnReg(listener chan interface{}) { |
||||
b.unreg <- listener |
||||
} |
||||
|
||||
func (b *Broadcast) Close() { |
||||
if b.stopStatus == false { |
||||
b.stop <- 1 |
||||
} |
||||
} |
@ -1,77 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package broadcast |
||||
|
||||
import ( |
||||
"sync" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
var ( |
||||
totalNum int = 5 |
||||
succNum int = 0 |
||||
mutex sync.Mutex |
||||
) |
||||
|
||||
func TestBroadcast(t *testing.T) { |
||||
b := NewBroadcast() |
||||
if b == nil { |
||||
t.Fatalf("New Broadcast error, nil return") |
||||
} |
||||
defer b.Close() |
||||
|
||||
var wait sync.WaitGroup |
||||
wait.Add(totalNum) |
||||
for i := 0; i < totalNum; i++ { |
||||
go worker(b, &wait) |
||||
} |
||||
|
||||
time.Sleep(1e6 * 20) |
||||
msg := "test" |
||||
b.In() <- msg |
||||
|
||||
wait.Wait() |
||||
if succNum != totalNum { |
||||
t.Fatalf("TotalNum %d, FailNum(timeout) %d", totalNum, totalNum-succNum) |
||||
} |
||||
} |
||||
|
||||
func worker(b *Broadcast, wait *sync.WaitGroup) { |
||||
defer wait.Done() |
||||
msgChan := b.Reg() |
||||
|
||||
// exit if nothing got in 2 seconds
|
||||
timeout := make(chan bool, 1) |
||||
go func() { |
||||
time.Sleep(time.Duration(2) * time.Second) |
||||
timeout <- true |
||||
}() |
||||
|
||||
select { |
||||
case item := <-msgChan: |
||||
msg := item.(string) |
||||
if msg == "test" { |
||||
mutex.Lock() |
||||
succNum++ |
||||
mutex.Unlock() |
||||
} else { |
||||
break |
||||
} |
||||
|
||||
case <-timeout: |
||||
break |
||||
} |
||||
} |
@ -1,314 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package conn |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"encoding/base64" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"net/http" |
||||
"net/url" |
||||
"strings" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/fatedier/frp/src/utils/pool" |
||||
) |
||||
|
||||
type Listener struct { |
||||
addr net.Addr |
||||
l *net.TCPListener |
||||
accept chan *Conn |
||||
closeFlag bool |
||||
} |
||||
|
||||
func Listen(bindAddr string, bindPort int64) (l *Listener, err error) { |
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) |
||||
if err != nil { |
||||
return l, err |
||||
} |
||||
listener, err := net.ListenTCP("tcp", tcpAddr) |
||||
if err != nil { |
||||
return l, err |
||||
} |
||||
|
||||
l = &Listener{ |
||||
addr: listener.Addr(), |
||||
l: listener, |
||||
accept: make(chan *Conn), |
||||
closeFlag: false, |
||||
} |
||||
|
||||
go func() { |
||||
for { |
||||
conn, err := l.l.AcceptTCP() |
||||
if err != nil { |
||||
if l.closeFlag { |
||||
return |
||||
} |
||||
continue |
||||
} |
||||
|
||||
c := NewConn(conn) |
||||
l.accept <- c |
||||
} |
||||
}() |
||||
return l, err |
||||
} |
||||
|
||||
// wait util get one new connection or listener is closed
|
||||
// if listener is closed, err returned
|
||||
func (l *Listener) Accept() (*Conn, error) { |
||||
conn, ok := <-l.accept |
||||
if !ok { |
||||
return conn, fmt.Errorf("channel close") |
||||
} |
||||
return conn, nil |
||||
} |
||||
|
||||
func (l *Listener) Close() error { |
||||
if l.l != nil && l.closeFlag == false { |
||||
l.closeFlag = true |
||||
l.l.Close() |
||||
close(l.accept) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// wrap for TCPConn
|
||||
type Conn struct { |
||||
TcpConn net.Conn |
||||
Reader *bufio.Reader |
||||
buffer *bytes.Buffer |
||||
closeFlag bool |
||||
|
||||
mutex sync.RWMutex |
||||
} |
||||
|
||||
func NewConn(conn net.Conn) (c *Conn) { |
||||
c = &Conn{ |
||||
TcpConn: conn, |
||||
buffer: nil, |
||||
closeFlag: false, |
||||
} |
||||
c.Reader = bufio.NewReader(c.TcpConn) |
||||
return |
||||
} |
||||
|
||||
func ConnectServer(addr string) (c *Conn, err error) { |
||||
servertAddr, err := net.ResolveTCPAddr("tcp", addr) |
||||
if err != nil { |
||||
return |
||||
} |
||||
conn, err := net.DialTCP("tcp", nil, servertAddr) |
||||
if err != nil { |
||||
return |
||||
} |
||||
c = NewConn(conn) |
||||
return c, nil |
||||
} |
||||
|
||||
func ConnectServerByHttpProxy(httpProxy string, serverAddr string) (c *Conn, err error) { |
||||
var proxyUrl *url.URL |
||||
if proxyUrl, err = url.Parse(httpProxy); err != nil { |
||||
return |
||||
} |
||||
|
||||
var proxyAuth string |
||||
if proxyUrl.User != nil { |
||||
proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(proxyUrl.User.String())) |
||||
} |
||||
|
||||
if proxyUrl.Scheme != "http" { |
||||
err = fmt.Errorf("Proxy URL scheme must be http, not [%s]", proxyUrl.Scheme) |
||||
return |
||||
} |
||||
|
||||
if c, err = ConnectServer(proxyUrl.Host); err != nil { |
||||
return |
||||
} |
||||
|
||||
req, err := http.NewRequest("CONNECT", "http://"+serverAddr, nil) |
||||
if err != nil { |
||||
return |
||||
} |
||||
if proxyAuth != "" { |
||||
req.Header.Set("Proxy-Authorization", proxyAuth) |
||||
} |
||||
req.Header.Set("User-Agent", "Mozilla/5.0") |
||||
req.Write(c.TcpConn) |
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(c), req) |
||||
if err != nil { |
||||
return |
||||
} |
||||
resp.Body.Close() |
||||
if resp.StatusCode != 200 { |
||||
err = fmt.Errorf("ConnectServer using proxy error, StatusCode [%d]", resp.StatusCode) |
||||
return |
||||
} |
||||
|
||||
return |
||||
} |
||||
|
||||
// if the tcpConn is different with c.TcpConn
|
||||
// you should call c.Close() first
|
||||
func (c *Conn) SetTcpConn(tcpConn net.Conn) { |
||||
c.mutex.Lock() |
||||
defer c.mutex.Unlock() |
||||
c.TcpConn = tcpConn |
||||
c.closeFlag = false |
||||
c.Reader = bufio.NewReader(c.TcpConn) |
||||
} |
||||
|
||||
func (c *Conn) GetRemoteAddr() (addr string) { |
||||
return c.TcpConn.RemoteAddr().String() |
||||
} |
||||
|
||||
func (c *Conn) GetLocalAddr() (addr string) { |
||||
return c.TcpConn.LocalAddr().String() |
||||
} |
||||
|
||||
func (c *Conn) Read(p []byte) (n int, err error) { |
||||
c.mutex.RLock() |
||||
if c.buffer == nil { |
||||
c.mutex.RUnlock() |
||||
return c.Reader.Read(p) |
||||
} |
||||
c.mutex.RUnlock() |
||||
|
||||
n, err = c.buffer.Read(p) |
||||
if err == io.EOF { |
||||
c.mutex.Lock() |
||||
c.buffer = nil |
||||
c.mutex.Unlock() |
||||
var n2 int |
||||
n2, err = c.Reader.Read(p[n:]) |
||||
|
||||
n += n2 |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (c *Conn) ReadLine() (buff string, err error) { |
||||
buff, err = c.Reader.ReadString('\n') |
||||
if err != nil { |
||||
// wsarecv error in windows means connection closed?
|
||||
if err == io.EOF || strings.Contains(err.Error(), "wsarecv") { |
||||
c.mutex.Lock() |
||||
c.closeFlag = true |
||||
c.mutex.Unlock() |
||||
} |
||||
} |
||||
return buff, err |
||||
} |
||||
|
||||
func (c *Conn) Write(content []byte) (n int, err error) { |
||||
n, err = c.TcpConn.Write(content) |
||||
return |
||||
} |
||||
|
||||
func (c *Conn) WriteString(content string) (err error) { |
||||
_, err = c.TcpConn.Write([]byte(content)) |
||||
return err |
||||
} |
||||
|
||||
func (c *Conn) AppendReaderBuffer(content []byte) { |
||||
c.mutex.Lock() |
||||
defer c.mutex.Unlock() |
||||
|
||||
if c.buffer == nil { |
||||
c.buffer = bytes.NewBuffer(make([]byte, 0, 2048)) |
||||
} |
||||
c.buffer.Write(content) |
||||
} |
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error { |
||||
return c.TcpConn.SetDeadline(t) |
||||
} |
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error { |
||||
return c.TcpConn.SetReadDeadline(t) |
||||
} |
||||
|
||||
func (c *Conn) Close() error { |
||||
c.mutex.Lock() |
||||
defer c.mutex.Unlock() |
||||
if c.TcpConn != nil && c.closeFlag == false { |
||||
c.closeFlag = true |
||||
c.TcpConn.Close() |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (c *Conn) IsClosed() (closeFlag bool) { |
||||
c.mutex.RLock() |
||||
defer c.mutex.RUnlock() |
||||
closeFlag = c.closeFlag |
||||
return |
||||
} |
||||
|
||||
// when you call this function, you should make sure that
|
||||
// no bytes were read before
|
||||
func (c *Conn) CheckClosed() bool { |
||||
c.mutex.RLock() |
||||
if c.closeFlag { |
||||
c.mutex.RUnlock() |
||||
return true |
||||
} |
||||
c.mutex.RUnlock() |
||||
|
||||
tmp := pool.GetBuf(2048) |
||||
defer pool.PutBuf(tmp) |
||||
err := c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond)) |
||||
if err != nil { |
||||
c.Close() |
||||
return true |
||||
} |
||||
|
||||
n, err := c.TcpConn.Read(tmp) |
||||
if err == io.EOF { |
||||
return true |
||||
} |
||||
|
||||
var tmp2 []byte = make([]byte, 1) |
||||
err = c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond)) |
||||
if err != nil { |
||||
c.Close() |
||||
return true |
||||
} |
||||
|
||||
n2, err := c.TcpConn.Read(tmp2) |
||||
if err == io.EOF { |
||||
return true |
||||
} |
||||
|
||||
err = c.TcpConn.SetReadDeadline(time.Time{}) |
||||
if err != nil { |
||||
c.Close() |
||||
return true |
||||
} |
||||
|
||||
if n > 0 { |
||||
c.AppendReaderBuffer(tmp[:n]) |
||||
} |
||||
if n2 > 0 { |
||||
c.AppendReaderBuffer(tmp2[:n2]) |
||||
} |
||||
return false |
||||
} |
@ -1,139 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pcrypto |
||||
|
||||
import ( |
||||
"bytes" |
||||
"compress/gzip" |
||||
"crypto/aes" |
||||
"crypto/cipher" |
||||
"crypto/md5" |
||||
"crypto/rand" |
||||
"encoding/hex" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
) |
||||
|
||||
type Pcrypto struct { |
||||
pkey []byte |
||||
paes cipher.Block |
||||
} |
||||
|
||||
func (pc *Pcrypto) Init(key []byte) error { |
||||
var err error |
||||
pc.pkey = pkKeyPadding(key) |
||||
pc.paes, err = aes.NewCipher(pc.pkey) |
||||
return err |
||||
} |
||||
|
||||
func (pc *Pcrypto) Encrypt(src []byte) ([]byte, error) { |
||||
// aes
|
||||
src = pKCS5Padding(src, aes.BlockSize) |
||||
ciphertext := make([]byte, aes.BlockSize+len(src)) |
||||
|
||||
// The IV needs to be unique, but not secure. Therefore it's common to
|
||||
// include it at the beginning of the ciphertext.
|
||||
iv := ciphertext[:aes.BlockSize] |
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil { |
||||
return nil, err |
||||
} |
||||
blockMode := cipher.NewCBCEncrypter(pc.paes, iv) |
||||
blockMode.CryptBlocks(ciphertext[aes.BlockSize:], src) |
||||
return ciphertext, nil |
||||
} |
||||
|
||||
func (pc *Pcrypto) Decrypt(str []byte) ([]byte, error) { |
||||
// aes
|
||||
ciphertext, err := hex.DecodeString(fmt.Sprintf("%x", str)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if len(ciphertext) < aes.BlockSize { |
||||
return nil, fmt.Errorf("ciphertext too short") |
||||
} |
||||
iv := ciphertext[:aes.BlockSize] |
||||
ciphertext = ciphertext[aes.BlockSize:] |
||||
|
||||
if len(ciphertext)%aes.BlockSize != 0 { |
||||
return nil, fmt.Errorf("crypto/cipher: ciphertext is not a multiple of the block size") |
||||
} |
||||
|
||||
blockMode := cipher.NewCBCDecrypter(pc.paes, iv) |
||||
blockMode.CryptBlocks(ciphertext, ciphertext) |
||||
return pKCS5UnPadding(ciphertext), nil |
||||
} |
||||
|
||||
func (pc *Pcrypto) Compression(src []byte) ([]byte, error) { |
||||
var zbuf bytes.Buffer |
||||
zwr, err := gzip.NewWriterLevel(&zbuf, gzip.DefaultCompression) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
defer zwr.Close() |
||||
zwr.Write(src) |
||||
zwr.Flush() |
||||
return zbuf.Bytes(), nil |
||||
} |
||||
|
||||
func (pc *Pcrypto) Decompression(src []byte) ([]byte, error) { |
||||
zbuf := bytes.NewBuffer(src) |
||||
zrd, err := gzip.NewReader(zbuf) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
defer zrd.Close() |
||||
str, _ := ioutil.ReadAll(zrd) |
||||
return str, nil |
||||
} |
||||
|
||||
func pkKeyPadding(key []byte) []byte { |
||||
l := len(key) |
||||
if l == 16 || l == 24 || l == 32 { |
||||
return key |
||||
} |
||||
if l < 16 { |
||||
return append(key, bytes.Repeat([]byte{byte(0)}, 16-l)...) |
||||
} else if l < 24 { |
||||
return append(key, bytes.Repeat([]byte{byte(0)}, 24-l)...) |
||||
} else if l < 32 { |
||||
return append(key, bytes.Repeat([]byte{byte(0)}, 32-l)...) |
||||
} else { |
||||
md5Ctx := md5.New() |
||||
md5Ctx.Write(key) |
||||
md5Str := md5Ctx.Sum(nil) |
||||
return []byte(hex.EncodeToString(md5Str)) |
||||
} |
||||
} |
||||
|
||||
func pKCS5Padding(ciphertext []byte, blockSize int) []byte { |
||||
padding := blockSize - len(ciphertext)%blockSize |
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding) |
||||
return append(ciphertext, padtext...) |
||||
} |
||||
|
||||
func pKCS5UnPadding(origData []byte) []byte { |
||||
length := len(origData) |
||||
unpadding := int(origData[length-1]) |
||||
return origData[:(length - unpadding)] |
||||
} |
||||
|
||||
func GetAuthKey(str string) (authKey string) { |
||||
md5Ctx := md5.New() |
||||
md5Ctx.Write([]byte(str)) |
||||
md5Str := md5Ctx.Sum(nil) |
||||
return hex.EncodeToString(md5Str) |
||||
} |
@ -1,77 +0,0 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pcrypto |
||||
|
||||
import ( |
||||
"testing" |
||||
) |
||||
|
||||
var ( |
||||
pp *Pcrypto |
||||
) |
||||
|
||||
func init() { |
||||
pp = &Pcrypto{} |
||||
pp.Init([]byte("12234567890123451223456789012345321:wq")) |
||||
} |
||||
|
||||
func TestEncrypt(t *testing.T) { |
||||
testStr := "Test Encrypt!" |
||||
res, err := pp.Encrypt([]byte(testStr)) |
||||
if err != nil { |
||||
t.Fatalf("encrypt error: %v", err) |
||||
} |
||||
|
||||
res, err = pp.Decrypt([]byte(res)) |
||||
if err != nil { |
||||
t.Fatalf("decrypt error: %v", err) |
||||
} |
||||
|
||||
if string(res) != testStr { |
||||
t.Fatalf("test encrypt error, from [%s] to [%s]", testStr, string(res)) |
||||
} |
||||
} |
||||
|
||||
func TestCompression(t *testing.T) { |
||||
testStr := "Test Compression!" |
||||
res, err := pp.Compression([]byte(testStr)) |
||||
if err != nil { |
||||
t.Fatalf("compression error: %v", err) |
||||
} |
||||
|
||||
res, err = pp.Decompression(res) |
||||
if err != nil { |
||||
t.Fatalf("decompression error: %v", err) |
||||
} |
||||
|
||||
if string(res) != testStr { |
||||
t.Fatalf("test compression error, from [%s] to [%s]", testStr, string(res)) |
||||
} |
||||
} |
||||
|
||||
func BenchmarkEncrypt(b *testing.B) { |
||||
testStr := "Test Encrypt!" |
||||
for i := 0; i < b.N; i++ { |
||||
pp.Encrypt([]byte(testStr)) |
||||
} |
||||
} |
||||
|
||||
func BenchmarkDecrypt(b *testing.B) { |
||||
testStr := "Test Encrypt!" |
||||
res, _ := pp.Encrypt([]byte(testStr)) |
||||
for i := 0; i < b.N; i++ { |
||||
pp.Decrypt([]byte(res)) |
||||
} |
||||
} |
@ -0,0 +1,52 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package crypto |
||||
|
||||
import ( |
||||
"bytes" |
||||
"io" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func TestWriter(t *testing.T) { |
||||
// Empty key.
|
||||
assert := assert.New(t) |
||||
key := "" |
||||
buffer := bytes.NewBuffer(nil) |
||||
_, err := NewWriter(buffer, []byte(key)) |
||||
assert.NoError(err) |
||||
} |
||||
|
||||
func TestCrypto(t *testing.T) { |
||||
assert := assert.New(t) |
||||
|
||||
text := "1234567890abcdefghigklmnopqrstuvwxyzeeeeeeeeeeeeeeeeeeeeeewwwwwwwwwwwwwwwwwwwwwwwwwwzzzzzzzzzzzzzzzzzzzzzzzzdddddddddddddddddddddddddddddddddddddrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrllllllllllllllllllllllllllllllllllqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeewwwwwwwwwwwwwwwwwwwwww" |
||||
key := "123456" |
||||
|
||||
buffer := bytes.NewBuffer(nil) |
||||
encWriter, err := NewWriter(buffer, []byte(key)) |
||||
assert.NoError(err) |
||||
|
||||
encWriter.Write([]byte(text)) |
||||
|
||||
decReader, err := NewReader(buffer, []byte(key)) |
||||
assert.NoError(err) |
||||
|
||||
c := bytes.NewBuffer(nil) |
||||
io.Copy(c, decReader) |
||||
assert.Equal(text, string(c.Bytes())) |
||||
} |
@ -0,0 +1,75 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package crypto |
||||
|
||||
import ( |
||||
"crypto/aes" |
||||
"crypto/cipher" |
||||
"crypto/sha1" |
||||
"io" |
||||
|
||||
"golang.org/x/crypto/pbkdf2" |
||||
) |
||||
|
||||
// NewReader returns a new Reader that decrypts bytes from r
|
||||
func NewReader(r io.Reader, key []byte) (*Reader, error) { |
||||
key = pbkdf2.Key(key, []byte(salt), 64, aes.BlockSize, sha1.New) |
||||
|
||||
return &Reader{ |
||||
r: r, |
||||
key: key, |
||||
}, nil |
||||
} |
||||
|
||||
// Reader is an io.Reader that can read encrypted bytes.
|
||||
// Now it only supports aes-128-cfb.
|
||||
type Reader struct { |
||||
r io.Reader |
||||
dec *cipher.StreamReader |
||||
key []byte |
||||
iv []byte |
||||
err error |
||||
} |
||||
|
||||
// Read satisfies the io.Reader interface.
|
||||
func (r *Reader) Read(p []byte) (nRet int, errRet error) { |
||||
if r.err != nil { |
||||
return 0, r.err |
||||
} |
||||
|
||||
if r.dec == nil { |
||||
iv := make([]byte, aes.BlockSize) |
||||
if _, errRet = io.ReadFull(r.r, iv); errRet != nil { |
||||
return |
||||
} |
||||
r.iv = iv |
||||
|
||||
block, err := aes.NewCipher(r.key) |
||||
if err != nil { |
||||
errRet = err |
||||
return |
||||
} |
||||
r.dec = &cipher.StreamReader{ |
||||
S: cipher.NewCFBDecrypter(block, iv), |
||||
R: r.r, |
||||
} |
||||
} |
||||
|
||||
nRet, errRet = r.dec.Read(p) |
||||
if errRet != nil { |
||||
r.err = errRet |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,93 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package crypto |
||||
|
||||
import ( |
||||
"crypto/aes" |
||||
"crypto/cipher" |
||||
"crypto/rand" |
||||
"crypto/sha1" |
||||
"io" |
||||
|
||||
"golang.org/x/crypto/pbkdf2" |
||||
) |
||||
|
||||
const ( |
||||
salt = "frp" |
||||
) |
||||
|
||||
// NewWriter returns a new Writer that encrypts bytes to w.
|
||||
func NewWriter(w io.Writer, key []byte) (*Writer, error) { |
||||
key = pbkdf2.Key(key, []byte(salt), 64, aes.BlockSize, sha1.New) |
||||
|
||||
// random iv
|
||||
iv := make([]byte, aes.BlockSize) |
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
block, err := aes.NewCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &Writer{ |
||||
w: w, |
||||
enc: &cipher.StreamWriter{ |
||||
S: cipher.NewCFBEncrypter(block, iv), |
||||
W: w, |
||||
}, |
||||
key: key, |
||||
iv: iv, |
||||
}, nil |
||||
} |
||||
|
||||
// Writer is an io.Writer that can write encrypted bytes.
|
||||
// Now it only support aes-128-cfb.
|
||||
type Writer struct { |
||||
w io.Writer |
||||
enc *cipher.StreamWriter |
||||
key []byte |
||||
iv []byte |
||||
ivSend bool |
||||
err error |
||||
} |
||||
|
||||
// Write satisfies the io.Writer interface.
|
||||
func (w *Writer) Write(p []byte) (nRet int, errRet error) { |
||||
return w.write(p) |
||||
} |
||||
|
||||
func (w *Writer) write(p []byte) (nRet int, errRet error) { |
||||
if w.err != nil { |
||||
return 0, w.err |
||||
} |
||||
|
||||
// When write is first called, iv will be written to w.w
|
||||
if !w.ivSend { |
||||
w.ivSend = true |
||||
_, errRet = w.w.Write(w.iv) |
||||
if errRet != nil { |
||||
w.err = errRet |
||||
return |
||||
} |
||||
} |
||||
|
||||
nRet, errRet = w.enc.Write(p) |
||||
if errRet != nil { |
||||
w.err = errRet |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,162 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package net |
||||
|
||||
import ( |
||||
"bufio" |
||||
"encoding/base64" |
||||
"fmt" |
||||
"net" |
||||
"net/http" |
||||
"net/url" |
||||
|
||||
"github.com/fatedier/frp/utils/log" |
||||
) |
||||
|
||||
type TcpListener struct { |
||||
net.Addr |
||||
listener net.Listener |
||||
accept chan Conn |
||||
closeFlag bool |
||||
} |
||||
|
||||
func ListenTcp(bindAddr string, bindPort int64) (l *TcpListener, err error) { |
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) |
||||
if err != nil { |
||||
return l, err |
||||
} |
||||
listener, err := net.ListenTCP("tcp", tcpAddr) |
||||
if err != nil { |
||||
return l, err |
||||
} |
||||
|
||||
l = &TcpListener{ |
||||
Addr: listener.Addr(), |
||||
listener: listener, |
||||
accept: make(chan Conn), |
||||
closeFlag: false, |
||||
} |
||||
|
||||
go func() { |
||||
for { |
||||
conn, err := listener.AcceptTCP() |
||||
if err != nil { |
||||
if l.closeFlag { |
||||
close(l.accept) |
||||
return |
||||
} |
||||
continue |
||||
} |
||||
|
||||
c := NewTcpConn(conn) |
||||
l.accept <- c |
||||
} |
||||
}() |
||||
return l, err |
||||
} |
||||
|
||||
// Wait util get one new connection or listener is closed
|
||||
// if listener is closed, err returned.
|
||||
func (l *TcpListener) Accept() (Conn, error) { |
||||
conn, ok := <-l.accept |
||||
if !ok { |
||||
return conn, fmt.Errorf("channel for tcp listener closed") |
||||
} |
||||
return conn, nil |
||||
} |
||||
|
||||
func (l *TcpListener) Close() error { |
||||
if !l.closeFlag { |
||||
l.closeFlag = true |
||||
l.listener.Close() |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Wrap for TCPConn.
|
||||
type TcpConn struct { |
||||
net.Conn |
||||
log.Logger |
||||
} |
||||
|
||||
func NewTcpConn(conn *net.TCPConn) (c *TcpConn) { |
||||
c = &TcpConn{ |
||||
Conn: conn, |
||||
Logger: log.NewPrefixLogger(""), |
||||
} |
||||
return |
||||
} |
||||
|
||||
func ConnectTcpServer(addr string) (c Conn, err error) { |
||||
servertAddr, err := net.ResolveTCPAddr("tcp", addr) |
||||
if err != nil { |
||||
return |
||||
} |
||||
conn, err := net.DialTCP("tcp", nil, servertAddr) |
||||
if err != nil { |
||||
return |
||||
} |
||||
c = NewTcpConn(conn) |
||||
return |
||||
} |
||||
|
||||
// ConnectTcpServerByHttpProxy try to connect remote server by http proxy.
|
||||
// If httpProxy is empty, it will connect server directly.
|
||||
func ConnectTcpServerByHttpProxy(httpProxy string, serverAddr string) (c Conn, err error) { |
||||
if httpProxy == "" { |
||||
return ConnectTcpServer(serverAddr) |
||||
} |
||||
|
||||
var proxyUrl *url.URL |
||||
if proxyUrl, err = url.Parse(httpProxy); err != nil { |
||||
return |
||||
} |
||||
|
||||
var proxyAuth string |
||||
if proxyUrl.User != nil { |
||||
proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(proxyUrl.User.String())) |
||||
} |
||||
|
||||
if proxyUrl.Scheme != "http" { |
||||
err = fmt.Errorf("Proxy URL scheme must be http, not [%s]", proxyUrl.Scheme) |
||||
return |
||||
} |
||||
|
||||
if c, err = ConnectTcpServer(proxyUrl.Host); err != nil { |
||||
return |
||||
} |
||||
|
||||
req, err := http.NewRequest("CONNECT", "http://"+serverAddr, nil) |
||||
if err != nil { |
||||
return |
||||
} |
||||
if proxyAuth != "" { |
||||
req.Header.Set("Proxy-Authorization", proxyAuth) |
||||
} |
||||
req.Header.Set("User-Agent", "Mozilla/5.0") |
||||
req.Write(c) |
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(c), req) |
||||
if err != nil { |
||||
return |
||||
} |
||||
resp.Body.Close() |
||||
if resp.StatusCode != 200 { |
||||
err = fmt.Errorf("ConnectTcpServer using proxy error, StatusCode [%d]", resp.StatusCode) |
||||
return |
||||
} |
||||
|
||||
return |
||||
} |
@ -0,0 +1,243 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package net |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"sync" |
||||
"time" |
||||
|
||||
flog "github.com/fatedier/frp/utils/log" |
||||
"github.com/fatedier/frp/utils/pool" |
||||
) |
||||
|
||||
type UdpPacket struct { |
||||
Buf []byte |
||||
LocalAddr net.Addr |
||||
RemoteAddr net.Addr |
||||
} |
||||
|
||||
type FakeUdpConn struct { |
||||
flog.Logger |
||||
l *UdpListener |
||||
|
||||
localAddr net.Addr |
||||
remoteAddr net.Addr |
||||
packets chan []byte |
||||
closeFlag bool |
||||
|
||||
lastActive time.Time |
||||
mu sync.RWMutex |
||||
} |
||||
|
||||
func NewFakeUdpConn(l *UdpListener, laddr, raddr net.Addr) *FakeUdpConn { |
||||
fc := &FakeUdpConn{ |
||||
Logger: flog.NewPrefixLogger(""), |
||||
l: l, |
||||
localAddr: laddr, |
||||
remoteAddr: raddr, |
||||
packets: make(chan []byte, 20), |
||||
} |
||||
|
||||
go func() { |
||||
for { |
||||
time.Sleep(5 * time.Second) |
||||
fc.mu.RLock() |
||||
if time.Now().Sub(fc.lastActive) > 10*time.Second { |
||||
fc.mu.RUnlock() |
||||
fc.Close() |
||||
break |
||||
} |
||||
fc.mu.RUnlock() |
||||
} |
||||
}() |
||||
return fc |
||||
} |
||||
|
||||
func (c *FakeUdpConn) putPacket(content []byte) { |
||||
defer func() { |
||||
if err := recover(); err != nil { |
||||
} |
||||
}() |
||||
|
||||
select { |
||||
case c.packets <- content: |
||||
default: |
||||
} |
||||
} |
||||
|
||||
func (c *FakeUdpConn) Read(b []byte) (n int, err error) { |
||||
content, ok := <-c.packets |
||||
if !ok { |
||||
return 0, io.EOF |
||||
} |
||||
c.mu.Lock() |
||||
c.lastActive = time.Now() |
||||
c.mu.Unlock() |
||||
|
||||
if len(b) < len(content) { |
||||
n = len(b) |
||||
} else { |
||||
n = len(content) |
||||
} |
||||
copy(b, content) |
||||
return n, nil |
||||
} |
||||
|
||||
func (c *FakeUdpConn) Write(b []byte) (n int, err error) { |
||||
c.mu.RLock() |
||||
if c.closeFlag { |
||||
c.mu.RUnlock() |
||||
return 0, io.ErrClosedPipe |
||||
} |
||||
c.mu.RUnlock() |
||||
|
||||
packet := &UdpPacket{ |
||||
Buf: b, |
||||
LocalAddr: c.localAddr, |
||||
RemoteAddr: c.remoteAddr, |
||||
} |
||||
c.l.writeUdpPacket(packet) |
||||
|
||||
c.mu.Lock() |
||||
c.lastActive = time.Now() |
||||
c.mu.Unlock() |
||||
return len(b), nil |
||||
} |
||||
|
||||
func (c *FakeUdpConn) Close() error { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
if !c.closeFlag { |
||||
c.closeFlag = true |
||||
close(c.packets) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (c *FakeUdpConn) IsClosed() bool { |
||||
c.mu.RLock() |
||||
defer c.mu.RUnlock() |
||||
return c.closeFlag |
||||
} |
||||
|
||||
func (c *FakeUdpConn) LocalAddr() net.Addr { |
||||
return c.localAddr |
||||
} |
||||
|
||||
func (c *FakeUdpConn) RemoteAddr() net.Addr { |
||||
return c.remoteAddr |
||||
} |
||||
|
||||
func (c *FakeUdpConn) SetDeadline(t time.Time) error { |
||||
return nil |
||||
} |
||||
|
||||
func (c *FakeUdpConn) SetReadDeadline(t time.Time) error { |
||||
return nil |
||||
} |
||||
|
||||
func (c *FakeUdpConn) SetWriteDeadline(t time.Time) error { |
||||
return nil |
||||
} |
||||
|
||||
type UdpListener struct { |
||||
net.Addr |
||||
accept chan Conn |
||||
writeCh chan *UdpPacket |
||||
readConn net.Conn |
||||
closeFlag bool |
||||
|
||||
fakeConns map[string]*FakeUdpConn |
||||
} |
||||
|
||||
func ListenUDP(bindAddr string, bindPort int64) (l *UdpListener, err error) { |
||||
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) |
||||
if err != nil { |
||||
return l, err |
||||
} |
||||
readConn, err := net.ListenUDP("udp", udpAddr) |
||||
|
||||
l = &UdpListener{ |
||||
Addr: udpAddr, |
||||
accept: make(chan Conn), |
||||
writeCh: make(chan *UdpPacket, 1000), |
||||
fakeConns: make(map[string]*FakeUdpConn), |
||||
} |
||||
|
||||
// for reading
|
||||
go func() { |
||||
for { |
||||
buf := pool.GetBuf(1450) |
||||
n, remoteAddr, err := readConn.ReadFromUDP(buf) |
||||
if err != nil { |
||||
close(l.accept) |
||||
close(l.writeCh) |
||||
return |
||||
} |
||||
|
||||
fakeConn, exist := l.fakeConns[remoteAddr.String()] |
||||
if !exist || fakeConn.IsClosed() { |
||||
fakeConn = NewFakeUdpConn(l, l.Addr, remoteAddr) |
||||
l.fakeConns[remoteAddr.String()] = fakeConn |
||||
} |
||||
fakeConn.putPacket(buf[:n]) |
||||
|
||||
l.accept <- fakeConn |
||||
} |
||||
}() |
||||
|
||||
// for writing
|
||||
go func() { |
||||
for { |
||||
packet, ok := <-l.writeCh |
||||
if !ok { |
||||
return |
||||
} |
||||
|
||||
if addr, ok := packet.RemoteAddr.(*net.UDPAddr); ok { |
||||
readConn.WriteToUDP(packet.Buf, addr) |
||||
} |
||||
} |
||||
}() |
||||
|
||||
return |
||||
} |
||||
|
||||
func (l *UdpListener) writeUdpPacket(packet *UdpPacket) { |
||||
defer func() { |
||||
if err := recover(); err != nil { |
||||
} |
||||
}() |
||||
l.writeCh <- packet |
||||
} |
||||
|
||||
func (l *UdpListener) Accept() (Conn, error) { |
||||
conn, ok := <-l.accept |
||||
if !ok { |
||||
return conn, fmt.Errorf("channel for udp listener closed") |
||||
} |
||||
return conn, nil |
||||
} |
||||
|
||||
func (l *UdpListener) Close() error { |
||||
if !l.closeFlag { |
||||
l.closeFlag = true |
||||
l.readConn.Close() |
||||
} |
||||
return nil |
||||
} |
@ -0,0 +1,27 @@
|
||||
package net |
||||
|
||||
import ( |
||||
"fmt" |
||||
"testing" |
||||
) |
||||
|
||||
func TestA(t *testing.T) { |
||||
l, err := ListenUDP("0.0.0.0", 9000) |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
} |
||||
for { |
||||
c, _ := l.Accept() |
||||
go func() { |
||||
for { |
||||
buf := make([]byte, 1450) |
||||
n, err := c.Read(buf) |
||||
if err != nil { |
||||
fmt.Println(buf[:n]) |
||||
} |
||||
|
||||
c.Write(buf[:n]) |
||||
} |
||||
}() |
||||
} |
||||
} |
@ -0,0 +1,51 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pool |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func TestPutBuf(t *testing.T) { |
||||
buf := make([]byte, 512) |
||||
PutBuf(buf) |
||||
|
||||
buf = make([]byte, 1025) |
||||
PutBuf(buf) |
||||
|
||||
buf = make([]byte, 2*1025) |
||||
PutBuf(buf) |
||||
|
||||
buf = make([]byte, 5*1025) |
||||
PutBuf(buf) |
||||
} |
||||
|
||||
func TestGetBuf(t *testing.T) { |
||||
assert := assert.New(t) |
||||
|
||||
buf := GetBuf(200) |
||||
assert.Len(buf, 200) |
||||
|
||||
buf = GetBuf(1025) |
||||
assert.Len(buf, 1025) |
||||
|
||||
buf = GetBuf(2 * 1024) |
||||
assert.Len(buf, 2*1024) |
||||
|
||||
buf = GetBuf(5 * 2000) |
||||
assert.Len(buf, 5*2000) |
||||
} |
@ -0,0 +1,62 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package shutdown |
||||
|
||||
import ( |
||||
"sync" |
||||
) |
||||
|
||||
type Shutdown struct { |
||||
doing bool |
||||
ending bool |
||||
start chan struct{} |
||||
down chan struct{} |
||||
mu sync.Mutex |
||||
} |
||||
|
||||
func New() *Shutdown { |
||||
return &Shutdown{ |
||||
doing: false, |
||||
ending: false, |
||||
start: make(chan struct{}), |
||||
down: make(chan struct{}), |
||||
} |
||||
} |
||||
|
||||
func (s *Shutdown) Start() { |
||||
s.mu.Lock() |
||||
defer s.mu.Unlock() |
||||
if !s.doing { |
||||
s.doing = true |
||||
close(s.start) |
||||
} |
||||
} |
||||
|
||||
func (s *Shutdown) WaitStart() { |
||||
<-s.start |
||||
} |
||||
|
||||
func (s *Shutdown) Done() { |
||||
s.mu.Lock() |
||||
defer s.mu.Unlock() |
||||
if !s.ending { |
||||
s.ending = true |
||||
close(s.down) |
||||
} |
||||
} |
||||
|
||||
func (s *Shutdown) WaitDown() { |
||||
<-s.down |
||||
} |
@ -0,0 +1,21 @@
|
||||
package shutdown |
||||
|
||||
import ( |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestShutdown(t *testing.T) { |
||||
s := New() |
||||
go func() { |
||||
time.Sleep(time.Millisecond) |
||||
s.Start() |
||||
}() |
||||
s.WaitStart() |
||||
|
||||
go func() { |
||||
time.Sleep(time.Millisecond) |
||||
s.Done() |
||||
}() |
||||
s.WaitDown() |
||||
} |
@ -0,0 +1,47 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package util |
||||
|
||||
import ( |
||||
"crypto/md5" |
||||
"crypto/rand" |
||||
"encoding/hex" |
||||
"fmt" |
||||
) |
||||
|
||||
// RandId return a rand string used in frp.
|
||||
func RandId() (id string, err error) { |
||||
return RandIdWithLen(8) |
||||
} |
||||
|
||||
// RandIdWithLen return a rand string with idLen length.
|
||||
func RandIdWithLen(idLen int) (id string, err error) { |
||||
b := make([]byte, idLen) |
||||
_, err = rand.Read(b) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
id = fmt.Sprintf("%x", b) |
||||
return |
||||
} |
||||
|
||||
func GetAuthKey(token string, timestamp int64) (key string) { |
||||
token = token + fmt.Sprintf("%d", timestamp) |
||||
md5Ctx := md5.New() |
||||
md5Ctx.Write([]byte(token)) |
||||
data := md5Ctx.Sum(nil) |
||||
return hex.EncodeToString(data) |
||||
} |
@ -0,0 +1,22 @@
|
||||
package util |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func TestRandId(t *testing.T) { |
||||
assert := assert.New(t) |
||||
id, err := RandId() |
||||
assert.NoError(err) |
||||
t.Log(id) |
||||
assert.Equal(16, len(id)) |
||||
} |
||||
|
||||
func TestGetAuthKey(t *testing.T) { |
||||
assert := assert.New(t) |
||||
key := GetAuthKey("1234", 1488720000) |
||||
t.Log(key) |
||||
assert.Equal("6df41a43725f0c770fd56379e12acf8c", key) |
||||
} |
@ -0,0 +1,78 @@
|
||||
package logs |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"net/http" |
||||
"net/url" |
||||
"time" |
||||
) |
||||
|
||||
// JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook
|
||||
type JLWriter struct { |
||||
AuthorName string `json:"authorname"` |
||||
Title string `json:"title"` |
||||
WebhookURL string `json:"webhookurl"` |
||||
RedirectURL string `json:"redirecturl,omitempty"` |
||||
ImageURL string `json:"imageurl,omitempty"` |
||||
Level int `json:"level"` |
||||
} |
||||
|
||||
// newJLWriter create jiaoliao writer.
|
||||
func newJLWriter() Logger { |
||||
return &JLWriter{Level: LevelTrace} |
||||
} |
||||
|
||||
// Init JLWriter with json config string
|
||||
func (s *JLWriter) Init(jsonconfig string) error { |
||||
err := json.Unmarshal([]byte(jsonconfig), s) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// WriteMsg write message in smtp writer.
|
||||
// it will send an email with subject and only this message.
|
||||
func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error { |
||||
if level > s.Level { |
||||
return nil |
||||
} |
||||
|
||||
text := fmt.Sprintf("%s %s", when.Format("2006-01-02 15:04:05"), msg) |
||||
|
||||
form := url.Values{} |
||||
form.Add("authorName", s.AuthorName) |
||||
form.Add("title", s.Title) |
||||
form.Add("text", text) |
||||
if s.RedirectURL != "" { |
||||
form.Add("redirectUrl", s.RedirectURL) |
||||
} |
||||
if s.ImageURL != "" { |
||||
form.Add("imageUrl", s.ImageURL) |
||||
} |
||||
|
||||
resp, err := http.PostForm(s.WebhookURL, form) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer resp.Body.Close() |
||||
if resp.StatusCode != http.StatusOK { |
||||
return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Flush implementing method. empty.
|
||||
func (s *JLWriter) Flush() { |
||||
return |
||||
} |
||||
|
||||
// Destroy implementing method. empty.
|
||||
func (s *JLWriter) Destroy() { |
||||
return |
||||
} |
||||
|
||||
func init() { |
||||
Register(AdapterJianLiao, newJLWriter) |
||||
} |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue