diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index 7a6edc9e..52dc95b4 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -6,6 +6,10 @@ "./..." ], "Deps": [ + { + "ImportPath": "github.com/armon/go-socks5", + "Rev": "e75332964ef517daa070d7c38a9466a0d687e0a5" + }, { "ImportPath": "github.com/davecgh/go-spew/spew", "Comment": "v1.1.0", @@ -110,6 +114,10 @@ "ImportPath": "golang.org/x/net/bpf", "Rev": "e4fa1c5465ad6111f206fc92186b8c83d64adbe1" }, + { + "ImportPath": "golang.org/x/net/context", + "Rev": "e4fa1c5465ad6111f206fc92186b8c83d64adbe1" + }, { "ImportPath": "golang.org/x/net/internal/iana", "Rev": "e4fa1c5465ad6111f206fc92186b8c83d64adbe1" diff --git a/README.md b/README.md index 005f7498..1a0dd081 100644 --- a/README.md +++ b/README.md @@ -237,14 +237,14 @@ Configure frps same as above. [http_proxy] type = tcp remote_port = 6000 - plugin = http_proxy + plugin = http_proxy # or socks5 ``` 4. Start frpc: `./frpc -c ./frpc.ini` -5. Set http proxy `x.x.x.x:6000` in your browser and visit website through frpc's network. +5. Set http proxy or socks5 proxy `x.x.x.x:6000` in your browser and visit website through frpc's network. ## Features @@ -469,7 +469,7 @@ http_proxy = http://user:pwd@192.168.1.128:8080 frpc only forward request to local tcp or udp port by default. -Plugin is used for providing rich features. There are built-in plugins such as **unix_domain_socket**, **http_proxy** and you can see [example usage](#example-usage). +Plugin is used for providing rich features. There are built-in plugins such as **unix_domain_socket**, **http_proxy**, **socks5** and you can see [example usage](#example-usage). Specify which plugin to use by `plugin` parameter. Configuration parameters of plugin should be started with `plugin_`. `local_ip` and `local_port` is useless for plugin. diff --git a/README_zh.md b/README_zh.md index be8187a2..c1aad4d8 100644 --- a/README_zh.md +++ b/README_zh.md @@ -225,11 +225,11 @@ DNS 查询请求通常使用 UDP 协议,frp 支持对内网 UDP 服务的穿 ### 通过 frpc 所在机器访问外网 -frpc 内置了 http proxy 插件,可以使其他机器通过 frpc 的网络访问互联网。 +frpc 内置了 http proxy 和 socks5 插件,可以使其他机器通过 frpc 的网络访问互联网。 frps 的部署步骤同上。 -1. 修改 frpc.ini 文件,启用 http_proxy 插件: +1. 修改 frpc.ini 文件,启用 http_proxy 或 socks5 插件(plugin 换为 socks5 即可): ```ini # frpc.ini @@ -247,7 +247,7 @@ frps 的部署步骤同上。 `./frpc -c ./frpc.ini` -5. 浏览器设置 http 代理地址为 `x.x.x.x:6000`,通过 frpc 机器的网络访问互联网。 +5. 浏览器设置 http 或 socks5 代理地址为 `x.x.x.x:6000`,通过 frpc 机器的网络访问互联网。 ## 功能说明 @@ -486,7 +486,7 @@ http_proxy = http://user:pwd@192.168.1.128:8080 默认情况下,frpc 只会转发请求到本地 tcp 或 udp 端口。 -插件模式是为了在客户端提供更加丰富的功能,目前内置的插件有 **unix_domain_socket**、**http_proxy**。具体使用方式请查看[使用示例](#使用示例)。 +插件模式是为了在客户端提供更加丰富的功能,目前内置的插件有 **unix_domain_socket**、**http_proxy**、**socks5**。具体使用方式请查看[使用示例](#使用示例)。 通过 `plugin` 指定需要使用的插件,插件的配置参数都以 `plugin_` 开头。使用插件后 `local_ip` 和 `local_port` 不再需要配置。 diff --git a/client/admin.go b/client/admin.go new file mode 100644 index 00000000..f728483e --- /dev/null +++ b/client/admin.go @@ -0,0 +1,60 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "fmt" + "net" + "net/http" + "time" + + "github.com/fatedier/frp/models/config" + frpNet "github.com/fatedier/frp/utils/net" + + "github.com/julienschmidt/httprouter" +) + +var ( + httpServerReadTimeout = 10 * time.Second + httpServerWriteTimeout = 10 * time.Second +) + +func (svr *Service) RunAdminServer(addr string, port int64) (err error) { + // url router + router := httprouter.New() + + user, passwd := config.ClientCommonCfg.AdminUser, config.ClientCommonCfg.AdminPwd + + // api, see dashboard_api.go + router.GET("/api/reload", frpNet.HttprouterBasicAuth(svr.apiReload, user, passwd)) + + address := fmt.Sprintf("%s:%d", addr, port) + server := &http.Server{ + Addr: address, + Handler: router, + ReadTimeout: httpServerReadTimeout, + WriteTimeout: httpServerWriteTimeout, + } + if address == "" { + address = ":http" + } + ln, err := net.Listen("tcp", address) + if err != nil { + return err + } + + go server.Serve(ln) + return +} diff --git a/client/admin_api.go b/client/admin_api.go new file mode 100644 index 00000000..72fae04e --- /dev/null +++ b/client/admin_api.go @@ -0,0 +1,78 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "encoding/json" + "net/http" + + "github.com/julienschmidt/httprouter" + ini "github.com/vaughan0/go-ini" + + "github.com/fatedier/frp/models/config" + "github.com/fatedier/frp/utils/log" +) + +type GeneralResponse struct { + Code int64 `json:"code"` + Msg string `json:"msg"` +} + +// api/reload +type ReloadResp struct { + GeneralResponse +} + +func (svr *Service) apiReload(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + var ( + buf []byte + res ReloadResp + ) + defer func() { + log.Info("Http response [/api/reload]: code [%d]", res.Code) + buf, _ = json.Marshal(&res) + w.Write(buf) + }() + + log.Info("Http request: [/api/reload]") + + conf, err := ini.LoadFile(config.ClientCommonCfg.ConfigFile) + if err != nil { + res.Code = 1 + res.Msg = err.Error() + log.Error("reload frpc config file error: %v", err) + return + } + + newCommonCfg, err := config.LoadClientCommonConf(conf) + if err != nil { + res.Code = 2 + res.Msg = err.Error() + log.Error("reload frpc common section error: %v", err) + return + } + + pxyCfgs, vistorCfgs, err := config.LoadProxyConfFromFile(newCommonCfg.User, conf, newCommonCfg.Start) + if err != nil { + res.Code = 3 + res.Msg = err.Error() + log.Error("reload frpc proxy config error: %v", err) + return + } + + svr.ctl.reloadConf(pxyCfgs, vistorCfgs) + log.Info("success reload conf") + return +} diff --git a/client/control.go b/client/control.go index c7020580..29dca60c 100644 --- a/client/control.go +++ b/client/control.go @@ -24,8 +24,9 @@ import ( "github.com/fatedier/frp/models/config" "github.com/fatedier/frp/models/msg" "github.com/fatedier/frp/utils/crypto" + "github.com/fatedier/frp/utils/errors" "github.com/fatedier/frp/utils/log" - "github.com/fatedier/frp/utils/net" + frpNet "github.com/fatedier/frp/utils/net" "github.com/fatedier/frp/utils/util" "github.com/fatedier/frp/utils/version" "github.com/xtaci/smux" @@ -48,8 +49,14 @@ type Control struct { // proxies proxies map[string]Proxy + // vistor configures + vistorCfgs map[string]config.ProxyConf + + // vistors + vistors map[string]Vistor + // control connection - conn net.Conn + conn frpNet.Conn // tcp stream multiplexing, if enabled session *smux.Session @@ -63,8 +70,8 @@ type Control struct { // run id got from server runId string - // connection or other error happens , control will try to reconnect to server - closed int32 + // if we call close() in control, do not reconnect to server + exit bool // goroutines can block by reading from this channel, it will be closed only in reader() when control connection is closed closedCh chan int @@ -77,7 +84,7 @@ type Control struct { log.Logger } -func NewControl(svr *Service, pxyCfgs map[string]config.ProxyConf) *Control { +func NewControl(svr *Service, pxyCfgs map[string]config.ProxyConf, vistorCfgs map[string]config.ProxyConf) *Control { loginMsg := &msg.Login{ Arch: runtime.GOARCH, Os: runtime.GOOS, @@ -86,14 +93,16 @@ func NewControl(svr *Service, pxyCfgs map[string]config.ProxyConf) *Control { Version: version.Full(), } return &Control{ - svr: svr, - loginMsg: loginMsg, - pxyCfgs: pxyCfgs, - proxies: make(map[string]Proxy), - sendCh: make(chan msg.Message, 10), - readCh: make(chan msg.Message, 10), - closedCh: make(chan int), - Logger: log.NewPrefixLogger(""), + svr: svr, + loginMsg: loginMsg, + pxyCfgs: pxyCfgs, + vistorCfgs: vistorCfgs, + proxies: make(map[string]Proxy), + vistors: make(map[string]Vistor), + sendCh: make(chan msg.Message, 10), + readCh: make(chan msg.Message, 10), + closedCh: make(chan int), + Logger: log.NewPrefixLogger(""), } } @@ -105,16 +114,17 @@ func NewControl(svr *Service, pxyCfgs map[string]config.ProxyConf) *Control { // 6. In controler(): ini readCh, sendCh, closedCh // 7. In controler(): start new reader(), writer(), manager() // controler() will keep running -func (ctl *Control) Run() error { +func (ctl *Control) Run() (err error) { for { - err := ctl.login() + err = ctl.login() if err != nil { + ctl.Warn("login to server failed: %v", err) + // if login_fail_exit is true, just exit this program // otherwise sleep a while and continues relogin to server if config.ClientCommonCfg.LoginFailExit { - return err + return } else { - ctl.Warn("login to server fail: %v", err) time.Sleep(30 * time.Second) } } else { @@ -127,6 +137,18 @@ func (ctl *Control) Run() error { go ctl.writer() go ctl.reader() + // start all local vistors + for _, cfg := range ctl.vistorCfgs { + vistor := NewVistor(ctl, cfg) + err = vistor.Run() + if err != nil { + vistor.Warn("start error: %v", err) + continue + } + ctl.vistors[cfg.GetName()] = vistor + vistor.Info("start vistor success") + } + // send NewProxy message for all configured proxies for _, cfg := range ctl.pxyCfgs { var newProxyMsg msg.NewProxy @@ -137,29 +159,13 @@ func (ctl *Control) Run() error { } func (ctl *Control) NewWorkConn() { - var ( - workConn net.Conn - err error - ) - if config.ClientCommonCfg.TcpMux { - stream, err := ctl.session.OpenStream() - if err != nil { - ctl.Warn("start new work connection error: %v", err) - return - } - workConn = net.WrapConn(stream) - - } else { - workConn, err = net.ConnectServerByHttpProxy(config.ClientCommonCfg.HttpProxy, config.ClientCommonCfg.Protocol, - fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, config.ClientCommonCfg.ServerPort)) - if err != nil { - ctl.Warn("start new work connection error: %v", err) - return - } + workConn, err := ctl.connectServer() + if err != nil { + return } m := &msg.NewWorkConn{ - RunId: ctl.runId, + RunId: ctl.getRunId(), } if err = msg.WriteMsg(workConn, m); err != nil { ctl.Warn("work connection write to server error: %v", err) @@ -176,7 +182,8 @@ func (ctl *Control) NewWorkConn() { workConn.AddLogPrefix(startMsg.ProxyName) // dispatch this work connection to related proxy - if pxy, ok := ctl.proxies[startMsg.ProxyName]; ok { + pxy, ok := ctl.getProxy(startMsg.ProxyName) + if ok { workConn.Debug("start a new work connection, localAddr: %s remoteAddr: %s", workConn.LocalAddr().String(), workConn.RemoteAddr().String()) go pxy.InWorkConn(workConn) } else { @@ -184,6 +191,20 @@ func (ctl *Control) NewWorkConn() { } } +func (ctl *Control) Close() error { + ctl.mu.Lock() + ctl.exit = true + err := errors.PanicToError(func() { + for name, _ := range ctl.proxies { + ctl.sendCh <- &msg.CloseProxy{ + ProxyName: name, + } + } + }) + ctl.mu.Unlock() + return err +} + func (ctl *Control) init() { ctl.sendCh = make(chan msg.Message, 10) ctl.readCh = make(chan msg.Message, 10) @@ -199,7 +220,7 @@ func (ctl *Control) login() (err error) { ctl.session.Close() } - conn, err := net.ConnectServerByHttpProxy(config.ClientCommonCfg.HttpProxy, config.ClientCommonCfg.Protocol, + conn, err := frpNet.ConnectServerByHttpProxy(config.ClientCommonCfg.HttpProxy, config.ClientCommonCfg.Protocol, fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, config.ClientCommonCfg.ServerPort)) if err != nil { return err @@ -221,14 +242,14 @@ func (ctl *Control) login() (err error) { session.Close() return errRet } - conn = net.WrapConn(stream) + conn = frpNet.WrapConn(stream) ctl.session = session } now := time.Now().Unix() ctl.loginMsg.PrivilegeKey = util.GetAuthKey(config.ClientCommonCfg.PrivilegeToken, now) ctl.loginMsg.Timestamp = now - ctl.loginMsg.RunId = ctl.runId + ctl.loginMsg.RunId = ctl.getRunId() if err = msg.WriteMsg(conn, ctl.loginMsg); err != nil { return err @@ -249,7 +270,7 @@ func (ctl *Control) login() (err error) { ctl.conn = conn // update runId got from server - ctl.runId = loginRespMsg.RunId + ctl.setRunId(loginRespMsg.RunId) ctl.ClearLogPrefix() ctl.AddLogPrefix(loginRespMsg.RunId) ctl.Info("login to server success, get run id [%s]", loginRespMsg.RunId) @@ -261,6 +282,27 @@ func (ctl *Control) login() (err error) { return nil } +func (ctl *Control) connectServer() (conn frpNet.Conn, err error) { + if config.ClientCommonCfg.TcpMux { + stream, errRet := ctl.session.OpenStream() + if errRet != nil { + err = errRet + ctl.Warn("start new connection to server error: %v", err) + return + } + conn = frpNet.WrapConn(stream) + + } else { + conn, err = frpNet.ConnectServerByHttpProxy(config.ClientCommonCfg.HttpProxy, config.ClientCommonCfg.Protocol, + fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, config.ClientCommonCfg.ServerPort)) + if err != nil { + ctl.Warn("start new connection to server error: %v", err) + return + } + } + return +} + func (ctl *Control) reader() { defer func() { if err := recover(); err != nil { @@ -305,6 +347,7 @@ func (ctl *Control) writer() { } } +// manager handles all channel events and do corresponding process func (ctl *Control) manager() { defer func() { if err := recover(); err != nil { @@ -345,13 +388,14 @@ func (ctl *Control) manager() { ctl.Warn("[%s] start error: %s", m.ProxyName, m.Error) continue } - cfg, ok := ctl.pxyCfgs[m.ProxyName] + cfg, ok := ctl.getProxyConf(m.ProxyName) if !ok { // it will never go to this branch now ctl.Warn("[%s] no proxy conf found", m.ProxyName) continue } - oldPxy, ok := ctl.proxies[m.ProxyName] + + oldPxy, ok := ctl.getProxy(m.ProxyName) if ok { oldPxy.Close() } @@ -363,7 +407,7 @@ func (ctl *Control) manager() { } continue } - ctl.proxies[m.ProxyName] = pxy + ctl.addProxy(m.ProxyName, pxy) ctl.Info("[%s] start proxy success", m.ProxyName) case *msg.Pong: ctl.lastPong = time.Now() @@ -373,26 +417,43 @@ func (ctl *Control) manager() { } } -// control keep watching closedCh, start a new connection if previous control connection is closed +// controler keep watching closedCh, start a new connection if previous control connection is closed. +// If controler is notified by closedCh, reader and writer and manager will exit, then recall these functions. func (ctl *Control) controler() { var err error maxDelayTime := 30 * time.Second delayTime := time.Second - checkInterval := 30 * time.Second + checkInterval := 10 * time.Second checkProxyTicker := time.NewTicker(checkInterval) for { select { case <-checkProxyTicker.C: - // Every 30 seconds, check which proxy registered failed and reregister it to server. + // Every 10 seconds, check which proxy registered failed and reregister it to server. + ctl.mu.RLock() for _, cfg := range ctl.pxyCfgs { if _, exist := ctl.proxies[cfg.GetName()]; !exist { - ctl.Info("try to reregister proxy [%s]", cfg.GetName()) + ctl.Info("try to register proxy [%s]", cfg.GetName()) var newProxyMsg msg.NewProxy cfg.UnMarshalToMsg(&newProxyMsg) ctl.sendCh <- &newProxyMsg } } + + for _, cfg := range ctl.vistorCfgs { + if _, exist := ctl.vistors[cfg.GetName()]; !exist { + ctl.Info("try to start vistor [%s]", cfg.GetName()) + vistor := NewVistor(ctl, cfg) + err = vistor.Run() + if err != nil { + vistor.Warn("start error: %v", err) + continue + } + ctl.vistors[cfg.GetName()] = vistor + vistor.Info("start vistor success") + } + } + ctl.mu.RUnlock() case _, ok := <-ctl.closedCh: // we won't get any variable from this channel if !ok { @@ -403,6 +464,14 @@ func (ctl *Control) controler() { for _, pxy := range ctl.proxies { pxy.Close() } + // if ctl.exit is true, just exit + ctl.mu.RLock() + exit := ctl.exit + ctl.mu.RUnlock() + if exit { + return + } + time.Sleep(time.Second) // loop util reconnect to server success @@ -432,11 +501,13 @@ func (ctl *Control) controler() { go ctl.reader() // send NewProxy message for all configured proxies + ctl.mu.RLock() for _, cfg := range ctl.pxyCfgs { var newProxyMsg msg.NewProxy cfg.UnMarshalToMsg(&newProxyMsg) ctl.sendCh <- &newProxyMsg } + ctl.mu.RUnlock() checkProxyTicker.Stop() checkProxyTicker = time.NewTicker(checkInterval) @@ -444,3 +515,107 @@ func (ctl *Control) controler() { } } } + +func (ctl *Control) setRunId(runId string) { + ctl.mu.Lock() + defer ctl.mu.Unlock() + ctl.runId = runId +} + +func (ctl *Control) getRunId() string { + ctl.mu.RLock() + defer ctl.mu.RUnlock() + return ctl.runId +} + +func (ctl *Control) getProxy(name string) (pxy Proxy, ok bool) { + ctl.mu.RLock() + defer ctl.mu.RUnlock() + pxy, ok = ctl.proxies[name] + return +} + +func (ctl *Control) addProxy(name string, pxy Proxy) { + ctl.mu.Lock() + defer ctl.mu.Unlock() + ctl.proxies[name] = pxy +} + +func (ctl *Control) getProxyConf(name string) (conf config.ProxyConf, ok bool) { + ctl.mu.RLock() + defer ctl.mu.RUnlock() + conf, ok = ctl.pxyCfgs[name] + return +} + +func (ctl *Control) reloadConf(pxyCfgs map[string]config.ProxyConf, vistorCfgs map[string]config.ProxyConf) { + ctl.mu.Lock() + defer ctl.mu.Unlock() + + removedPxyNames := make([]string, 0) + for name, oldCfg := range ctl.pxyCfgs { + del := false + cfg, ok := pxyCfgs[name] + if !ok { + del = true + } else { + if !oldCfg.Compare(cfg) { + del = true + } + } + + if del { + removedPxyNames = append(removedPxyNames, name) + delete(ctl.pxyCfgs, name) + if pxy, ok := ctl.proxies[name]; ok { + pxy.Close() + } + delete(ctl.proxies, name) + ctl.sendCh <- &msg.CloseProxy{ + ProxyName: name, + } + } + } + ctl.Info("proxy removed: %v", removedPxyNames) + + addedPxyNames := make([]string, 0) + for name, cfg := range pxyCfgs { + if _, ok := ctl.pxyCfgs[name]; !ok { + ctl.pxyCfgs[name] = cfg + addedPxyNames = append(addedPxyNames, name) + } + } + ctl.Info("proxy added: %v", addedPxyNames) + + removedVistorName := make([]string, 0) + for name, oldVistorCfg := range ctl.vistorCfgs { + del := false + cfg, ok := vistorCfgs[name] + if !ok { + del = true + } else { + if !oldVistorCfg.Compare(cfg) { + del = true + } + } + + if del { + removedVistorName = append(removedVistorName, name) + delete(ctl.vistorCfgs, name) + if vistor, ok := ctl.vistors[name]; ok { + vistor.Close() + } + delete(ctl.vistors, name) + } + } + ctl.Info("vistor removed: %v", removedVistorName) + + addedVistorName := make([]string, 0) + for name, vistorCfg := range vistorCfgs { + if _, ok := ctl.vistorCfgs[name]; !ok { + ctl.vistorCfgs[name] = vistorCfg + addedVistorName = append(addedVistorName, name) + } + } + ctl.Info("vistor added: %v", addedVistorName) +} diff --git a/client/proxy.go b/client/proxy.go index cd7994c6..147a3fbd 100644 --- a/client/proxy.go +++ b/client/proxy.go @@ -31,7 +31,7 @@ import ( frpNet "github.com/fatedier/frp/utils/net" ) -// Proxy defines how to work for different proxy type. +// Proxy defines how to deal with work connections for different proxy type. type Proxy interface { Run() error @@ -67,6 +67,11 @@ func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy) { BaseProxy: baseProxy, cfg: cfg, } + case *config.StcpProxyConf: + pxy = &StcpProxy{ + BaseProxy: baseProxy, + cfg: cfg, + } } return } @@ -162,6 +167,34 @@ func (pxy *HttpsProxy) InWorkConn(conn frpNet.Conn) { HandleTcpWorkConnection(&pxy.cfg.LocalSvrConf, pxy.proxyPlugin, &pxy.cfg.BaseProxyConf, conn) } +// STCP +type StcpProxy struct { + BaseProxy + + cfg *config.StcpProxyConf + proxyPlugin plugin.Plugin +} + +func (pxy *StcpProxy) Run() (err error) { + if pxy.cfg.Plugin != "" { + pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) + if err != nil { + return + } + } + return +} + +func (pxy *StcpProxy) Close() { + if pxy.proxyPlugin != nil { + pxy.proxyPlugin.Close() + } +} + +func (pxy *StcpProxy) InWorkConn(conn frpNet.Conn) { + HandleTcpWorkConnection(&pxy.cfg.LocalSvrConf, pxy.proxyPlugin, &pxy.cfg.BaseProxyConf, conn) +} + // UDP type UdpProxy struct { BaseProxy diff --git a/client/service.go b/client/service.go index 36d0a7e8..241a435c 100644 --- a/client/service.go +++ b/client/service.go @@ -14,7 +14,10 @@ package client -import "github.com/fatedier/frp/models/config" +import ( + "github.com/fatedier/frp/models/config" + "github.com/fatedier/frp/utils/log" +) type Service struct { // manager control connection with server @@ -23,11 +26,11 @@ type Service struct { closedCh chan int } -func NewService(pxyCfgs map[string]config.ProxyConf) (svr *Service) { +func NewService(pxyCfgs map[string]config.ProxyConf, vistorCfgs map[string]config.ProxyConf) (svr *Service) { svr = &Service{ closedCh: make(chan int), } - ctl := NewControl(svr, pxyCfgs) + ctl := NewControl(svr, pxyCfgs, vistorCfgs) svr.ctl = ctl return } @@ -38,6 +41,18 @@ func (svr *Service) Run() error { return err } + if config.ClientCommonCfg.AdminPort != 0 { + err = svr.RunAdminServer(config.ClientCommonCfg.AdminAddr, config.ClientCommonCfg.AdminPort) + if err != nil { + log.Warn("run admin server error: %v", err) + } + log.Info("admin server listen on %s:%d", config.ClientCommonCfg.AdminAddr, config.ClientCommonCfg.AdminPort) + } + <-svr.closedCh return nil } + +func (svr *Service) Close() error { + return svr.ctl.Close() +} diff --git a/client/vistor.go b/client/vistor.go new file mode 100644 index 00000000..8787ebfe --- /dev/null +++ b/client/vistor.go @@ -0,0 +1,145 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "io" + "sync" + "time" + + "github.com/fatedier/frp/models/config" + "github.com/fatedier/frp/models/msg" + frpIo "github.com/fatedier/frp/utils/io" + "github.com/fatedier/frp/utils/log" + frpNet "github.com/fatedier/frp/utils/net" + "github.com/fatedier/frp/utils/util" +) + +// Vistor is used for forward traffics from local port tot remote service. +type Vistor interface { + Run() error + Close() + log.Logger +} + +func NewVistor(ctl *Control, pxyConf config.ProxyConf) (vistor Vistor) { + baseVistor := BaseVistor{ + ctl: ctl, + Logger: log.NewPrefixLogger(pxyConf.GetName()), + } + switch cfg := pxyConf.(type) { + case *config.StcpProxyConf: + vistor = &StcpVistor{ + BaseVistor: baseVistor, + cfg: cfg, + } + } + return +} + +type BaseVistor struct { + ctl *Control + l frpNet.Listener + closed bool + mu sync.RWMutex + log.Logger +} + +type StcpVistor struct { + BaseVistor + + cfg *config.StcpProxyConf +} + +func (sv *StcpVistor) Run() (err error) { + sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, int64(sv.cfg.BindPort)) + if err != nil { + return + } + + go sv.worker() + return +} + +func (sv *StcpVistor) Close() { + sv.l.Close() +} + +func (sv *StcpVistor) worker() { + for { + conn, err := sv.l.Accept() + if err != nil { + sv.Warn("stcp local listener closed") + return + } + + go sv.handleConn(conn) + } +} + +func (sv *StcpVistor) handleConn(userConn frpNet.Conn) { + defer userConn.Close() + + sv.Debug("get a new stcp user connection") + vistorConn, err := sv.ctl.connectServer() + if err != nil { + return + } + defer vistorConn.Close() + + now := time.Now().Unix() + newVistorConnMsg := &msg.NewVistorConn{ + ProxyName: sv.cfg.ServerName, + SignKey: util.GetAuthKey(sv.cfg.Sk, now), + Timestamp: now, + UseEncryption: sv.cfg.UseEncryption, + UseCompression: sv.cfg.UseCompression, + } + err = msg.WriteMsg(vistorConn, newVistorConnMsg) + if err != nil { + sv.Warn("send newVistorConnMsg to server error: %v", err) + return + } + + var newVistorConnRespMsg msg.NewVistorConnResp + vistorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + err = msg.ReadMsgInto(vistorConn, &newVistorConnRespMsg) + if err != nil { + sv.Warn("get newVistorConnRespMsg error: %v", err) + return + } + vistorConn.SetReadDeadline(time.Time{}) + + if newVistorConnRespMsg.Error != "" { + sv.Warn("start new vistor connection error: %s", newVistorConnRespMsg.Error) + return + } + + var remote io.ReadWriteCloser + remote = vistorConn + if sv.cfg.UseEncryption { + remote, err = frpIo.WithEncryption(remote, []byte(sv.cfg.Sk)) + if err != nil { + sv.Error("create encryption stream error: %v", err) + return + } + } + + if sv.cfg.UseCompression { + remote = frpIo.WithCompression(remote) + } + + frpIo.Join(userConn, remote) +} diff --git a/cmd/frpc/main.go b/cmd/frpc/main.go index dcbc3a53..5aa6bec8 100644 --- a/cmd/frpc/main.go +++ b/cmd/frpc/main.go @@ -15,10 +15,17 @@ package main import ( + "encoding/base64" + "encoding/json" "fmt" + "io/ioutil" + "net/http" "os" + "os/signal" "strconv" "strings" + "syscall" + "time" docopt "github.com/docopt/docopt-go" ini "github.com/vaughan0/go-ini" @@ -37,6 +44,7 @@ var usage string = `frpc is the client of frp Usage: frpc [-c config_file] [-L log_file] [--log-level=] [--server-addr=] + frpc [-c config_file] --reload frpc -h | --help frpc -v | --version @@ -45,13 +53,14 @@ Options: -L log_file set output log file, including console --log-level= set log level: debug, info, warn, error --server-addr= addr which frps is listening for, example: 0.0.0.0:7000 + --reload reload configure file without program exit -h --help show this screen -v --version show version ` func main() { var err error - confFile := "./frpc.ini" + confFile := "./frps.ini" // the configures parsed from file will be replaced by those from command line if exist args, err := docopt.Parse(usage, nil, true, version.Full(), false) @@ -70,6 +79,47 @@ func main() { fmt.Println(err) os.Exit(1) } + config.ClientCommonCfg.ConfigFile = confFile + + // check if reload command + if args["--reload"] != nil { + if args["--reload"].(bool) { + req, err := http.NewRequest("GET", "http://"+ + config.ClientCommonCfg.AdminAddr+":"+fmt.Sprintf("%d", config.ClientCommonCfg.AdminPort)+"/api/reload", nil) + if err != nil { + fmt.Printf("frps reload error: %v\n", err) + os.Exit(1) + } + + authStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(config.ClientCommonCfg.AdminUser+":"+ + config.ClientCommonCfg.AdminPwd)) + + req.Header.Add("Authorization", authStr) + resp, err := http.DefaultClient.Do(req) + if err != nil { + fmt.Printf("frpc reload error: %v\n", err) + os.Exit(1) + } else { + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Printf("frpc reload error: %v\n", err) + os.Exit(1) + } + res := &client.GeneralResponse{} + err = json.Unmarshal(body, &res) + if err != nil { + fmt.Printf("http response error: %s\n", strings.TrimSpace(string(body))) + os.Exit(1) + } else if res.Code != 0 { + fmt.Printf("reload error: %s\n", res.Msg) + os.Exit(1) + } + fmt.Printf("reload success\n") + os.Exit(0) + } + } + } if args["-L"] != nil { if args["-L"].(string) == "console" { @@ -106,7 +156,7 @@ func main() { } } - pxyCfgs, err := config.LoadProxyConfFromFile(config.ClientCommonCfg.User, conf, config.ClientCommonCfg.Start) + pxyCfgs, vistorCfgs, err := config.LoadProxyConfFromFile(config.ClientCommonCfg.User, conf, config.ClientCommonCfg.Start) if err != nil { fmt.Println(err) os.Exit(1) @@ -115,10 +165,25 @@ func main() { log.InitLog(config.ClientCommonCfg.LogWay, config.ClientCommonCfg.LogFile, config.ClientCommonCfg.LogLevel, config.ClientCommonCfg.LogMaxDays) - svr := client.NewService(pxyCfgs) + svr := client.NewService(pxyCfgs, vistorCfgs) + + // Capture the exit signal if we use kcp. + if config.ClientCommonCfg.Protocol == "kcp" { + go HandleSignal(svr) + } + err = svr.Run() if err != nil { fmt.Println(err) os.Exit(1) } } + +func HandleSignal(svr *client.Service) { + ch := make(chan os.Signal) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) + <-ch + svr.Close() + time.Sleep(250 * time.Millisecond) + os.Exit(0) +} diff --git a/conf/frpc_full.ini b/conf/frpc_full.ini index a653fc24..5ae97447 100644 --- a/conf/frpc_full.ini +++ b/conf/frpc_full.ini @@ -20,6 +20,12 @@ log_max_days = 3 # for authentication privilege_token = 12345678 +# set admin address for control frpc's action by http api such as reload +admin_addr = 127.0.0.1 +admin_port = 7400 +admin_user = admin +admin_pwd = admin + # connections will be established in advance, default value is zero pool_count = 5 @@ -110,3 +116,28 @@ remote_port = 6004 plugin = http_proxy plugin_http_user = abc plugin_http_passwd = abc + +[secret_tcp] +# If the type is secret tcp, remote_port is useless +# Who want to connect local port should deploy another frpc with stcp proxy and role is vistor +type = stcp +# sk used for authentication for vistors +sk = abcdefg +local_ip = 127.0.0.1 +local_port = 22 +use_encryption = false +use_compression = false + +# user of frpc should be same in both stcp server and stcp vistor +[secret_tcp_vistor] +# frpc role vistor -> frps -> frpc role server +role = vistor +type = stcp +# the server name you want to vistor +server_name = secret_tcp +sk = abcdefg +# connect this address to vistor stcp server +bind_addr = 127.0.0.1 +bind_port = 9000 +use_encryption = false +use_compression = false diff --git a/conf/frps_full.ini b/conf/frps_full.ini index fa728383..3b4740ed 100644 --- a/conf/frps_full.ini +++ b/conf/frps_full.ini @@ -9,11 +9,14 @@ bind_port = 7000 # if not set, kcp is disabled in frps kcp_bind_port = 7000 +# specify which address proxy will listen for, default value is same with bind_addr +# proxy_bind_addr = 127.0.0.1 + # if you want to support virtual host, you must set the http port for listening (optional) vhost_http_port = 80 vhost_https_port = 443 -# if you want to configure or reload frps by dashboard, dashboard_port must be set +# set dashboard_port to view dashboard of frps dashboard_port = 7500 # dashboard user and pwd for basic auth protect, if not set, both default value is admin diff --git a/models/config/client_common.go b/models/config/client_common.go index 8ec6cd89..749b6b13 100644 --- a/models/config/client_common.go +++ b/models/config/client_common.go @@ -36,6 +36,10 @@ type ClientCommonConf struct { LogLevel string LogMaxDays int64 PrivilegeToken string + AdminAddr string + AdminPort int64 + AdminUser string + AdminPwd string PoolCount int TcpMux bool User string @@ -57,6 +61,10 @@ func GetDeaultClientCommonConf() *ClientCommonConf { LogLevel: "info", LogMaxDays: 3, PrivilegeToken: "", + AdminAddr: "127.0.0.1", + AdminPort: 0, + AdminUser: "", + AdminPwd: "", PoolCount: 1, TcpMux: true, User: "", @@ -111,7 +119,9 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { tmpStr, ok = conf.Get("common", "log_max_days") if ok { - cfg.LogMaxDays, _ = strconv.ParseInt(tmpStr, 10, 64) + if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil { + cfg.LogMaxDays = v + } } tmpStr, ok = conf.Get("common", "privilege_token") @@ -119,6 +129,28 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { cfg.PrivilegeToken = tmpStr } + tmpStr, ok = conf.Get("common", "admin_addr") + if ok { + cfg.AdminAddr = tmpStr + } + + tmpStr, ok = conf.Get("common", "admin_port") + if ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil { + cfg.AdminPort = v + } + } + + tmpStr, ok = conf.Get("common", "admin_user") + if ok { + cfg.AdminUser = tmpStr + } + + tmpStr, ok = conf.Get("common", "admin_pwd") + if ok { + cfg.AdminPwd = tmpStr + } + tmpStr, ok = conf.Get("common", "pool_count") if ok { v, err = strconv.ParseInt(tmpStr, 10, 64) @@ -145,7 +177,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { if ok { proxyNames := strings.Split(tmpStr, ",") for _, name := range proxyNames { - cfg.Start[name] = struct{}{} + cfg.Start[strings.TrimSpace(name)] = struct{}{} } } diff --git a/models/config/proxy.go b/models/config/proxy.go index 27fd3f64..b42f416c 100644 --- a/models/config/proxy.go +++ b/models/config/proxy.go @@ -35,6 +35,7 @@ func init() { proxyConfTypeMap[consts.UdpProxy] = reflect.TypeOf(UdpProxyConf{}) proxyConfTypeMap[consts.HttpProxy] = reflect.TypeOf(HttpProxyConf{}) proxyConfTypeMap[consts.HttpsProxy] = reflect.TypeOf(HttpsProxyConf{}) + proxyConfTypeMap[consts.StcpProxy] = reflect.TypeOf(StcpProxyConf{}) } // NewConfByType creates a empty ProxyConf object by proxyType. @@ -55,6 +56,7 @@ type ProxyConf interface { LoadFromFile(name string, conf ini.Section) error UnMarshalToMsg(pMsg *msg.NewProxy) Check() error + Compare(conf ProxyConf) bool } func NewProxyConf(pMsg *msg.NewProxy) (cfg ProxyConf, err error) { @@ -104,6 +106,16 @@ func (cfg *BaseProxyConf) GetBaseInfo() *BaseProxyConf { return cfg } +func (cfg *BaseProxyConf) compare(cmp *BaseProxyConf) bool { + if cfg.ProxyName != cmp.ProxyName || + cfg.ProxyType != cmp.ProxyType || + cfg.UseEncryption != cmp.UseEncryption || + cfg.UseCompression != cmp.UseCompression { + return false + } + return true +} + func (cfg *BaseProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.ProxyName = pMsg.ProxyName cfg.ProxyType = pMsg.ProxyType @@ -148,8 +160,16 @@ type BindInfoConf struct { RemotePort int64 `json:"remote_port"` } +func (cfg *BindInfoConf) compare(cmp *BindInfoConf) bool { + if cfg.BindAddr != cmp.BindAddr || + cfg.RemotePort != cmp.RemotePort { + return false + } + return true +} + func (cfg *BindInfoConf) LoadFromMsg(pMsg *msg.NewProxy) { - cfg.BindAddr = ServerCommonCfg.BindAddr + cfg.BindAddr = ServerCommonCfg.ProxyBindAddr cfg.RemotePort = pMsg.RemotePort } @@ -187,6 +207,14 @@ type DomainConf struct { SubDomain string `json:"sub_domain"` } +func (cfg *DomainConf) compare(cmp *DomainConf) bool { + if strings.Join(cfg.CustomDomains, " ") != strings.Join(cmp.CustomDomains, " ") || + cfg.SubDomain != cmp.SubDomain { + return false + } + return true +} + func (cfg *DomainConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.CustomDomains = pMsg.CustomDomains cfg.SubDomain = pMsg.SubDomain @@ -245,6 +273,14 @@ type LocalSvrConf struct { LocalPort int `json:"-"` } +func (cfg *LocalSvrConf) compare(cmp *LocalSvrConf) bool { + if cfg.LocalIp != cmp.LocalIp || + cfg.LocalPort != cmp.LocalPort { + return false + } + return true +} + func (cfg *LocalSvrConf) LoadFromFile(name string, section ini.Section) (err error) { if cfg.LocalIp = section["local_ip"]; cfg.LocalIp == "" { cfg.LocalIp = "127.0.0.1" @@ -265,6 +301,20 @@ type PluginConf struct { PluginParams map[string]string `json:"-"` } +func (cfg *PluginConf) compare(cmp *PluginConf) bool { + if cfg.Plugin != cmp.Plugin || + len(cfg.PluginParams) != len(cmp.PluginParams) { + return false + } + for k, v := range cfg.PluginParams { + value, ok := cmp.PluginParams[k] + if !ok || v != value { + return false + } + } + return true +} + func (cfg *PluginConf) LoadFromFile(name string, section ini.Section) (err error) { cfg.Plugin = section["plugin"] cfg.PluginParams = make(map[string]string) @@ -290,6 +340,21 @@ type TcpProxyConf struct { PluginConf } +func (cfg *TcpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*TcpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.BindInfoConf.compare(&cmpConf.BindInfoConf) || + !cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) || + !cfg.PluginConf.compare(&cmpConf.PluginConf) { + return false + } + return true +} + func (cfg *TcpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.BindInfoConf.LoadFromMsg(pMsg) @@ -329,6 +394,20 @@ type UdpProxyConf struct { LocalSvrConf } +func (cfg *UdpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*UdpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.BindInfoConf.compare(&cmpConf.BindInfoConf) || + !cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) { + return false + } + return true +} + func (cfg *UdpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.BindInfoConf.LoadFromMsg(pMsg) @@ -371,6 +450,25 @@ type HttpProxyConf struct { HttpPwd string `json:"-"` } +func (cfg *HttpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*HttpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.DomainConf.compare(&cmpConf.DomainConf) || + !cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) || + !cfg.PluginConf.compare(&cmpConf.PluginConf) || + strings.Join(cfg.Locations, " ") != strings.Join(cmpConf.Locations, " ") || + cfg.HostHeaderRewrite != cmpConf.HostHeaderRewrite || + cfg.HttpUser != cmpConf.HttpUser || + cfg.HttpPwd != cmpConf.HttpPwd { + return false + } + return true +} + func (cfg *HttpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.DomainConf.LoadFromMsg(pMsg) @@ -388,8 +486,10 @@ func (cfg *HttpProxyConf) LoadFromFile(name string, section ini.Section) (err er if err = cfg.DomainConf.LoadFromFile(name, section); err != nil { return } - if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil { - return + if err = cfg.PluginConf.LoadFromFile(name, section); err != nil { + if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil { + return + } } var ( @@ -435,6 +535,21 @@ type HttpsProxyConf struct { PluginConf } +func (cfg *HttpsProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*HttpsProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.DomainConf.compare(&cmpConf.DomainConf) || + !cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) || + !cfg.PluginConf.compare(&cmpConf.PluginConf) { + return false + } + return true +} + func (cfg *HttpsProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.DomainConf.LoadFromMsg(pMsg) @@ -447,8 +562,10 @@ func (cfg *HttpsProxyConf) LoadFromFile(name string, section ini.Section) (err e if err = cfg.DomainConf.LoadFromFile(name, section); err != nil { return } - if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil { - return + if err = cfg.PluginConf.LoadFromFile(name, section); err != nil { + if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil { + return + } } return } @@ -466,9 +583,100 @@ func (cfg *HttpsProxyConf) Check() (err error) { return } +// STCP +type StcpProxyConf struct { + BaseProxyConf + + Role string `json:"role"` + Sk string `json:"sk"` + + // used in role server + LocalSvrConf + PluginConf + + // used in role vistor + ServerName string `json:"server_name"` + BindAddr string `json:"bind_addr"` + BindPort int `json:"bind_port"` +} + +func (cfg *StcpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*StcpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) || + !cfg.PluginConf.compare(&cmpConf.PluginConf) || + cfg.Role != cmpConf.Role || + cfg.Sk != cmpConf.Sk || + cfg.ServerName != cmpConf.ServerName || + cfg.BindAddr != cmpConf.BindAddr || + cfg.BindPort != cmpConf.BindPort { + return false + } + return true +} + +// Only for role server. +func (cfg *StcpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.LoadFromMsg(pMsg) + cfg.Sk = pMsg.Sk +} + +func (cfg *StcpProxyConf) LoadFromFile(name string, section ini.Section) (err error) { + if err = cfg.BaseProxyConf.LoadFromFile(name, section); err != nil { + return + } + + tmpStr := section["role"] + if tmpStr == "server" || tmpStr == "vistor" { + cfg.Role = tmpStr + } else { + cfg.Role = "server" + } + + cfg.Sk = section["sk"] + + if tmpStr == "vistor" { + prefix := section["prefix"] + cfg.ServerName = prefix + section["server_name"] + if cfg.BindAddr = section["bind_addr"]; cfg.BindAddr == "" { + cfg.BindAddr = "127.0.0.1" + } + + if tmpStr, ok := section["bind_port"]; ok { + if cfg.BindPort, err = strconv.Atoi(tmpStr); err != nil { + return fmt.Errorf("Parse conf error: proxy [%s] bind_port error", name) + } + } else { + return fmt.Errorf("Parse conf error: proxy [%s] bind_port not found", name) + } + } else { + if err = cfg.PluginConf.LoadFromFile(name, section); err != nil { + if err = cfg.LocalSvrConf.LoadFromFile(name, section); err != nil { + return + } + } + } + return +} + +func (cfg *StcpProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.UnMarshalToMsg(pMsg) + pMsg.Sk = cfg.Sk +} + +func (cfg *StcpProxyConf) Check() (err error) { + return +} + // if len(startProxy) is 0, start all // otherwise just start proxies in startProxy map -func LoadProxyConfFromFile(prefix string, conf ini.File, startProxy map[string]struct{}) (proxyConfs map[string]ProxyConf, err error) { +func LoadProxyConfFromFile(prefix string, conf ini.File, startProxy map[string]struct{}) ( + proxyConfs map[string]ProxyConf, vistorConfs map[string]ProxyConf, err error) { + if prefix != "" { prefix += "." } @@ -478,14 +686,23 @@ func LoadProxyConfFromFile(prefix string, conf ini.File, startProxy map[string]s startAll = false } proxyConfs = make(map[string]ProxyConf) + vistorConfs = make(map[string]ProxyConf) for name, section := range conf { _, shouldStart := startProxy[name] if name != "common" && (startAll || shouldStart) { + // some proxy or visotr configure may be used this prefix + section["prefix"] = prefix cfg, err := NewProxyConfFromFile(name, section) if err != nil { - return proxyConfs, err + return proxyConfs, vistorConfs, err + } + + role := section["role"] + if role == "vistor" { + vistorConfs[prefix+name] = cfg + } else { + proxyConfs[prefix+name] = cfg } - proxyConfs[prefix+name] = cfg } } return diff --git a/models/config/server_common.go b/models/config/server_common.go index 1bf1f256..1795a2a5 100644 --- a/models/config/server_common.go +++ b/models/config/server_common.go @@ -27,10 +27,11 @@ var ServerCommonCfg *ServerCommonConf // common config type ServerCommonConf struct { - ConfigFile string - BindAddr string - BindPort int64 - KcpBindPort int64 + ConfigFile string + BindAddr string + BindPort int64 + KcpBindPort int64 + ProxyBindAddr string // If VhostHttpPort equals 0, don't listen a public port for http protocol. VhostHttpPort int64 @@ -66,6 +67,7 @@ func GetDefaultServerCommonConf() *ServerCommonConf { BindAddr: "0.0.0.0", BindPort: 7000, KcpBindPort: 0, + ProxyBindAddr: "0.0.0.0", VhostHttpPort: 0, VhostHttpsPort: 0, DashboardPort: 0, @@ -117,6 +119,13 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { } } + tmpStr, ok = conf.Get("common", "proxy_bind_addr") + if ok { + cfg.ProxyBindAddr = tmpStr + } else { + cfg.ProxyBindAddr = cfg.BindAddr + } + tmpStr, ok = conf.Get("common", "vhost_http_port") if ok { cfg.VhostHttpPort, err = strconv.ParseInt(tmpStr, 10, 64) diff --git a/models/consts/consts.go b/models/consts/consts.go index 170bd240..5a4bc264 100644 --- a/models/consts/consts.go +++ b/models/consts/consts.go @@ -27,4 +27,5 @@ var ( UdpProxy string = "udp" HttpProxy string = "http" HttpsProxy string = "https" + StcpProxy string = "stcp" ) diff --git a/models/msg/msg.go b/models/msg/msg.go index d961befa..59736b6d 100644 --- a/models/msg/msg.go +++ b/models/msg/msg.go @@ -20,17 +20,19 @@ import ( ) const ( - TypeLogin = 'o' - TypeLoginResp = '1' - TypeNewProxy = 'p' - TypeNewProxyResp = '2' - TypeCloseProxy = 'c' - TypeNewWorkConn = 'w' - TypeReqWorkConn = 'r' - TypeStartWorkConn = 's' - TypePing = 'h' - TypePong = '4' - TypeUdpPacket = 'u' + TypeLogin = 'o' + TypeLoginResp = '1' + TypeNewProxy = 'p' + TypeNewProxyResp = '2' + TypeCloseProxy = 'c' + TypeNewWorkConn = 'w' + TypeReqWorkConn = 'r' + TypeStartWorkConn = 's' + TypeNewVistorConn = 'v' + TypeNewVistorConnResp = '3' + TypePing = 'h' + TypePong = '4' + TypeUdpPacket = 'u' ) var ( @@ -50,6 +52,8 @@ func init() { TypeMap[TypeNewWorkConn] = reflect.TypeOf(NewWorkConn{}) TypeMap[TypeReqWorkConn] = reflect.TypeOf(ReqWorkConn{}) TypeMap[TypeStartWorkConn] = reflect.TypeOf(StartWorkConn{}) + TypeMap[TypeNewVistorConn] = reflect.TypeOf(NewVistorConn{}) + TypeMap[TypeNewVistorConnResp] = reflect.TypeOf(NewVistorConnResp{}) TypeMap[TypePing] = reflect.TypeOf(Ping{}) TypeMap[TypePong] = reflect.TypeOf(Pong{}) TypeMap[TypeUdpPacket] = reflect.TypeOf(UdpPacket{}) @@ -100,6 +104,9 @@ type NewProxy struct { HostHeaderRewrite string `json:"host_header_rewrite"` HttpUser string `json:"http_user"` HttpPwd string `json:"http_pwd"` + + // stcp + Sk string `json:"sk"` } type NewProxyResp struct { @@ -122,6 +129,19 @@ type StartWorkConn struct { ProxyName string `json:"proxy_name"` } +type NewVistorConn struct { + ProxyName string `json:"proxy_name"` + SignKey string `json:"sign_key"` + Timestamp int64 `json:"timestamp"` + UseEncryption bool `json:"use_encryption"` + UseCompression bool `json:"use_compression"` +} + +type NewVistorConnResp struct { + ProxyName string `json:"proxy_name"` + Error string `json:"error"` +} + type Ping struct { } diff --git a/models/plugin/socks5.go b/models/plugin/socks5.go new file mode 100644 index 00000000..d3b82e12 --- /dev/null +++ b/models/plugin/socks5.go @@ -0,0 +1,65 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "io" + "io/ioutil" + "log" + + frpNet "github.com/fatedier/frp/utils/net" + + gosocks5 "github.com/armon/go-socks5" +) + +const PluginSocks5 = "socks5" + +func init() { + Register(PluginSocks5, NewSocks5Plugin) +} + +type Socks5Plugin struct { + Server *gosocks5.Server +} + +func NewSocks5Plugin(params map[string]string) (p Plugin, err error) { + sp := &Socks5Plugin{} + sp.Server, err = gosocks5.New(&gosocks5.Config{ + Logger: log.New(ioutil.Discard, "", log.LstdFlags), + }) + p = sp + return +} + +func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser) { + defer conn.Close() + + var wrapConn frpNet.Conn + if realConn, ok := conn.(frpNet.Conn); ok { + wrapConn = realConn + } else { + wrapConn = frpNet.WrapReadWriteCloserToConn(conn) + } + + sp.Server.ServeConn(wrapConn) +} + +func (sp *Socks5Plugin) Name() string { + return PluginSocks5 +} + +func (sp *Socks5Plugin) Close() error { + return nil +} diff --git a/server/control.go b/server/control.go index d6b2c2c6..5a84394a 100644 --- a/server/control.go +++ b/server/control.go @@ -378,6 +378,7 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) { pxy.Close() ctl.svr.DelProxy(pxy.GetName()) + delete(ctl.proxies, closeMsg.ProxyName) StatsCloseProxy(pxy.GetName(), pxy.GetConf().GetBaseInfo().ProxyType) return } diff --git a/server/dashboard.go b/server/dashboard.go index 84eac81d..01f71591 100644 --- a/server/dashboard.go +++ b/server/dashboard.go @@ -15,16 +15,14 @@ package server import ( - "compress/gzip" "fmt" - "io" "net" "net/http" - "strings" "time" "github.com/fatedier/frp/assets" "github.com/fatedier/frp/models/config" + frpNet "github.com/fatedier/frp/utils/net" "github.com/julienschmidt/httprouter" ) @@ -38,20 +36,24 @@ func RunDashboardServer(addr string, port int64) (err error) { // url router router := httprouter.New() + user, passwd := config.ServerCommonCfg.DashboardUser, config.ServerCommonCfg.DashboardPwd + // api, see dashboard_api.go - router.GET("/api/serverinfo", httprouterBasicAuth(apiServerInfo)) - router.GET("/api/proxy/tcp", httprouterBasicAuth(apiProxyTcp)) - router.GET("/api/proxy/udp", httprouterBasicAuth(apiProxyUdp)) - router.GET("/api/proxy/http", httprouterBasicAuth(apiProxyHttp)) - router.GET("/api/proxy/https", httprouterBasicAuth(apiProxyHttps)) - router.GET("/api/proxy/traffic/:name", httprouterBasicAuth(apiProxyTraffic)) + router.GET("/api/serverinfo", frpNet.HttprouterBasicAuth(apiServerInfo, user, passwd)) + router.GET("/api/proxy/tcp", frpNet.HttprouterBasicAuth(apiProxyTcp, user, passwd)) + router.GET("/api/proxy/udp", frpNet.HttprouterBasicAuth(apiProxyUdp, user, passwd)) + router.GET("/api/proxy/http", frpNet.HttprouterBasicAuth(apiProxyHttp, user, passwd)) + router.GET("/api/proxy/https", frpNet.HttprouterBasicAuth(apiProxyHttps, user, passwd)) + router.GET("/api/proxy/traffic/:name", frpNet.HttprouterBasicAuth(apiProxyTraffic, user, passwd)) // view router.Handler("GET", "/favicon.ico", http.FileServer(assets.FileSystem)) - router.Handler("GET", "/static/*filepath", MakeGzipHandler(basicAuthWraper(http.StripPrefix("/static/", http.FileServer(assets.FileSystem))))) - router.HandlerFunc("GET", "/", basicAuth(func(w http.ResponseWriter, r *http.Request) { + router.Handler("GET", "/static/*filepath", frpNet.MakeHttpGzipHandler( + frpNet.NewHttpBasicAuthWraper(http.StripPrefix("/static/", http.FileServer(assets.FileSystem)), user, passwd))) + + router.HandlerFunc("GET", "/", frpNet.HttpBasicAuth(func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/static/", http.StatusMovedPermanently) - })) + }, user, passwd)) address := fmt.Sprintf("%s:%d", addr, port) server := &http.Server{ @@ -71,91 +73,3 @@ func RunDashboardServer(addr string, port int64) (err error) { go server.Serve(ln) return } - -func use(h http.HandlerFunc, middleware ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc { - for _, m := range middleware { - h = m(h) - } - return h -} - -type AuthWraper struct { - h http.Handler - user string - passwd string -} - -func (aw *AuthWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { - user, passwd, hasAuth := r.BasicAuth() - if (aw.user == "" && aw.passwd == "") || (hasAuth && user == aw.user && passwd == aw.passwd) { - aw.h.ServeHTTP(w, r) - } else { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - } -} - -func basicAuthWraper(h http.Handler) http.Handler { - return &AuthWraper{ - h: h, - user: config.ServerCommonCfg.DashboardUser, - passwd: config.ServerCommonCfg.DashboardPwd, - } -} - -func basicAuth(h http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - user, passwd, hasAuth := r.BasicAuth() - if (config.ServerCommonCfg.DashboardUser == "" && config.ServerCommonCfg.DashboardPwd == "") || - (hasAuth && user == config.ServerCommonCfg.DashboardUser && passwd == config.ServerCommonCfg.DashboardPwd) { - h.ServeHTTP(w, r) - } else { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - } - } -} - -func httprouterBasicAuth(h httprouter.Handle) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - user, passwd, hasAuth := r.BasicAuth() - if (config.ServerCommonCfg.DashboardUser == "" && config.ServerCommonCfg.DashboardPwd == "") || - (hasAuth && user == config.ServerCommonCfg.DashboardUser && passwd == config.ServerCommonCfg.DashboardPwd) { - h(w, r, ps) - } else { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - } - } -} - -type GzipWraper struct { - h http.Handler -} - -func (gw *GzipWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { - gw.h.ServeHTTP(w, r) - return - } - w.Header().Set("Content-Encoding", "gzip") - gz := gzip.NewWriter(w) - defer gz.Close() - gzr := gzipResponseWriter{Writer: gz, ResponseWriter: w} - gw.h.ServeHTTP(gzr, r) -} - -func MakeGzipHandler(h http.Handler) http.Handler { - return &GzipWraper{ - h: h, - } -} - -type gzipResponseWriter struct { - io.Writer - http.ResponseWriter -} - -func (w gzipResponseWriter) Write(b []byte) (int, error) { - return w.Writer.Write(b) -} diff --git a/server/manager.go b/server/manager.go index 47456c00..c78037dd 100644 --- a/server/manager.go +++ b/server/manager.go @@ -16,7 +16,12 @@ package server import ( "fmt" + "io" "sync" + + frpIo "github.com/fatedier/frp/utils/io" + frpNet "github.com/fatedier/frp/utils/net" + "github.com/fatedier/frp/utils/util" ) type ControlManager struct { @@ -87,3 +92,72 @@ func (pm *ProxyManager) GetByName(name string) (pxy Proxy, ok bool) { pxy, ok = pm.pxys[name] return } + +// Manager for vistor listeners. +type VistorManager struct { + vistorListeners map[string]*frpNet.CustomListener + skMap map[string]string + + mu sync.RWMutex +} + +func NewVistorManager() *VistorManager { + return &VistorManager{ + vistorListeners: make(map[string]*frpNet.CustomListener), + skMap: make(map[string]string), + } +} + +func (vm *VistorManager) Listen(name string, sk string) (l *frpNet.CustomListener, err error) { + vm.mu.Lock() + defer vm.mu.Unlock() + + if _, ok := vm.vistorListeners[name]; ok { + err = fmt.Errorf("custom listener for [%s] is repeated", name) + return + } + + l = frpNet.NewCustomListener() + vm.vistorListeners[name] = l + vm.skMap[name] = sk + return +} + +func (vm *VistorManager) NewConn(name string, conn frpNet.Conn, timestamp int64, signKey string, + useEncryption bool, useCompression bool) (err error) { + + vm.mu.RLock() + defer vm.mu.RUnlock() + + if l, ok := vm.vistorListeners[name]; ok { + var sk string + if sk = vm.skMap[name]; util.GetAuthKey(sk, timestamp) != signKey { + err = fmt.Errorf("vistor connection of [%s] auth failed", name) + return + } + + var rwc io.ReadWriteCloser = conn + if useEncryption { + if rwc, err = frpIo.WithEncryption(rwc, []byte(sk)); err != nil { + err = fmt.Errorf("create encryption connection failed: %v", err) + return + } + } + if useCompression { + rwc = frpIo.WithCompression(rwc) + } + err = l.PutConn(frpNet.WrapReadWriteCloserToConn(rwc)) + } else { + err = fmt.Errorf("custom listener for [%s] doesn't exist", name) + return + } + return +} + +func (vm *VistorManager) CloseListener(name string) { + vm.mu.Lock() + defer vm.mu.Unlock() + + delete(vm.vistorListeners, name) + delete(vm.skMap, name) +} diff --git a/server/proxy.go b/server/proxy.go index 6751e245..ecc58751 100644 --- a/server/proxy.go +++ b/server/proxy.go @@ -143,6 +143,11 @@ func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy, err error) { BaseProxy: basePxy, cfg: cfg, } + case *config.StcpProxyConf: + pxy = &StcpProxy{ + BaseProxy: basePxy, + cfg: cfg, + } default: return pxy, fmt.Errorf("proxy type not support") } @@ -156,7 +161,7 @@ type TcpProxy struct { } func (pxy *TcpProxy) Run() error { - listener, err := frpNet.ListenTcp(config.ServerCommonCfg.BindAddr, pxy.cfg.RemotePort) + listener, err := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort) if err != nil { return err } @@ -274,6 +279,33 @@ func (pxy *HttpsProxy) Close() { pxy.BaseProxy.Close() } +type StcpProxy struct { + BaseProxy + cfg *config.StcpProxyConf +} + +func (pxy *StcpProxy) Run() error { + listener, err := pxy.ctl.svr.vistorManager.Listen(pxy.GetName(), pxy.cfg.Sk) + if err != nil { + return err + } + listener.AddLogPrefix(pxy.name) + pxy.listeners = append(pxy.listeners, listener) + pxy.Info("stcp proxy custom listen success") + + pxy.startListenHandler(pxy, HandleUserTcpConnection) + return nil +} + +func (pxy *StcpProxy) GetConf() config.ProxyConf { + return pxy.cfg +} + +func (pxy *StcpProxy) Close() { + pxy.BaseProxy.Close() + pxy.ctl.svr.vistorManager.CloseListener(pxy.GetName()) +} + type UdpProxy struct { BaseProxy cfg *config.UdpProxyConf @@ -298,7 +330,7 @@ type UdpProxy struct { } func (pxy *UdpProxy) Run() (err error) { - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.BindAddr, pxy.cfg.RemotePort)) + addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort)) if err != nil { return err } diff --git a/server/service.go b/server/service.go index 6583b1c6..e06447dc 100644 --- a/server/service.go +++ b/server/service.go @@ -55,12 +55,16 @@ type Service struct { // Manage all proxies. pxyManager *ProxyManager + + // Manage all vistor listeners. + vistorManager *VistorManager } func NewService() (svr *Service, err error) { svr = &Service{ - ctlManager: NewControlManager(), - pxyManager: NewProxyManager(), + ctlManager: NewControlManager(), + pxyManager: NewProxyManager(), + vistorManager: NewVistorManager(), } // Init assets. @@ -91,7 +95,7 @@ func NewService() (svr *Service, err error) { // Create http vhost muxer. if config.ServerCommonCfg.VhostHttpPort > 0 { var l frpNet.Listener - l, err = frpNet.ListenTcp(config.ServerCommonCfg.BindAddr, config.ServerCommonCfg.VhostHttpPort) + l, err = frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, config.ServerCommonCfg.VhostHttpPort) if err != nil { err = fmt.Errorf("Create vhost http listener error, %v", err) return @@ -101,13 +105,13 @@ func NewService() (svr *Service, err error) { err = fmt.Errorf("Create vhost httpMuxer error, %v", err) return } - log.Info("http service listen on %s:%d", config.ServerCommonCfg.BindAddr, config.ServerCommonCfg.VhostHttpPort) + log.Info("http service listen on %s:%d", config.ServerCommonCfg.ProxyBindAddr, config.ServerCommonCfg.VhostHttpPort) } // Create https vhost muxer. if config.ServerCommonCfg.VhostHttpsPort > 0 { var l frpNet.Listener - l, err = frpNet.ListenTcp(config.ServerCommonCfg.BindAddr, config.ServerCommonCfg.VhostHttpsPort) + l, err = frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, config.ServerCommonCfg.VhostHttpsPort) if err != nil { err = fmt.Errorf("Create vhost https listener error, %v", err) return @@ -117,7 +121,7 @@ func NewService() (svr *Service, err error) { err = fmt.Errorf("Create vhost httpsMuxer error, %v", err) return } - log.Info("https service listen on %s:%d", config.ServerCommonCfg.BindAddr, config.ServerCommonCfg.VhostHttpsPort) + log.Info("https service listen on %s:%d", config.ServerCommonCfg.ProxyBindAddr, config.ServerCommonCfg.VhostHttpsPort) } // Create dashboard web server. @@ -176,6 +180,20 @@ func (svr *Service) HandleListener(l frpNet.Listener) { } case *msg.NewWorkConn: svr.RegisterWorkConn(conn, m) + case *msg.NewVistorConn: + if err = svr.RegisterVistorConn(conn, m); err != nil { + conn.Warn("%v", err) + msg.WriteMsg(conn, &msg.NewVistorConnResp{ + ProxyName: m.ProxyName, + Error: err.Error(), + }) + conn.Close() + } else { + msg.WriteMsg(conn, &msg.NewVistorConnResp{ + ProxyName: m.ProxyName, + Error: "", + }) + } default: log.Warn("Error message type for the new connection [%s]", conn.RemoteAddr().String()) conn.Close() @@ -262,9 +280,13 @@ func (svr *Service) RegisterWorkConn(workConn frpNet.Conn, newMsg *msg.NewWorkCo return } +func (svr *Service) RegisterVistorConn(vistorConn frpNet.Conn, newMsg *msg.NewVistorConn) error { + return svr.vistorManager.NewConn(newMsg.ProxyName, vistorConn, newMsg.Timestamp, newMsg.SignKey, + newMsg.UseEncryption, newMsg.UseCompression) +} + func (svr *Service) RegisterProxy(name string, pxy Proxy) error { - err := svr.pxyManager.Add(name, pxy) - return err + return svr.pxyManager.Add(name, pxy) } func (svr *Service) DelProxy(name string) { diff --git a/utils/log/log.go b/utils/log/log.go index ec6e0775..a0e42b8f 100644 --- a/utils/log/log.go +++ b/utils/log/log.go @@ -88,6 +88,7 @@ func Trace(format string, v ...interface{}) { // Logger type Logger interface { AddLogPrefix(string) + GetPrefixStr() string GetAllPrefix() []string ClearLogPrefix() Error(string, ...interface{}) @@ -119,6 +120,10 @@ func (pl *PrefixLogger) AddLogPrefix(prefix string) { pl.allPrefix = append(pl.allPrefix, prefix) } +func (pl *PrefixLogger) GetPrefixStr() string { + return pl.prefix +} + func (pl *PrefixLogger) GetAllPrefix() []string { return pl.allPrefix } diff --git a/utils/net/http.go b/utils/net/http.go new file mode 100644 index 00000000..acc0f43e --- /dev/null +++ b/utils/net/http.go @@ -0,0 +1,105 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "compress/gzip" + "io" + "net/http" + "strings" + + "github.com/julienschmidt/httprouter" +) + +type HttpAuthWraper struct { + h http.Handler + user string + passwd string +} + +func NewHttpBasicAuthWraper(h http.Handler, user, passwd string) http.Handler { + return &HttpAuthWraper{ + h: h, + user: user, + passwd: passwd, + } +} + +func (aw *HttpAuthWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { + user, passwd, hasAuth := r.BasicAuth() + if (aw.user == "" && aw.passwd == "") || (hasAuth && user == aw.user && passwd == aw.passwd) { + aw.h.ServeHTTP(w, r) + } else { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + } +} + +func HttpBasicAuth(h http.HandlerFunc, user, passwd string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + reqUser, reqPasswd, hasAuth := r.BasicAuth() + if (user == "" && passwd == "") || + (hasAuth && reqUser == user && reqPasswd == passwd) { + h.ServeHTTP(w, r) + } else { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + } + } +} + +func HttprouterBasicAuth(h httprouter.Handle, user, passwd string) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + reqUser, reqPasswd, hasAuth := r.BasicAuth() + if (user == "" && passwd == "") || + (hasAuth && reqUser == user && reqPasswd == passwd) { + h(w, r, ps) + } else { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + } + } +} + +type HttpGzipWraper struct { + h http.Handler +} + +func (gw *HttpGzipWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + gw.h.ServeHTTP(w, r) + return + } + w.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(w) + defer gz.Close() + gzr := gzipResponseWriter{Writer: gz, ResponseWriter: w} + gw.h.ServeHTTP(gzr, r) +} + +func MakeHttpGzipHandler(h http.Handler) http.Handler { + return &HttpGzipWraper{ + h: h, + } +} + +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter +} + +func (w gzipResponseWriter) Write(b []byte) (int, error) { + return w.Writer.Write(b) +} diff --git a/utils/net/listener.go b/utils/net/listener.go index f9345b92..cb3847e5 100644 --- a/utils/net/listener.go +++ b/utils/net/listener.go @@ -15,8 +15,11 @@ package net import ( + "fmt" "net" + "sync" + "github.com/fatedier/frp/utils/errors" "github.com/fatedier/frp/utils/log" ) @@ -44,3 +47,53 @@ func (logL *LogListener) Accept() (Conn, error) { c, err := logL.l.Accept() return WrapConn(c), err } + +// Custom listener +type CustomListener struct { + conns chan Conn + closed bool + mu sync.Mutex + + log.Logger +} + +func NewCustomListener() *CustomListener { + return &CustomListener{ + conns: make(chan Conn, 64), + Logger: log.NewPrefixLogger(""), + } +} + +func (l *CustomListener) Accept() (Conn, error) { + conn, ok := <-l.conns + if !ok { + return nil, fmt.Errorf("listener closed") + } + conn.AddLogPrefix(l.GetPrefixStr()) + return conn, nil +} + +func (l *CustomListener) PutConn(conn Conn) error { + err := errors.PanicToError(func() { + select { + case l.conns <- conn: + default: + conn.Close() + } + }) + return err +} + +func (l *CustomListener) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + if !l.closed { + close(l.conns) + l.closed = true + } + return nil +} + +func (l *CustomListener) Addr() net.Addr { + return (*net.TCPAddr)(nil) +} diff --git a/utils/version/version.go b/utils/version/version.go index 1a1e749f..b631dba7 100644 --- a/utils/version/version.go +++ b/utils/version/version.go @@ -19,7 +19,7 @@ import ( "strings" ) -var version string = "0.12.0" +var version string = "0.13.0" func Full() string { return version diff --git a/utils/vhost/http.go b/utils/vhost/http.go index 2e8208cf..5ecee78c 100644 --- a/utils/vhost/http.go +++ b/utils/vhost/http.go @@ -57,30 +57,35 @@ func GetHttpRequestInfo(c frpNet.Conn) (_ frpNet.Conn, _ map[string]string, err } func NewHttpMuxer(listener frpNet.Listener, timeout time.Duration) (*HttpMuxer, error) { - mux, err := NewVhostMuxer(listener, GetHttpRequestInfo, HttpAuthFunc, HttpHostNameRewrite, timeout) + mux, err := NewVhostMuxer(listener, GetHttpRequestInfo, HttpAuthFunc, ModifyHttpRequest, timeout) return &HttpMuxer{mux}, err } -func HttpHostNameRewrite(c frpNet.Conn, rewriteHost string) (_ frpNet.Conn, err error) { +func ModifyHttpRequest(c frpNet.Conn, rewriteHost string) (_ frpNet.Conn, err error) { sc, rd := frpNet.NewShareConn(c) var buff []byte - if buff, err = hostNameRewrite(rd, rewriteHost); err != nil { + remoteIP := strings.Split(c.RemoteAddr().String(), ":")[0] + if buff, err = hostNameRewrite(rd, rewriteHost, remoteIP); err != nil { return sc, err } err = sc.WriteBuff(buff) return sc, err } -func hostNameRewrite(request io.Reader, rewriteHost string) (_ []byte, err error) { +func hostNameRewrite(request io.Reader, rewriteHost string, remoteIP string) (_ []byte, err error) { buf := pool.GetBuf(1024) defer pool.PutBuf(buf) - request.Read(buf) - retBuffer, err := parseRequest(buf, rewriteHost) + var n int + n, err = request.Read(buf) + if err != nil { + return + } + retBuffer, err := parseRequest(buf[:n], rewriteHost, remoteIP) return retBuffer, err } -func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) { +func parseRequest(org []byte, rewriteHost string, remoteIP string) (ret []byte, err error) { tp := bytes.NewBuffer(org) // First line: GET /index.html HTTP/1.0 var b []byte @@ -106,10 +111,19 @@ func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) { // GET /index.html HTTP/1.1 // Host: www.google.com if req.URL.Host == "" { - changedBuf, err := changeHostName(tp, rewriteHost) + var changedBuf []byte + if rewriteHost != "" { + changedBuf, err = changeHostName(tp, rewriteHost) + } buf := new(bytes.Buffer) buf.Write(b) - buf.Write(changedBuf) + buf.WriteString(fmt.Sprintf("X-Forwarded-For: %s\r\n", remoteIP)) + buf.WriteString(fmt.Sprintf("X-Real-IP: %s\r\n", remoteIP)) + if len(changedBuf) == 0 { + tp.WriteTo(buf) + } else { + buf.Write(changedBuf) + } return buf.Bytes(), err } @@ -117,18 +131,21 @@ func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) { // GET http://www.google.com/index.html HTTP/1.1 // Host: doesntmatter // In this case, any Host line is ignored. - hostPort := strings.Split(req.URL.Host, ":") - if len(hostPort) == 1 { - req.URL.Host = rewriteHost - } else if len(hostPort) == 2 { - req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1]) + if rewriteHost != "" { + hostPort := strings.Split(req.URL.Host, ":") + if len(hostPort) == 1 { + req.URL.Host = rewriteHost + } else if len(hostPort) == 2 { + req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1]) + } } firstLine := req.Method + " " + req.URL.String() + " " + req.Proto buf := new(bytes.Buffer) buf.WriteString(firstLine) + buf.WriteString(fmt.Sprintf("X-Forwarded-For: %s\r\n", remoteIP)) + buf.WriteString(fmt.Sprintf("X-Real-IP: %s\r\n", remoteIP)) tp.WriteTo(buf) return buf.Bytes(), err - } // parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. @@ -162,9 +179,9 @@ func changeHostName(buff *bytes.Buffer, rewriteHost string) (_ []byte, err error var hostHeader string portPos := bytes.IndexByte(kv[j+1:], ':') if portPos == -1 { - hostHeader = fmt.Sprintf("Host: %s\n", rewriteHost) + hostHeader = fmt.Sprintf("Host: %s\r\n", rewriteHost) } else { - hostHeader = fmt.Sprintf("Host: %s:%s\n", rewriteHost, kv[portPos+1:]) + hostHeader = fmt.Sprintf("Host: %s:%s\r\n", rewriteHost, kv[j+portPos+2:]) } retBuf.WriteString(hostHeader) peek = peek[i+1:] diff --git a/utils/vhost/resource.go b/utils/vhost/resource.go index 6189b35a..ed8149c6 100644 --- a/utils/vhost/resource.go +++ b/utils/vhost/resource.go @@ -49,9 +49,10 @@ Please try again later.

func notFoundResponse() *http.Response { header := make(http.Header) header.Set("server", "frp/"+version.Full()) + header.Set("Content-Type", "text/html") res := &http.Response{ Status: "Not Found", - StatusCode: 400, + StatusCode: 404, Proto: "HTTP/1.0", ProtoMajor: 1, ProtoMinor: 0, diff --git a/utils/vhost/vhost.go b/utils/vhost/vhost.go index 21771d04..bb2b4ad5 100644 --- a/utils/vhost/vhost.go +++ b/utils/vhost/vhost.go @@ -182,9 +182,10 @@ func (l *Listener) Accept() (frpNet.Conn, error) { return nil, fmt.Errorf("Listener closed") } - // if rewriteFunc is exist and rewriteHost is set + // if rewriteFunc is exist // rewrite http requests with a modified host header - if l.mux.rewriteFunc != nil && l.rewriteHost != "" { + // if l.rewriteHost is empty, nothing to do + if l.mux.rewriteFunc != nil { sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost) if err != nil { l.Warn("host header rewrite failed: %v", err) diff --git a/vendor/github.com/armon/go-socks5/.gitignore b/vendor/github.com/armon/go-socks5/.gitignore new file mode 100644 index 00000000..00268614 --- /dev/null +++ b/vendor/github.com/armon/go-socks5/.gitignore @@ -0,0 +1,22 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/vendor/github.com/armon/go-socks5/.travis.yml b/vendor/github.com/armon/go-socks5/.travis.yml new file mode 100644 index 00000000..8d61700e --- /dev/null +++ b/vendor/github.com/armon/go-socks5/.travis.yml @@ -0,0 +1,4 @@ +language: go +go: + - 1.1 + - tip diff --git a/vendor/github.com/armon/go-socks5/LICENSE b/vendor/github.com/armon/go-socks5/LICENSE new file mode 100644 index 00000000..a5df10e6 --- /dev/null +++ b/vendor/github.com/armon/go-socks5/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2014 Armon Dadgar + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/armon/go-socks5/README.md b/vendor/github.com/armon/go-socks5/README.md new file mode 100644 index 00000000..9cd15635 --- /dev/null +++ b/vendor/github.com/armon/go-socks5/README.md @@ -0,0 +1,45 @@ +go-socks5 [![Build Status](https://travis-ci.org/armon/go-socks5.png)](https://travis-ci.org/armon/go-socks5) +========= + +Provides the `socks5` package that implements a [SOCKS5 server](http://en.wikipedia.org/wiki/SOCKS). +SOCKS (Secure Sockets) is used to route traffic between a client and server through +an intermediate proxy layer. This can be used to bypass firewalls or NATs. + +Feature +======= + +The package has the following features: +* "No Auth" mode +* User/Password authentication +* Support for the CONNECT command +* Rules to do granular filtering of commands +* Custom DNS resolution +* Unit tests + +TODO +==== + +The package still needs the following: +* Support for the BIND command +* Support for the ASSOCIATE command + + +Example +======= + +Below is a simple example of usage + +```go +// Create a SOCKS5 server +conf := &socks5.Config{} +server, err := socks5.New(conf) +if err != nil { + panic(err) +} + +// Create SOCKS5 proxy on localhost port 8000 +if err := server.ListenAndServe("tcp", "127.0.0.1:8000"); err != nil { + panic(err) +} +``` + diff --git a/vendor/github.com/armon/go-socks5/auth.go b/vendor/github.com/armon/go-socks5/auth.go new file mode 100644 index 00000000..7811e2aa --- /dev/null +++ b/vendor/github.com/armon/go-socks5/auth.go @@ -0,0 +1,151 @@ +package socks5 + +import ( + "fmt" + "io" +) + +const ( + NoAuth = uint8(0) + noAcceptable = uint8(255) + UserPassAuth = uint8(2) + userAuthVersion = uint8(1) + authSuccess = uint8(0) + authFailure = uint8(1) +) + +var ( + UserAuthFailed = fmt.Errorf("User authentication failed") + NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") +) + +// A Request encapsulates authentication state provided +// during negotiation +type AuthContext struct { + // Provided auth method + Method uint8 + // Payload provided during negotiation. + // Keys depend on the used auth method. + // For UserPassauth contains Username + Payload map[string]string +} + +type Authenticator interface { + Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) + GetCode() uint8 +} + +// NoAuthAuthenticator is used to handle the "No Authentication" mode +type NoAuthAuthenticator struct{} + +func (a NoAuthAuthenticator) GetCode() uint8 { + return NoAuth +} + +func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { + _, err := writer.Write([]byte{socks5Version, NoAuth}) + return &AuthContext{NoAuth, nil}, err +} + +// UserPassAuthenticator is used to handle username/password based +// authentication +type UserPassAuthenticator struct { + Credentials CredentialStore +} + +func (a UserPassAuthenticator) GetCode() uint8 { + return UserPassAuth +} + +func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { + // Tell the client to use user/pass auth + if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil { + return nil, err + } + + // Get the version and username length + header := []byte{0, 0} + if _, err := io.ReadAtLeast(reader, header, 2); err != nil { + return nil, err + } + + // Ensure we are compatible + if header[0] != userAuthVersion { + return nil, fmt.Errorf("Unsupported auth version: %v", header[0]) + } + + // Get the user name + userLen := int(header[1]) + user := make([]byte, userLen) + if _, err := io.ReadAtLeast(reader, user, userLen); err != nil { + return nil, err + } + + // Get the password length + if _, err := reader.Read(header[:1]); err != nil { + return nil, err + } + + // Get the password + passLen := int(header[0]) + pass := make([]byte, passLen) + if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil { + return nil, err + } + + // Verify the password + if a.Credentials.Valid(string(user), string(pass)) { + if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil { + return nil, err + } + } else { + if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil { + return nil, err + } + return nil, UserAuthFailed + } + + // Done + return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil +} + +// authenticate is used to handle connection authentication +func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) { + // Get the methods + methods, err := readMethods(bufConn) + if err != nil { + return nil, fmt.Errorf("Failed to get auth methods: %v", err) + } + + // Select a usable method + for _, method := range methods { + cator, found := s.authMethods[method] + if found { + return cator.Authenticate(bufConn, conn) + } + } + + // No usable method found + return nil, noAcceptableAuth(conn) +} + +// noAcceptableAuth is used to handle when we have no eligible +// authentication mechanism +func noAcceptableAuth(conn io.Writer) error { + conn.Write([]byte{socks5Version, noAcceptable}) + return NoSupportedAuth +} + +// readMethods is used to read the number of methods +// and proceeding auth methods +func readMethods(r io.Reader) ([]byte, error) { + header := []byte{0} + if _, err := r.Read(header); err != nil { + return nil, err + } + + numMethods := int(header[0]) + methods := make([]byte, numMethods) + _, err := io.ReadAtLeast(r, methods, numMethods) + return methods, err +} diff --git a/vendor/github.com/armon/go-socks5/credentials.go b/vendor/github.com/armon/go-socks5/credentials.go new file mode 100644 index 00000000..96664273 --- /dev/null +++ b/vendor/github.com/armon/go-socks5/credentials.go @@ -0,0 +1,17 @@ +package socks5 + +// CredentialStore is used to support user/pass authentication +type CredentialStore interface { + Valid(user, password string) bool +} + +// StaticCredentials enables using a map directly as a credential store +type StaticCredentials map[string]string + +func (s StaticCredentials) Valid(user, password string) bool { + pass, ok := s[user] + if !ok { + return false + } + return password == pass +} diff --git a/vendor/github.com/armon/go-socks5/request.go b/vendor/github.com/armon/go-socks5/request.go new file mode 100644 index 00000000..b615fcbe --- /dev/null +++ b/vendor/github.com/armon/go-socks5/request.go @@ -0,0 +1,364 @@ +package socks5 + +import ( + "fmt" + "io" + "net" + "strconv" + "strings" + + "golang.org/x/net/context" +) + +const ( + ConnectCommand = uint8(1) + BindCommand = uint8(2) + AssociateCommand = uint8(3) + ipv4Address = uint8(1) + fqdnAddress = uint8(3) + ipv6Address = uint8(4) +) + +const ( + successReply uint8 = iota + serverFailure + ruleFailure + networkUnreachable + hostUnreachable + connectionRefused + ttlExpired + commandNotSupported + addrTypeNotSupported +) + +var ( + unrecognizedAddrType = fmt.Errorf("Unrecognized address type") +) + +// AddressRewriter is used to rewrite a destination transparently +type AddressRewriter interface { + Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec) +} + +// AddrSpec is used to return the target AddrSpec +// which may be specified as IPv4, IPv6, or a FQDN +type AddrSpec struct { + FQDN string + IP net.IP + Port int +} + +func (a *AddrSpec) String() string { + if a.FQDN != "" { + return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) + } + return fmt.Sprintf("%s:%d", a.IP, a.Port) +} + +// Address returns a string suitable to dial; prefer returning IP-based +// address, fallback to FQDN +func (a AddrSpec) Address() string { + if 0 != len(a.IP) { + return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) + } + return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) +} + +// A Request represents request received by a server +type Request struct { + // Protocol version + Version uint8 + // Requested command + Command uint8 + // AuthContext provided during negotiation + AuthContext *AuthContext + // AddrSpec of the the network that sent the request + RemoteAddr *AddrSpec + // AddrSpec of the desired destination + DestAddr *AddrSpec + // AddrSpec of the actual destination (might be affected by rewrite) + realDestAddr *AddrSpec + bufConn io.Reader +} + +type conn interface { + Write([]byte) (int, error) + RemoteAddr() net.Addr +} + +// NewRequest creates a new Request from the tcp connection +func NewRequest(bufConn io.Reader) (*Request, error) { + // Read the version byte + header := []byte{0, 0, 0} + if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { + return nil, fmt.Errorf("Failed to get command version: %v", err) + } + + // Ensure we are compatible + if header[0] != socks5Version { + return nil, fmt.Errorf("Unsupported command version: %v", header[0]) + } + + // Read in the destination address + dest, err := readAddrSpec(bufConn) + if err != nil { + return nil, err + } + + request := &Request{ + Version: socks5Version, + Command: header[1], + DestAddr: dest, + bufConn: bufConn, + } + + return request, nil +} + +// handleRequest is used for request processing after authentication +func (s *Server) handleRequest(req *Request, conn conn) error { + ctx := context.Background() + + // Resolve the address if we have a FQDN + dest := req.DestAddr + if dest.FQDN != "" { + ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN) + if err != nil { + if err := sendReply(conn, hostUnreachable, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) + } + ctx = ctx_ + dest.IP = addr + } + + // Apply any address rewrites + req.realDestAddr = req.DestAddr + if s.config.Rewriter != nil { + ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req) + } + + // Switch on the command + switch req.Command { + case ConnectCommand: + return s.handleConnect(ctx, conn, req) + case BindCommand: + return s.handleBind(ctx, conn, req) + case AssociateCommand: + return s.handleAssociate(ctx, conn, req) + default: + if err := sendReply(conn, commandNotSupported, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Unsupported command: %v", req.Command) + } +} + +// handleConnect is used to handle a connect command +func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error { + // Check if this is allowed + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { + if err := sendReply(conn, ruleFailure, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr) + } else { + ctx = ctx_ + } + + // Attempt to connect + dial := s.config.Dial + if dial == nil { + dial = func(ctx context.Context, net_, addr string) (net.Conn, error) { + return net.Dial(net_, addr) + } + } + target, err := dial(ctx, "tcp", req.realDestAddr.Address()) + if err != nil { + msg := err.Error() + resp := hostUnreachable + if strings.Contains(msg, "refused") { + resp = connectionRefused + } else if strings.Contains(msg, "network is unreachable") { + resp = networkUnreachable + } + if err := sendReply(conn, resp, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err) + } + defer target.Close() + + // Send success + local := target.LocalAddr().(*net.TCPAddr) + bind := AddrSpec{IP: local.IP, Port: local.Port} + if err := sendReply(conn, successReply, &bind); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + + // Start proxying + errCh := make(chan error, 2) + go proxy(target, req.bufConn, errCh) + go proxy(conn, target, errCh) + + // Wait + for i := 0; i < 2; i++ { + e := <-errCh + if e != nil { + // return from this function closes target (and conn). + return e + } + } + return nil +} + +// handleBind is used to handle a connect command +func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error { + // Check if this is allowed + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { + if err := sendReply(conn, ruleFailure, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr) + } else { + ctx = ctx_ + } + + // TODO: Support bind + if err := sendReply(conn, commandNotSupported, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return nil +} + +// handleAssociate is used to handle a connect command +func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error { + // Check if this is allowed + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { + if err := sendReply(conn, ruleFailure, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr) + } else { + ctx = ctx_ + } + + // TODO: Support associate + if err := sendReply(conn, commandNotSupported, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return nil +} + +// readAddrSpec is used to read AddrSpec. +// Expects an address type byte, follwed by the address and port +func readAddrSpec(r io.Reader) (*AddrSpec, error) { + d := &AddrSpec{} + + // Get the address type + addrType := []byte{0} + if _, err := r.Read(addrType); err != nil { + return nil, err + } + + // Handle on a per type basis + switch addrType[0] { + case ipv4Address: + addr := make([]byte, 4) + if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { + return nil, err + } + d.IP = net.IP(addr) + + case ipv6Address: + addr := make([]byte, 16) + if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { + return nil, err + } + d.IP = net.IP(addr) + + case fqdnAddress: + if _, err := r.Read(addrType); err != nil { + return nil, err + } + addrLen := int(addrType[0]) + fqdn := make([]byte, addrLen) + if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil { + return nil, err + } + d.FQDN = string(fqdn) + + default: + return nil, unrecognizedAddrType + } + + // Read the port + port := []byte{0, 0} + if _, err := io.ReadAtLeast(r, port, 2); err != nil { + return nil, err + } + d.Port = (int(port[0]) << 8) | int(port[1]) + + return d, nil +} + +// sendReply is used to send a reply message +func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { + // Format the address + var addrType uint8 + var addrBody []byte + var addrPort uint16 + switch { + case addr == nil: + addrType = ipv4Address + addrBody = []byte{0, 0, 0, 0} + addrPort = 0 + + case addr.FQDN != "": + addrType = fqdnAddress + addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) + addrPort = uint16(addr.Port) + + case addr.IP.To4() != nil: + addrType = ipv4Address + addrBody = []byte(addr.IP.To4()) + addrPort = uint16(addr.Port) + + case addr.IP.To16() != nil: + addrType = ipv6Address + addrBody = []byte(addr.IP.To16()) + addrPort = uint16(addr.Port) + + default: + return fmt.Errorf("Failed to format address: %v", addr) + } + + // Format the message + msg := make([]byte, 6+len(addrBody)) + msg[0] = socks5Version + msg[1] = resp + msg[2] = 0 // Reserved + msg[3] = addrType + copy(msg[4:], addrBody) + msg[4+len(addrBody)] = byte(addrPort >> 8) + msg[4+len(addrBody)+1] = byte(addrPort & 0xff) + + // Send the message + _, err := w.Write(msg) + return err +} + +type closeWriter interface { + CloseWrite() error +} + +// proxy is used to suffle data from src to destination, and sends errors +// down a dedicated channel +func proxy(dst io.Writer, src io.Reader, errCh chan error) { + _, err := io.Copy(dst, src) + if tcpConn, ok := dst.(closeWriter); ok { + tcpConn.CloseWrite() + } + errCh <- err +} diff --git a/vendor/github.com/armon/go-socks5/resolver.go b/vendor/github.com/armon/go-socks5/resolver.go new file mode 100644 index 00000000..b75a5c4d --- /dev/null +++ b/vendor/github.com/armon/go-socks5/resolver.go @@ -0,0 +1,23 @@ +package socks5 + +import ( + "net" + + "golang.org/x/net/context" +) + +// NameResolver is used to implement custom name resolution +type NameResolver interface { + Resolve(ctx context.Context, name string) (context.Context, net.IP, error) +} + +// DNSResolver uses the system DNS to resolve host names +type DNSResolver struct{} + +func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { + addr, err := net.ResolveIPAddr("ip", name) + if err != nil { + return ctx, nil, err + } + return ctx, addr.IP, err +} diff --git a/vendor/github.com/armon/go-socks5/ruleset.go b/vendor/github.com/armon/go-socks5/ruleset.go new file mode 100644 index 00000000..ba0e3538 --- /dev/null +++ b/vendor/github.com/armon/go-socks5/ruleset.go @@ -0,0 +1,41 @@ +package socks5 + +import ( + "golang.org/x/net/context" +) + +// RuleSet is used to provide custom rules to allow or prohibit actions +type RuleSet interface { + Allow(ctx context.Context, req *Request) (context.Context, bool) +} + +// PermitAll returns a RuleSet which allows all types of connections +func PermitAll() RuleSet { + return &PermitCommand{true, true, true} +} + +// PermitNone returns a RuleSet which disallows all types of connections +func PermitNone() RuleSet { + return &PermitCommand{false, false, false} +} + +// PermitCommand is an implementation of the RuleSet which +// enables filtering supported commands +type PermitCommand struct { + EnableConnect bool + EnableBind bool + EnableAssociate bool +} + +func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) { + switch req.Command { + case ConnectCommand: + return ctx, p.EnableConnect + case BindCommand: + return ctx, p.EnableBind + case AssociateCommand: + return ctx, p.EnableAssociate + } + + return ctx, false +} diff --git a/vendor/github.com/armon/go-socks5/socks5.go b/vendor/github.com/armon/go-socks5/socks5.go new file mode 100644 index 00000000..a17be68f --- /dev/null +++ b/vendor/github.com/armon/go-socks5/socks5.go @@ -0,0 +1,169 @@ +package socks5 + +import ( + "bufio" + "fmt" + "log" + "net" + "os" + + "golang.org/x/net/context" +) + +const ( + socks5Version = uint8(5) +) + +// Config is used to setup and configure a Server +type Config struct { + // AuthMethods can be provided to implement custom authentication + // By default, "auth-less" mode is enabled. + // For password-based auth use UserPassAuthenticator. + AuthMethods []Authenticator + + // If provided, username/password authentication is enabled, + // by appending a UserPassAuthenticator to AuthMethods. If not provided, + // and AUthMethods is nil, then "auth-less" mode is enabled. + Credentials CredentialStore + + // Resolver can be provided to do custom name resolution. + // Defaults to DNSResolver if not provided. + Resolver NameResolver + + // Rules is provided to enable custom logic around permitting + // various commands. If not provided, PermitAll is used. + Rules RuleSet + + // Rewriter can be used to transparently rewrite addresses. + // This is invoked before the RuleSet is invoked. + // Defaults to NoRewrite. + Rewriter AddressRewriter + + // BindIP is used for bind or udp associate + BindIP net.IP + + // Logger can be used to provide a custom log target. + // Defaults to stdout. + Logger *log.Logger + + // Optional function for dialing out + Dial func(ctx context.Context, network, addr string) (net.Conn, error) +} + +// Server is reponsible for accepting connections and handling +// the details of the SOCKS5 protocol +type Server struct { + config *Config + authMethods map[uint8]Authenticator +} + +// New creates a new Server and potentially returns an error +func New(conf *Config) (*Server, error) { + // Ensure we have at least one authentication method enabled + if len(conf.AuthMethods) == 0 { + if conf.Credentials != nil { + conf.AuthMethods = []Authenticator{&UserPassAuthenticator{conf.Credentials}} + } else { + conf.AuthMethods = []Authenticator{&NoAuthAuthenticator{}} + } + } + + // Ensure we have a DNS resolver + if conf.Resolver == nil { + conf.Resolver = DNSResolver{} + } + + // Ensure we have a rule set + if conf.Rules == nil { + conf.Rules = PermitAll() + } + + // Ensure we have a log target + if conf.Logger == nil { + conf.Logger = log.New(os.Stdout, "", log.LstdFlags) + } + + server := &Server{ + config: conf, + } + + server.authMethods = make(map[uint8]Authenticator) + + for _, a := range conf.AuthMethods { + server.authMethods[a.GetCode()] = a + } + + return server, nil +} + +// ListenAndServe is used to create a listener and serve on it +func (s *Server) ListenAndServe(network, addr string) error { + l, err := net.Listen(network, addr) + if err != nil { + return err + } + return s.Serve(l) +} + +// Serve is used to serve connections from a listener +func (s *Server) Serve(l net.Listener) error { + for { + conn, err := l.Accept() + if err != nil { + return err + } + go s.ServeConn(conn) + } + return nil +} + +// ServeConn is used to serve a single connection. +func (s *Server) ServeConn(conn net.Conn) error { + defer conn.Close() + bufConn := bufio.NewReader(conn) + + // Read the version byte + version := []byte{0} + if _, err := bufConn.Read(version); err != nil { + s.config.Logger.Printf("[ERR] socks: Failed to get version byte: %v", err) + return err + } + + // Ensure we are compatible + if version[0] != socks5Version { + err := fmt.Errorf("Unsupported SOCKS version: %v", version) + s.config.Logger.Printf("[ERR] socks: %v", err) + return err + } + + // Authenticate the connection + authContext, err := s.authenticate(conn, bufConn) + if err != nil { + err = fmt.Errorf("Failed to authenticate: %v", err) + s.config.Logger.Printf("[ERR] socks: %v", err) + return err + } + + request, err := NewRequest(bufConn) + if err != nil { + if err == unrecognizedAddrType { + if err := sendReply(conn, addrTypeNotSupported, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + } + return fmt.Errorf("Failed to read destination address: %v", err) + } + request.AuthContext = authContext + if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok { + request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port} + } + + // Process the client request + if err := s.handleRequest(request, conn); err != nil { + err = fmt.Errorf("Failed to handle request: %v", err) + s.config.Logger.Printf("[ERR] socks: %v", err) + return err + } + + return nil +} diff --git a/vendor/golang.org/x/net/context/context.go b/vendor/golang.org/x/net/context/context.go new file mode 100644 index 00000000..27dcb951 --- /dev/null +++ b/vendor/golang.org/x/net/context/context.go @@ -0,0 +1,156 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package context defines the Context type, which carries deadlines, +// cancelation signals, and other request-scoped values across API boundaries +// and between processes. +// +// Incoming requests to a server should create a Context, and outgoing calls to +// servers should accept a Context. The chain of function calls between must +// propagate the Context, optionally replacing it with a modified copy created +// using WithDeadline, WithTimeout, WithCancel, or WithValue. +// +// Programs that use Contexts should follow these rules to keep interfaces +// consistent across packages and enable static analysis tools to check context +// propagation: +// +// Do not store Contexts inside a struct type; instead, pass a Context +// explicitly to each function that needs it. The Context should be the first +// parameter, typically named ctx: +// +// func DoSomething(ctx context.Context, arg Arg) error { +// // ... use ctx ... +// } +// +// Do not pass a nil Context, even if a function permits it. Pass context.TODO +// if you are unsure about which Context to use. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +// +// The same Context may be passed to functions running in different goroutines; +// Contexts are safe for simultaneous use by multiple goroutines. +// +// See http://blog.golang.org/context for example code for a server that uses +// Contexts. +package context + +import "time" + +// A Context carries a deadline, a cancelation signal, and other values across +// API boundaries. +// +// Context's methods may be called by multiple goroutines simultaneously. +type Context interface { + // Deadline returns the time when work done on behalf of this context + // should be canceled. Deadline returns ok==false when no deadline is + // set. Successive calls to Deadline return the same results. + Deadline() (deadline time.Time, ok bool) + + // Done returns a channel that's closed when work done on behalf of this + // context should be canceled. Done may return nil if this context can + // never be canceled. Successive calls to Done return the same value. + // + // WithCancel arranges for Done to be closed when cancel is called; + // WithDeadline arranges for Done to be closed when the deadline + // expires; WithTimeout arranges for Done to be closed when the timeout + // elapses. + // + // Done is provided for use in select statements: + // + // // Stream generates values with DoSomething and sends them to out + // // until DoSomething returns an error or ctx.Done is closed. + // func Stream(ctx context.Context, out chan<- Value) error { + // for { + // v, err := DoSomething(ctx) + // if err != nil { + // return err + // } + // select { + // case <-ctx.Done(): + // return ctx.Err() + // case out <- v: + // } + // } + // } + // + // See http://blog.golang.org/pipelines for more examples of how to use + // a Done channel for cancelation. + Done() <-chan struct{} + + // Err returns a non-nil error value after Done is closed. Err returns + // Canceled if the context was canceled or DeadlineExceeded if the + // context's deadline passed. No other values for Err are defined. + // After Done is closed, successive calls to Err return the same value. + Err() error + + // Value returns the value associated with this context for key, or nil + // if no value is associated with key. Successive calls to Value with + // the same key returns the same result. + // + // Use context values only for request-scoped data that transits + // processes and API boundaries, not for passing optional parameters to + // functions. + // + // A key identifies a specific value in a Context. Functions that wish + // to store values in Context typically allocate a key in a global + // variable then use that key as the argument to context.WithValue and + // Context.Value. A key can be any type that supports equality; + // packages should define keys as an unexported type to avoid + // collisions. + // + // Packages that define a Context key should provide type-safe accessors + // for the values stores using that key: + // + // // Package user defines a User type that's stored in Contexts. + // package user + // + // import "golang.org/x/net/context" + // + // // User is the type of value stored in the Contexts. + // type User struct {...} + // + // // key is an unexported type for keys defined in this package. + // // This prevents collisions with keys defined in other packages. + // type key int + // + // // userKey is the key for user.User values in Contexts. It is + // // unexported; clients use user.NewContext and user.FromContext + // // instead of using this key directly. + // var userKey key = 0 + // + // // NewContext returns a new Context that carries value u. + // func NewContext(ctx context.Context, u *User) context.Context { + // return context.WithValue(ctx, userKey, u) + // } + // + // // FromContext returns the User value stored in ctx, if any. + // func FromContext(ctx context.Context) (*User, bool) { + // u, ok := ctx.Value(userKey).(*User) + // return u, ok + // } + Value(key interface{}) interface{} +} + +// Background returns a non-nil, empty Context. It is never canceled, has no +// values, and has no deadline. It is typically used by the main function, +// initialization, and tests, and as the top-level Context for incoming +// requests. +func Background() Context { + return background +} + +// TODO returns a non-nil, empty Context. Code should use context.TODO when +// it's unclear which Context to use or it is not yet available (because the +// surrounding function has not yet been extended to accept a Context +// parameter). TODO is recognized by static analysis tools that determine +// whether Contexts are propagated correctly in a program. +func TODO() Context { + return todo +} + +// A CancelFunc tells an operation to abandon its work. +// A CancelFunc does not wait for the work to stop. +// After the first call, subsequent calls to a CancelFunc do nothing. +type CancelFunc func() diff --git a/vendor/golang.org/x/net/context/go17.go b/vendor/golang.org/x/net/context/go17.go new file mode 100644 index 00000000..d20f52b7 --- /dev/null +++ b/vendor/golang.org/x/net/context/go17.go @@ -0,0 +1,72 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.7 + +package context + +import ( + "context" // standard library's context, as of Go 1.7 + "time" +) + +var ( + todo = context.TODO() + background = context.Background() +) + +// Canceled is the error returned by Context.Err when the context is canceled. +var Canceled = context.Canceled + +// DeadlineExceeded is the error returned by Context.Err when the context's +// deadline passes. +var DeadlineExceeded = context.DeadlineExceeded + +// WithCancel returns a copy of parent with a new Done channel. The returned +// context's Done channel is closed when the returned cancel function is called +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + ctx, f := context.WithCancel(parent) + return ctx, CancelFunc(f) +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted +// to be no later than d. If the parent's deadline is already earlier than d, +// WithDeadline(parent, d) is semantically equivalent to parent. The returned +// context's Done channel is closed when the deadline expires, when the returned +// cancel function is called, or when the parent context's Done channel is +// closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { + ctx, f := context.WithDeadline(parent, deadline) + return ctx, CancelFunc(f) +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +func WithValue(parent Context, key interface{}, val interface{}) Context { + return context.WithValue(parent, key, val) +} diff --git a/vendor/golang.org/x/net/context/pre_go17.go b/vendor/golang.org/x/net/context/pre_go17.go new file mode 100644 index 00000000..0f35592d --- /dev/null +++ b/vendor/golang.org/x/net/context/pre_go17.go @@ -0,0 +1,300 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.7 + +package context + +import ( + "errors" + "fmt" + "sync" + "time" +) + +// An emptyCtx is never canceled, has no values, and has no deadline. It is not +// struct{}, since vars of this type must have distinct addresses. +type emptyCtx int + +func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*emptyCtx) Done() <-chan struct{} { + return nil +} + +func (*emptyCtx) Err() error { + return nil +} + +func (*emptyCtx) Value(key interface{}) interface{} { + return nil +} + +func (e *emptyCtx) String() string { + switch e { + case background: + return "context.Background" + case todo: + return "context.TODO" + } + return "unknown empty Context" +} + +var ( + background = new(emptyCtx) + todo = new(emptyCtx) +) + +// Canceled is the error returned by Context.Err when the context is canceled. +var Canceled = errors.New("context canceled") + +// DeadlineExceeded is the error returned by Context.Err when the context's +// deadline passes. +var DeadlineExceeded = errors.New("context deadline exceeded") + +// WithCancel returns a copy of parent with a new Done channel. The returned +// context's Done channel is closed when the returned cancel function is called +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + c := newCancelCtx(parent) + propagateCancel(parent, c) + return c, func() { c.cancel(true, Canceled) } +} + +// newCancelCtx returns an initialized cancelCtx. +func newCancelCtx(parent Context) *cancelCtx { + return &cancelCtx{ + Context: parent, + done: make(chan struct{}), + } +} + +// propagateCancel arranges for child to be canceled when parent is. +func propagateCancel(parent Context, child canceler) { + if parent.Done() == nil { + return // parent is never canceled + } + if p, ok := parentCancelCtx(parent); ok { + p.mu.Lock() + if p.err != nil { + // parent has already been canceled + child.cancel(false, p.err) + } else { + if p.children == nil { + p.children = make(map[canceler]bool) + } + p.children[child] = true + } + p.mu.Unlock() + } else { + go func() { + select { + case <-parent.Done(): + child.cancel(false, parent.Err()) + case <-child.Done(): + } + }() + } +} + +// parentCancelCtx follows a chain of parent references until it finds a +// *cancelCtx. This function understands how each of the concrete types in this +// package represents its parent. +func parentCancelCtx(parent Context) (*cancelCtx, bool) { + for { + switch c := parent.(type) { + case *cancelCtx: + return c, true + case *timerCtx: + return c.cancelCtx, true + case *valueCtx: + parent = c.Context + default: + return nil, false + } + } +} + +// removeChild removes a context from its parent. +func removeChild(parent Context, child canceler) { + p, ok := parentCancelCtx(parent) + if !ok { + return + } + p.mu.Lock() + if p.children != nil { + delete(p.children, child) + } + p.mu.Unlock() +} + +// A canceler is a context type that can be canceled directly. The +// implementations are *cancelCtx and *timerCtx. +type canceler interface { + cancel(removeFromParent bool, err error) + Done() <-chan struct{} +} + +// A cancelCtx can be canceled. When canceled, it also cancels any children +// that implement canceler. +type cancelCtx struct { + Context + + done chan struct{} // closed by the first cancel call. + + mu sync.Mutex + children map[canceler]bool // set to nil by the first cancel call + err error // set to non-nil by the first cancel call +} + +func (c *cancelCtx) Done() <-chan struct{} { + return c.done +} + +func (c *cancelCtx) Err() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.err +} + +func (c *cancelCtx) String() string { + return fmt.Sprintf("%v.WithCancel", c.Context) +} + +// cancel closes c.done, cancels each of c's children, and, if +// removeFromParent is true, removes c from its parent's children. +func (c *cancelCtx) cancel(removeFromParent bool, err error) { + if err == nil { + panic("context: internal error: missing cancel error") + } + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return // already canceled + } + c.err = err + close(c.done) + for child := range c.children { + // NOTE: acquiring the child's lock while holding parent's lock. + child.cancel(false, err) + } + c.children = nil + c.mu.Unlock() + + if removeFromParent { + removeChild(c.Context, c) + } +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted +// to be no later than d. If the parent's deadline is already earlier than d, +// WithDeadline(parent, d) is semantically equivalent to parent. The returned +// context's Done channel is closed when the deadline expires, when the returned +// cancel function is called, or when the parent context's Done channel is +// closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { + if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { + // The current deadline is already sooner than the new one. + return WithCancel(parent) + } + c := &timerCtx{ + cancelCtx: newCancelCtx(parent), + deadline: deadline, + } + propagateCancel(parent, c) + d := deadline.Sub(time.Now()) + if d <= 0 { + c.cancel(true, DeadlineExceeded) // deadline has already passed + return c, func() { c.cancel(true, Canceled) } + } + c.mu.Lock() + defer c.mu.Unlock() + if c.err == nil { + c.timer = time.AfterFunc(d, func() { + c.cancel(true, DeadlineExceeded) + }) + } + return c, func() { c.cancel(true, Canceled) } +} + +// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to +// implement Done and Err. It implements cancel by stopping its timer then +// delegating to cancelCtx.cancel. +type timerCtx struct { + *cancelCtx + timer *time.Timer // Under cancelCtx.mu. + + deadline time.Time +} + +func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { + return c.deadline, true +} + +func (c *timerCtx) String() string { + return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) +} + +func (c *timerCtx) cancel(removeFromParent bool, err error) { + c.cancelCtx.cancel(false, err) + if removeFromParent { + // Remove this timerCtx from its parent cancelCtx's children. + removeChild(c.cancelCtx.Context, c) + } + c.mu.Lock() + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } + c.mu.Unlock() +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +func WithValue(parent Context, key interface{}, val interface{}) Context { + return &valueCtx{parent, key, val} +} + +// A valueCtx carries a key-value pair. It implements Value for that key and +// delegates all other calls to the embedded Context. +type valueCtx struct { + Context + key, val interface{} +} + +func (c *valueCtx) String() string { + return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) +} + +func (c *valueCtx) Value(key interface{}) interface{} { + if c.key == key { + return c.val + } + return c.Context.Value(key) +}