mirror of https://github.com/fatedier/frp
fatedier
7 years ago
committed by
GitHub
44 changed files with 2958 additions and 270 deletions
@ -0,0 +1,60 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net" |
||||
"net/http" |
||||
"time" |
||||
|
||||
"github.com/fatedier/frp/models/config" |
||||
frpNet "github.com/fatedier/frp/utils/net" |
||||
|
||||
"github.com/julienschmidt/httprouter" |
||||
) |
||||
|
||||
var ( |
||||
httpServerReadTimeout = 10 * time.Second |
||||
httpServerWriteTimeout = 10 * time.Second |
||||
) |
||||
|
||||
func (svr *Service) RunAdminServer(addr string, port int64) (err error) { |
||||
// url router
|
||||
router := httprouter.New() |
||||
|
||||
user, passwd := config.ClientCommonCfg.AdminUser, config.ClientCommonCfg.AdminPwd |
||||
|
||||
// api, see dashboard_api.go
|
||||
router.GET("/api/reload", frpNet.HttprouterBasicAuth(svr.apiReload, user, passwd)) |
||||
|
||||
address := fmt.Sprintf("%s:%d", addr, port) |
||||
server := &http.Server{ |
||||
Addr: address, |
||||
Handler: router, |
||||
ReadTimeout: httpServerReadTimeout, |
||||
WriteTimeout: httpServerWriteTimeout, |
||||
} |
||||
if address == "" { |
||||
address = ":http" |
||||
} |
||||
ln, err := net.Listen("tcp", address) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
go server.Serve(ln) |
||||
return |
||||
} |
@ -0,0 +1,78 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"net/http" |
||||
|
||||
"github.com/julienschmidt/httprouter" |
||||
ini "github.com/vaughan0/go-ini" |
||||
|
||||
"github.com/fatedier/frp/models/config" |
||||
"github.com/fatedier/frp/utils/log" |
||||
) |
||||
|
||||
type GeneralResponse struct { |
||||
Code int64 `json:"code"` |
||||
Msg string `json:"msg"` |
||||
} |
||||
|
||||
// api/reload
|
||||
type ReloadResp struct { |
||||
GeneralResponse |
||||
} |
||||
|
||||
func (svr *Service) apiReload(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { |
||||
var ( |
||||
buf []byte |
||||
res ReloadResp |
||||
) |
||||
defer func() { |
||||
log.Info("Http response [/api/reload]: code [%d]", res.Code) |
||||
buf, _ = json.Marshal(&res) |
||||
w.Write(buf) |
||||
}() |
||||
|
||||
log.Info("Http request: [/api/reload]") |
||||
|
||||
conf, err := ini.LoadFile(config.ClientCommonCfg.ConfigFile) |
||||
if err != nil { |
||||
res.Code = 1 |
||||
res.Msg = err.Error() |
||||
log.Error("reload frpc config file error: %v", err) |
||||
return |
||||
} |
||||
|
||||
newCommonCfg, err := config.LoadClientCommonConf(conf) |
||||
if err != nil { |
||||
res.Code = 2 |
||||
res.Msg = err.Error() |
||||
log.Error("reload frpc common section error: %v", err) |
||||
return |
||||
} |
||||
|
||||
pxyCfgs, vistorCfgs, err := config.LoadProxyConfFromFile(config.ClientCommonCfg.User, conf, newCommonCfg.Start) |
||||
if err != nil { |
||||
res.Code = 3 |
||||
res.Msg = err.Error() |
||||
log.Error("reload frpc proxy config error: %v", err) |
||||
return |
||||
} |
||||
|
||||
svr.ctl.reloadConf(pxyCfgs, vistorCfgs) |
||||
log.Info("success reload conf") |
||||
return |
||||
} |
@ -0,0 +1,145 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"io" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/fatedier/frp/models/config" |
||||
"github.com/fatedier/frp/models/msg" |
||||
frpIo "github.com/fatedier/frp/utils/io" |
||||
"github.com/fatedier/frp/utils/log" |
||||
frpNet "github.com/fatedier/frp/utils/net" |
||||
"github.com/fatedier/frp/utils/util" |
||||
) |
||||
|
||||
// Vistor is used for forward traffics from local port tot remote service.
|
||||
type Vistor interface { |
||||
Run() error |
||||
Close() |
||||
log.Logger |
||||
} |
||||
|
||||
func NewVistor(ctl *Control, pxyConf config.ProxyConf) (vistor Vistor) { |
||||
baseVistor := BaseVistor{ |
||||
ctl: ctl, |
||||
Logger: log.NewPrefixLogger(pxyConf.GetName()), |
||||
} |
||||
switch cfg := pxyConf.(type) { |
||||
case *config.StcpProxyConf: |
||||
vistor = &StcpVistor{ |
||||
BaseVistor: baseVistor, |
||||
cfg: cfg, |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
type BaseVistor struct { |
||||
ctl *Control |
||||
l frpNet.Listener |
||||
closed bool |
||||
mu sync.RWMutex |
||||
log.Logger |
||||
} |
||||
|
||||
type StcpVistor struct { |
||||
BaseVistor |
||||
|
||||
cfg *config.StcpProxyConf |
||||
} |
||||
|
||||
func (sv *StcpVistor) Run() (err error) { |
||||
sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, int64(sv.cfg.BindPort)) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
go sv.worker() |
||||
return |
||||
} |
||||
|
||||
func (sv *StcpVistor) Close() { |
||||
sv.l.Close() |
||||
} |
||||
|
||||
func (sv *StcpVistor) worker() { |
||||
for { |
||||
conn, err := sv.l.Accept() |
||||
if err != nil { |
||||
sv.Warn("stcp local listener closed") |
||||
return |
||||
} |
||||
|
||||
go sv.handleConn(conn) |
||||
} |
||||
} |
||||
|
||||
func (sv *StcpVistor) handleConn(userConn frpNet.Conn) { |
||||
defer userConn.Close() |
||||
|
||||
sv.Debug("get a new stcp user connection") |
||||
vistorConn, err := sv.ctl.connectServer() |
||||
if err != nil { |
||||
return |
||||
} |
||||
defer vistorConn.Close() |
||||
|
||||
now := time.Now().Unix() |
||||
newVistorConnMsg := &msg.NewVistorConn{ |
||||
ProxyName: sv.cfg.ServerName, |
||||
SignKey: util.GetAuthKey(sv.cfg.Sk, now), |
||||
Timestamp: now, |
||||
UseEncryption: sv.cfg.UseEncryption, |
||||
UseCompression: sv.cfg.UseCompression, |
||||
} |
||||
err = msg.WriteMsg(vistorConn, newVistorConnMsg) |
||||
if err != nil { |
||||
sv.Warn("send newVistorConnMsg to server error: %v", err) |
||||
return |
||||
} |
||||
|
||||
var newVistorConnRespMsg msg.NewVistorConnResp |
||||
vistorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) |
||||
err = msg.ReadMsgInto(vistorConn, &newVistorConnRespMsg) |
||||
if err != nil { |
||||
sv.Warn("get newVistorConnRespMsg error: %v", err) |
||||
return |
||||
} |
||||
vistorConn.SetReadDeadline(time.Time{}) |
||||
|
||||
if newVistorConnRespMsg.Error != "" { |
||||
sv.Warn("start new vistor connection error: %s", newVistorConnRespMsg.Error) |
||||
return |
||||
} |
||||
|
||||
var remote io.ReadWriteCloser |
||||
remote = vistorConn |
||||
if sv.cfg.UseEncryption { |
||||
remote, err = frpIo.WithEncryption(remote, []byte(sv.cfg.Sk)) |
||||
if err != nil { |
||||
sv.Error("create encryption stream error: %v", err) |
||||
return |
||||
} |
||||
} |
||||
|
||||
if sv.cfg.UseCompression { |
||||
remote = frpIo.WithCompression(remote) |
||||
} |
||||
|
||||
frpIo.Join(userConn, remote) |
||||
} |
@ -0,0 +1,65 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin |
||||
|
||||
import ( |
||||
"io" |
||||
"io/ioutil" |
||||
"log" |
||||
|
||||
frpNet "github.com/fatedier/frp/utils/net" |
||||
|
||||
gosocks5 "github.com/armon/go-socks5" |
||||
) |
||||
|
||||
const PluginSocks5 = "socks5" |
||||
|
||||
func init() { |
||||
Register(PluginSocks5, NewSocks5Plugin) |
||||
} |
||||
|
||||
type Socks5Plugin struct { |
||||
Server *gosocks5.Server |
||||
} |
||||
|
||||
func NewSocks5Plugin(params map[string]string) (p Plugin, err error) { |
||||
sp := &Socks5Plugin{} |
||||
sp.Server, err = gosocks5.New(&gosocks5.Config{ |
||||
Logger: log.New(ioutil.Discard, "", log.LstdFlags), |
||||
}) |
||||
p = sp |
||||
return |
||||
} |
||||
|
||||
func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser) { |
||||
defer conn.Close() |
||||
|
||||
var wrapConn frpNet.Conn |
||||
if realConn, ok := conn.(frpNet.Conn); ok { |
||||
wrapConn = realConn |
||||
} else { |
||||
wrapConn = frpNet.WrapReadWriteCloserToConn(conn) |
||||
} |
||||
|
||||
sp.Server.ServeConn(wrapConn) |
||||
} |
||||
|
||||
func (sp *Socks5Plugin) Name() string { |
||||
return PluginSocks5 |
||||
} |
||||
|
||||
func (sp *Socks5Plugin) Close() error { |
||||
return nil |
||||
} |
@ -0,0 +1,105 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package net |
||||
|
||||
import ( |
||||
"compress/gzip" |
||||
"io" |
||||
"net/http" |
||||
"strings" |
||||
|
||||
"github.com/julienschmidt/httprouter" |
||||
) |
||||
|
||||
type HttpAuthWraper struct { |
||||
h http.Handler |
||||
user string |
||||
passwd string |
||||
} |
||||
|
||||
func NewHttpBasicAuthWraper(h http.Handler, user, passwd string) http.Handler { |
||||
return &HttpAuthWraper{ |
||||
h: h, |
||||
user: user, |
||||
passwd: passwd, |
||||
} |
||||
} |
||||
|
||||
func (aw *HttpAuthWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { |
||||
user, passwd, hasAuth := r.BasicAuth() |
||||
if (aw.user == "" && aw.passwd == "") || (hasAuth && user == aw.user && passwd == aw.passwd) { |
||||
aw.h.ServeHTTP(w, r) |
||||
} else { |
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) |
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) |
||||
} |
||||
} |
||||
|
||||
func HttpBasicAuth(h http.HandlerFunc, user, passwd string) http.HandlerFunc { |
||||
return func(w http.ResponseWriter, r *http.Request) { |
||||
reqUser, reqPasswd, hasAuth := r.BasicAuth() |
||||
if (user == "" && passwd == "") || |
||||
(hasAuth && reqUser == user && reqPasswd == passwd) { |
||||
h.ServeHTTP(w, r) |
||||
} else { |
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) |
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func HttprouterBasicAuth(h httprouter.Handle, user, passwd string) httprouter.Handle { |
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { |
||||
reqUser, reqPasswd, hasAuth := r.BasicAuth() |
||||
if (user == "" && passwd == "") || |
||||
(hasAuth && reqUser == user && reqPasswd == passwd) { |
||||
h(w, r, ps) |
||||
} else { |
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) |
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) |
||||
} |
||||
} |
||||
} |
||||
|
||||
type HttpGzipWraper struct { |
||||
h http.Handler |
||||
} |
||||
|
||||
func (gw *HttpGzipWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { |
||||
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { |
||||
gw.h.ServeHTTP(w, r) |
||||
return |
||||
} |
||||
w.Header().Set("Content-Encoding", "gzip") |
||||
gz := gzip.NewWriter(w) |
||||
defer gz.Close() |
||||
gzr := gzipResponseWriter{Writer: gz, ResponseWriter: w} |
||||
gw.h.ServeHTTP(gzr, r) |
||||
} |
||||
|
||||
func MakeHttpGzipHandler(h http.Handler) http.Handler { |
||||
return &HttpGzipWraper{ |
||||
h: h, |
||||
} |
||||
} |
||||
|
||||
type gzipResponseWriter struct { |
||||
io.Writer |
||||
http.ResponseWriter |
||||
} |
||||
|
||||
func (w gzipResponseWriter) Write(b []byte) (int, error) { |
||||
return w.Writer.Write(b) |
||||
} |
@ -0,0 +1,22 @@
|
||||
# Compiled Object files, Static and Dynamic libs (Shared Objects) |
||||
*.o |
||||
*.a |
||||
*.so |
||||
|
||||
# Folders |
||||
_obj |
||||
_test |
||||
|
||||
# Architecture specific extensions/prefixes |
||||
*.[568vq] |
||||
[568vq].out |
||||
|
||||
*.cgo1.go |
||||
*.cgo2.c |
||||
_cgo_defun.c |
||||
_cgo_gotypes.go |
||||
_cgo_export.* |
||||
|
||||
_testmain.go |
||||
|
||||
*.exe |
@ -0,0 +1,4 @@
|
||||
language: go |
||||
go: |
||||
- 1.1 |
||||
- tip |
@ -0,0 +1,20 @@
|
||||
The MIT License (MIT) |
||||
|
||||
Copyright (c) 2014 Armon Dadgar |
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of |
||||
this software and associated documentation files (the "Software"), to deal in |
||||
the Software without restriction, including without limitation the rights to |
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of |
||||
the Software, and to permit persons to whom the Software is furnished to do so, |
||||
subject to the following conditions: |
||||
|
||||
The above copyright notice and this permission notice shall be included in all |
||||
copies or substantial portions of the Software. |
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS |
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR |
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER |
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
@ -0,0 +1,45 @@
|
||||
go-socks5 [![Build Status](https://travis-ci.org/armon/go-socks5.png)](https://travis-ci.org/armon/go-socks5) |
||||
========= |
||||
|
||||
Provides the `socks5` package that implements a [SOCKS5 server](http://en.wikipedia.org/wiki/SOCKS). |
||||
SOCKS (Secure Sockets) is used to route traffic between a client and server through |
||||
an intermediate proxy layer. This can be used to bypass firewalls or NATs. |
||||
|
||||
Feature |
||||
======= |
||||
|
||||
The package has the following features: |
||||
* "No Auth" mode |
||||
* User/Password authentication |
||||
* Support for the CONNECT command |
||||
* Rules to do granular filtering of commands |
||||
* Custom DNS resolution |
||||
* Unit tests |
||||
|
||||
TODO |
||||
==== |
||||
|
||||
The package still needs the following: |
||||
* Support for the BIND command |
||||
* Support for the ASSOCIATE command |
||||
|
||||
|
||||
Example |
||||
======= |
||||
|
||||
Below is a simple example of usage |
||||
|
||||
```go |
||||
// Create a SOCKS5 server |
||||
conf := &socks5.Config{} |
||||
server, err := socks5.New(conf) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
|
||||
// Create SOCKS5 proxy on localhost port 8000 |
||||
if err := server.ListenAndServe("tcp", "127.0.0.1:8000"); err != nil { |
||||
panic(err) |
||||
} |
||||
``` |
||||
|
@ -0,0 +1,151 @@
|
||||
package socks5 |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
) |
||||
|
||||
const ( |
||||
NoAuth = uint8(0) |
||||
noAcceptable = uint8(255) |
||||
UserPassAuth = uint8(2) |
||||
userAuthVersion = uint8(1) |
||||
authSuccess = uint8(0) |
||||
authFailure = uint8(1) |
||||
) |
||||
|
||||
var ( |
||||
UserAuthFailed = fmt.Errorf("User authentication failed") |
||||
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") |
||||
) |
||||
|
||||
// A Request encapsulates authentication state provided
|
||||
// during negotiation
|
||||
type AuthContext struct { |
||||
// Provided auth method
|
||||
Method uint8 |
||||
// Payload provided during negotiation.
|
||||
// Keys depend on the used auth method.
|
||||
// For UserPassauth contains Username
|
||||
Payload map[string]string |
||||
} |
||||
|
||||
type Authenticator interface { |
||||
Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) |
||||
GetCode() uint8 |
||||
} |
||||
|
||||
// NoAuthAuthenticator is used to handle the "No Authentication" mode
|
||||
type NoAuthAuthenticator struct{} |
||||
|
||||
func (a NoAuthAuthenticator) GetCode() uint8 { |
||||
return NoAuth |
||||
} |
||||
|
||||
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { |
||||
_, err := writer.Write([]byte{socks5Version, NoAuth}) |
||||
return &AuthContext{NoAuth, nil}, err |
||||
} |
||||
|
||||
// UserPassAuthenticator is used to handle username/password based
|
||||
// authentication
|
||||
type UserPassAuthenticator struct { |
||||
Credentials CredentialStore |
||||
} |
||||
|
||||
func (a UserPassAuthenticator) GetCode() uint8 { |
||||
return UserPassAuth |
||||
} |
||||
|
||||
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { |
||||
// Tell the client to use user/pass auth
|
||||
if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Get the version and username length
|
||||
header := []byte{0, 0} |
||||
if _, err := io.ReadAtLeast(reader, header, 2); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Ensure we are compatible
|
||||
if header[0] != userAuthVersion { |
||||
return nil, fmt.Errorf("Unsupported auth version: %v", header[0]) |
||||
} |
||||
|
||||
// Get the user name
|
||||
userLen := int(header[1]) |
||||
user := make([]byte, userLen) |
||||
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Get the password length
|
||||
if _, err := reader.Read(header[:1]); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Get the password
|
||||
passLen := int(header[0]) |
||||
pass := make([]byte, passLen) |
||||
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Verify the password
|
||||
if a.Credentials.Valid(string(user), string(pass)) { |
||||
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil { |
||||
return nil, err |
||||
} |
||||
} else { |
||||
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil { |
||||
return nil, err |
||||
} |
||||
return nil, UserAuthFailed |
||||
} |
||||
|
||||
// Done
|
||||
return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil |
||||
} |
||||
|
||||
// authenticate is used to handle connection authentication
|
||||
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) { |
||||
// Get the methods
|
||||
methods, err := readMethods(bufConn) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("Failed to get auth methods: %v", err) |
||||
} |
||||
|
||||
// Select a usable method
|
||||
for _, method := range methods { |
||||
cator, found := s.authMethods[method] |
||||
if found { |
||||
return cator.Authenticate(bufConn, conn) |
||||
} |
||||
} |
||||
|
||||
// No usable method found
|
||||
return nil, noAcceptableAuth(conn) |
||||
} |
||||
|
||||
// noAcceptableAuth is used to handle when we have no eligible
|
||||
// authentication mechanism
|
||||
func noAcceptableAuth(conn io.Writer) error { |
||||
conn.Write([]byte{socks5Version, noAcceptable}) |
||||
return NoSupportedAuth |
||||
} |
||||
|
||||
// readMethods is used to read the number of methods
|
||||
// and proceeding auth methods
|
||||
func readMethods(r io.Reader) ([]byte, error) { |
||||
header := []byte{0} |
||||
if _, err := r.Read(header); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
numMethods := int(header[0]) |
||||
methods := make([]byte, numMethods) |
||||
_, err := io.ReadAtLeast(r, methods, numMethods) |
||||
return methods, err |
||||
} |
@ -0,0 +1,17 @@
|
||||
package socks5 |
||||
|
||||
// CredentialStore is used to support user/pass authentication
|
||||
type CredentialStore interface { |
||||
Valid(user, password string) bool |
||||
} |
||||
|
||||
// StaticCredentials enables using a map directly as a credential store
|
||||
type StaticCredentials map[string]string |
||||
|
||||
func (s StaticCredentials) Valid(user, password string) bool { |
||||
pass, ok := s[user] |
||||
if !ok { |
||||
return false |
||||
} |
||||
return password == pass |
||||
} |
@ -0,0 +1,364 @@
|
||||
package socks5 |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"golang.org/x/net/context" |
||||
) |
||||
|
||||
const ( |
||||
ConnectCommand = uint8(1) |
||||
BindCommand = uint8(2) |
||||
AssociateCommand = uint8(3) |
||||
ipv4Address = uint8(1) |
||||
fqdnAddress = uint8(3) |
||||
ipv6Address = uint8(4) |
||||
) |
||||
|
||||
const ( |
||||
successReply uint8 = iota |
||||
serverFailure |
||||
ruleFailure |
||||
networkUnreachable |
||||
hostUnreachable |
||||
connectionRefused |
||||
ttlExpired |
||||
commandNotSupported |
||||
addrTypeNotSupported |
||||
) |
||||
|
||||
var ( |
||||
unrecognizedAddrType = fmt.Errorf("Unrecognized address type") |
||||
) |
||||
|
||||
// AddressRewriter is used to rewrite a destination transparently
|
||||
type AddressRewriter interface { |
||||
Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec) |
||||
} |
||||
|
||||
// AddrSpec is used to return the target AddrSpec
|
||||
// which may be specified as IPv4, IPv6, or a FQDN
|
||||
type AddrSpec struct { |
||||
FQDN string |
||||
IP net.IP |
||||
Port int |
||||
} |
||||
|
||||
func (a *AddrSpec) String() string { |
||||
if a.FQDN != "" { |
||||
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) |
||||
} |
||||
return fmt.Sprintf("%s:%d", a.IP, a.Port) |
||||
} |
||||
|
||||
// Address returns a string suitable to dial; prefer returning IP-based
|
||||
// address, fallback to FQDN
|
||||
func (a AddrSpec) Address() string { |
||||
if 0 != len(a.IP) { |
||||
return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) |
||||
} |
||||
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) |
||||
} |
||||
|
||||
// A Request represents request received by a server
|
||||
type Request struct { |
||||
// Protocol version
|
||||
Version uint8 |
||||
// Requested command
|
||||
Command uint8 |
||||
// AuthContext provided during negotiation
|
||||
AuthContext *AuthContext |
||||
// AddrSpec of the the network that sent the request
|
||||
RemoteAddr *AddrSpec |
||||
// AddrSpec of the desired destination
|
||||
DestAddr *AddrSpec |
||||
// AddrSpec of the actual destination (might be affected by rewrite)
|
||||
realDestAddr *AddrSpec |
||||
bufConn io.Reader |
||||
} |
||||
|
||||
type conn interface { |
||||
Write([]byte) (int, error) |
||||
RemoteAddr() net.Addr |
||||
} |
||||
|
||||
// NewRequest creates a new Request from the tcp connection
|
||||
func NewRequest(bufConn io.Reader) (*Request, error) { |
||||
// Read the version byte
|
||||
header := []byte{0, 0, 0} |
||||
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { |
||||
return nil, fmt.Errorf("Failed to get command version: %v", err) |
||||
} |
||||
|
||||
// Ensure we are compatible
|
||||
if header[0] != socks5Version { |
||||
return nil, fmt.Errorf("Unsupported command version: %v", header[0]) |
||||
} |
||||
|
||||
// Read in the destination address
|
||||
dest, err := readAddrSpec(bufConn) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
request := &Request{ |
||||
Version: socks5Version, |
||||
Command: header[1], |
||||
DestAddr: dest, |
||||
bufConn: bufConn, |
||||
} |
||||
|
||||
return request, nil |
||||
} |
||||
|
||||
// handleRequest is used for request processing after authentication
|
||||
func (s *Server) handleRequest(req *Request, conn conn) error { |
||||
ctx := context.Background() |
||||
|
||||
// Resolve the address if we have a FQDN
|
||||
dest := req.DestAddr |
||||
if dest.FQDN != "" { |
||||
ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN) |
||||
if err != nil { |
||||
if err := sendReply(conn, hostUnreachable, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) |
||||
} |
||||
ctx = ctx_ |
||||
dest.IP = addr |
||||
} |
||||
|
||||
// Apply any address rewrites
|
||||
req.realDestAddr = req.DestAddr |
||||
if s.config.Rewriter != nil { |
||||
ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req) |
||||
} |
||||
|
||||
// Switch on the command
|
||||
switch req.Command { |
||||
case ConnectCommand: |
||||
return s.handleConnect(ctx, conn, req) |
||||
case BindCommand: |
||||
return s.handleBind(ctx, conn, req) |
||||
case AssociateCommand: |
||||
return s.handleAssociate(ctx, conn, req) |
||||
default: |
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return fmt.Errorf("Unsupported command: %v", req.Command) |
||||
} |
||||
} |
||||
|
||||
// handleConnect is used to handle a connect command
|
||||
func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error { |
||||
// Check if this is allowed
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { |
||||
if err := sendReply(conn, ruleFailure, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr) |
||||
} else { |
||||
ctx = ctx_ |
||||
} |
||||
|
||||
// Attempt to connect
|
||||
dial := s.config.Dial |
||||
if dial == nil { |
||||
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) { |
||||
return net.Dial(net_, addr) |
||||
} |
||||
} |
||||
target, err := dial(ctx, "tcp", req.realDestAddr.Address()) |
||||
if err != nil { |
||||
msg := err.Error() |
||||
resp := hostUnreachable |
||||
if strings.Contains(msg, "refused") { |
||||
resp = connectionRefused |
||||
} else if strings.Contains(msg, "network is unreachable") { |
||||
resp = networkUnreachable |
||||
} |
||||
if err := sendReply(conn, resp, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err) |
||||
} |
||||
defer target.Close() |
||||
|
||||
// Send success
|
||||
local := target.LocalAddr().(*net.TCPAddr) |
||||
bind := AddrSpec{IP: local.IP, Port: local.Port} |
||||
if err := sendReply(conn, successReply, &bind); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
|
||||
// Start proxying
|
||||
errCh := make(chan error, 2) |
||||
go proxy(target, req.bufConn, errCh) |
||||
go proxy(conn, target, errCh) |
||||
|
||||
// Wait
|
||||
for i := 0; i < 2; i++ { |
||||
e := <-errCh |
||||
if e != nil { |
||||
// return from this function closes target (and conn).
|
||||
return e |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// handleBind is used to handle a connect command
|
||||
func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error { |
||||
// Check if this is allowed
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { |
||||
if err := sendReply(conn, ruleFailure, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr) |
||||
} else { |
||||
ctx = ctx_ |
||||
} |
||||
|
||||
// TODO: Support bind
|
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// handleAssociate is used to handle a connect command
|
||||
func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error { |
||||
// Check if this is allowed
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { |
||||
if err := sendReply(conn, ruleFailure, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr) |
||||
} else { |
||||
ctx = ctx_ |
||||
} |
||||
|
||||
// TODO: Support associate
|
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// readAddrSpec is used to read AddrSpec.
|
||||
// Expects an address type byte, follwed by the address and port
|
||||
func readAddrSpec(r io.Reader) (*AddrSpec, error) { |
||||
d := &AddrSpec{} |
||||
|
||||
// Get the address type
|
||||
addrType := []byte{0} |
||||
if _, err := r.Read(addrType); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Handle on a per type basis
|
||||
switch addrType[0] { |
||||
case ipv4Address: |
||||
addr := make([]byte, 4) |
||||
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { |
||||
return nil, err |
||||
} |
||||
d.IP = net.IP(addr) |
||||
|
||||
case ipv6Address: |
||||
addr := make([]byte, 16) |
||||
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { |
||||
return nil, err |
||||
} |
||||
d.IP = net.IP(addr) |
||||
|
||||
case fqdnAddress: |
||||
if _, err := r.Read(addrType); err != nil { |
||||
return nil, err |
||||
} |
||||
addrLen := int(addrType[0]) |
||||
fqdn := make([]byte, addrLen) |
||||
if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil { |
||||
return nil, err |
||||
} |
||||
d.FQDN = string(fqdn) |
||||
|
||||
default: |
||||
return nil, unrecognizedAddrType |
||||
} |
||||
|
||||
// Read the port
|
||||
port := []byte{0, 0} |
||||
if _, err := io.ReadAtLeast(r, port, 2); err != nil { |
||||
return nil, err |
||||
} |
||||
d.Port = (int(port[0]) << 8) | int(port[1]) |
||||
|
||||
return d, nil |
||||
} |
||||
|
||||
// sendReply is used to send a reply message
|
||||
func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { |
||||
// Format the address
|
||||
var addrType uint8 |
||||
var addrBody []byte |
||||
var addrPort uint16 |
||||
switch { |
||||
case addr == nil: |
||||
addrType = ipv4Address |
||||
addrBody = []byte{0, 0, 0, 0} |
||||
addrPort = 0 |
||||
|
||||
case addr.FQDN != "": |
||||
addrType = fqdnAddress |
||||
addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) |
||||
addrPort = uint16(addr.Port) |
||||
|
||||
case addr.IP.To4() != nil: |
||||
addrType = ipv4Address |
||||
addrBody = []byte(addr.IP.To4()) |
||||
addrPort = uint16(addr.Port) |
||||
|
||||
case addr.IP.To16() != nil: |
||||
addrType = ipv6Address |
||||
addrBody = []byte(addr.IP.To16()) |
||||
addrPort = uint16(addr.Port) |
||||
|
||||
default: |
||||
return fmt.Errorf("Failed to format address: %v", addr) |
||||
} |
||||
|
||||
// Format the message
|
||||
msg := make([]byte, 6+len(addrBody)) |
||||
msg[0] = socks5Version |
||||
msg[1] = resp |
||||
msg[2] = 0 // Reserved
|
||||
msg[3] = addrType |
||||
copy(msg[4:], addrBody) |
||||
msg[4+len(addrBody)] = byte(addrPort >> 8) |
||||
msg[4+len(addrBody)+1] = byte(addrPort & 0xff) |
||||
|
||||
// Send the message
|
||||
_, err := w.Write(msg) |
||||
return err |
||||
} |
||||
|
||||
type closeWriter interface { |
||||
CloseWrite() error |
||||
} |
||||
|
||||
// proxy is used to suffle data from src to destination, and sends errors
|
||||
// down a dedicated channel
|
||||
func proxy(dst io.Writer, src io.Reader, errCh chan error) { |
||||
_, err := io.Copy(dst, src) |
||||
if tcpConn, ok := dst.(closeWriter); ok { |
||||
tcpConn.CloseWrite() |
||||
} |
||||
errCh <- err |
||||
} |
@ -0,0 +1,23 @@
|
||||
package socks5 |
||||
|
||||
import ( |
||||
"net" |
||||
|
||||
"golang.org/x/net/context" |
||||
) |
||||
|
||||
// NameResolver is used to implement custom name resolution
|
||||
type NameResolver interface { |
||||
Resolve(ctx context.Context, name string) (context.Context, net.IP, error) |
||||
} |
||||
|
||||
// DNSResolver uses the system DNS to resolve host names
|
||||
type DNSResolver struct{} |
||||
|
||||
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { |
||||
addr, err := net.ResolveIPAddr("ip", name) |
||||
if err != nil { |
||||
return ctx, nil, err |
||||
} |
||||
return ctx, addr.IP, err |
||||
} |
@ -0,0 +1,41 @@
|
||||
package socks5 |
||||
|
||||
import ( |
||||
"golang.org/x/net/context" |
||||
) |
||||
|
||||
// RuleSet is used to provide custom rules to allow or prohibit actions
|
||||
type RuleSet interface { |
||||
Allow(ctx context.Context, req *Request) (context.Context, bool) |
||||
} |
||||
|
||||
// PermitAll returns a RuleSet which allows all types of connections
|
||||
func PermitAll() RuleSet { |
||||
return &PermitCommand{true, true, true} |
||||
} |
||||
|
||||
// PermitNone returns a RuleSet which disallows all types of connections
|
||||
func PermitNone() RuleSet { |
||||
return &PermitCommand{false, false, false} |
||||
} |
||||
|
||||
// PermitCommand is an implementation of the RuleSet which
|
||||
// enables filtering supported commands
|
||||
type PermitCommand struct { |
||||
EnableConnect bool |
||||
EnableBind bool |
||||
EnableAssociate bool |
||||
} |
||||
|
||||
func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) { |
||||
switch req.Command { |
||||
case ConnectCommand: |
||||
return ctx, p.EnableConnect |
||||
case BindCommand: |
||||
return ctx, p.EnableBind |
||||
case AssociateCommand: |
||||
return ctx, p.EnableAssociate |
||||
} |
||||
|
||||
return ctx, false |
||||
} |
@ -0,0 +1,169 @@
|
||||
package socks5 |
||||
|
||||
import ( |
||||
"bufio" |
||||
"fmt" |
||||
"log" |
||||
"net" |
||||
"os" |
||||
|
||||
"golang.org/x/net/context" |
||||
) |
||||
|
||||
const ( |
||||
socks5Version = uint8(5) |
||||
) |
||||
|
||||
// Config is used to setup and configure a Server
|
||||
type Config struct { |
||||
// AuthMethods can be provided to implement custom authentication
|
||||
// By default, "auth-less" mode is enabled.
|
||||
// For password-based auth use UserPassAuthenticator.
|
||||
AuthMethods []Authenticator |
||||
|
||||
// If provided, username/password authentication is enabled,
|
||||
// by appending a UserPassAuthenticator to AuthMethods. If not provided,
|
||||
// and AUthMethods is nil, then "auth-less" mode is enabled.
|
||||
Credentials CredentialStore |
||||
|
||||
// Resolver can be provided to do custom name resolution.
|
||||
// Defaults to DNSResolver if not provided.
|
||||
Resolver NameResolver |
||||
|
||||
// Rules is provided to enable custom logic around permitting
|
||||
// various commands. If not provided, PermitAll is used.
|
||||
Rules RuleSet |
||||
|
||||
// Rewriter can be used to transparently rewrite addresses.
|
||||
// This is invoked before the RuleSet is invoked.
|
||||
// Defaults to NoRewrite.
|
||||
Rewriter AddressRewriter |
||||
|
||||
// BindIP is used for bind or udp associate
|
||||
BindIP net.IP |
||||
|
||||
// Logger can be used to provide a custom log target.
|
||||
// Defaults to stdout.
|
||||
Logger *log.Logger |
||||
|
||||
// Optional function for dialing out
|
||||
Dial func(ctx context.Context, network, addr string) (net.Conn, error) |
||||
} |
||||
|
||||
// Server is reponsible for accepting connections and handling
|
||||
// the details of the SOCKS5 protocol
|
||||
type Server struct { |
||||
config *Config |
||||
authMethods map[uint8]Authenticator |
||||
} |
||||
|
||||
// New creates a new Server and potentially returns an error
|
||||
func New(conf *Config) (*Server, error) { |
||||
// Ensure we have at least one authentication method enabled
|
||||
if len(conf.AuthMethods) == 0 { |
||||
if conf.Credentials != nil { |
||||
conf.AuthMethods = []Authenticator{&UserPassAuthenticator{conf.Credentials}} |
||||
} else { |
||||
conf.AuthMethods = []Authenticator{&NoAuthAuthenticator{}} |
||||
} |
||||
} |
||||
|
||||
// Ensure we have a DNS resolver
|
||||
if conf.Resolver == nil { |
||||
conf.Resolver = DNSResolver{} |
||||
} |
||||
|
||||
// Ensure we have a rule set
|
||||
if conf.Rules == nil { |
||||
conf.Rules = PermitAll() |
||||
} |
||||
|
||||
// Ensure we have a log target
|
||||
if conf.Logger == nil { |
||||
conf.Logger = log.New(os.Stdout, "", log.LstdFlags) |
||||
} |
||||
|
||||
server := &Server{ |
||||
config: conf, |
||||
} |
||||
|
||||
server.authMethods = make(map[uint8]Authenticator) |
||||
|
||||
for _, a := range conf.AuthMethods { |
||||
server.authMethods[a.GetCode()] = a |
||||
} |
||||
|
||||
return server, nil |
||||
} |
||||
|
||||
// ListenAndServe is used to create a listener and serve on it
|
||||
func (s *Server) ListenAndServe(network, addr string) error { |
||||
l, err := net.Listen(network, addr) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return s.Serve(l) |
||||
} |
||||
|
||||
// Serve is used to serve connections from a listener
|
||||
func (s *Server) Serve(l net.Listener) error { |
||||
for { |
||||
conn, err := l.Accept() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
go s.ServeConn(conn) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// ServeConn is used to serve a single connection.
|
||||
func (s *Server) ServeConn(conn net.Conn) error { |
||||
defer conn.Close() |
||||
bufConn := bufio.NewReader(conn) |
||||
|
||||
// Read the version byte
|
||||
version := []byte{0} |
||||
if _, err := bufConn.Read(version); err != nil { |
||||
s.config.Logger.Printf("[ERR] socks: Failed to get version byte: %v", err) |
||||
return err |
||||
} |
||||
|
||||
// Ensure we are compatible
|
||||
if version[0] != socks5Version { |
||||
err := fmt.Errorf("Unsupported SOCKS version: %v", version) |
||||
s.config.Logger.Printf("[ERR] socks: %v", err) |
||||
return err |
||||
} |
||||
|
||||
// Authenticate the connection
|
||||
authContext, err := s.authenticate(conn, bufConn) |
||||
if err != nil { |
||||
err = fmt.Errorf("Failed to authenticate: %v", err) |
||||
s.config.Logger.Printf("[ERR] socks: %v", err) |
||||
return err |
||||
} |
||||
|
||||
request, err := NewRequest(bufConn) |
||||
if err != nil { |
||||
if err == unrecognizedAddrType { |
||||
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
} |
||||
return fmt.Errorf("Failed to read destination address: %v", err) |
||||
} |
||||
request.AuthContext = authContext |
||||
if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok { |
||||
request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port} |
||||
} |
||||
|
||||
// Process the client request
|
||||
if err := s.handleRequest(request, conn); err != nil { |
||||
err = fmt.Errorf("Failed to handle request: %v", err) |
||||
s.config.Logger.Printf("[ERR] socks: %v", err) |
||||
return err |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,156 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package context defines the Context type, which carries deadlines,
|
||||
// cancelation signals, and other request-scoped values across API boundaries
|
||||
// and between processes.
|
||||
//
|
||||
// Incoming requests to a server should create a Context, and outgoing calls to
|
||||
// servers should accept a Context. The chain of function calls between must
|
||||
// propagate the Context, optionally replacing it with a modified copy created
|
||||
// using WithDeadline, WithTimeout, WithCancel, or WithValue.
|
||||
//
|
||||
// Programs that use Contexts should follow these rules to keep interfaces
|
||||
// consistent across packages and enable static analysis tools to check context
|
||||
// propagation:
|
||||
//
|
||||
// Do not store Contexts inside a struct type; instead, pass a Context
|
||||
// explicitly to each function that needs it. The Context should be the first
|
||||
// parameter, typically named ctx:
|
||||
//
|
||||
// func DoSomething(ctx context.Context, arg Arg) error {
|
||||
// // ... use ctx ...
|
||||
// }
|
||||
//
|
||||
// Do not pass a nil Context, even if a function permits it. Pass context.TODO
|
||||
// if you are unsure about which Context to use.
|
||||
//
|
||||
// Use context Values only for request-scoped data that transits processes and
|
||||
// APIs, not for passing optional parameters to functions.
|
||||
//
|
||||
// The same Context may be passed to functions running in different goroutines;
|
||||
// Contexts are safe for simultaneous use by multiple goroutines.
|
||||
//
|
||||
// See http://blog.golang.org/context for example code for a server that uses
|
||||
// Contexts.
|
||||
package context |
||||
|
||||
import "time" |
||||
|
||||
// A Context carries a deadline, a cancelation signal, and other values across
|
||||
// API boundaries.
|
||||
//
|
||||
// Context's methods may be called by multiple goroutines simultaneously.
|
||||
type Context interface { |
||||
// Deadline returns the time when work done on behalf of this context
|
||||
// should be canceled. Deadline returns ok==false when no deadline is
|
||||
// set. Successive calls to Deadline return the same results.
|
||||
Deadline() (deadline time.Time, ok bool) |
||||
|
||||
// Done returns a channel that's closed when work done on behalf of this
|
||||
// context should be canceled. Done may return nil if this context can
|
||||
// never be canceled. Successive calls to Done return the same value.
|
||||
//
|
||||
// WithCancel arranges for Done to be closed when cancel is called;
|
||||
// WithDeadline arranges for Done to be closed when the deadline
|
||||
// expires; WithTimeout arranges for Done to be closed when the timeout
|
||||
// elapses.
|
||||
//
|
||||
// Done is provided for use in select statements:
|
||||
//
|
||||
// // Stream generates values with DoSomething and sends them to out
|
||||
// // until DoSomething returns an error or ctx.Done is closed.
|
||||
// func Stream(ctx context.Context, out chan<- Value) error {
|
||||
// for {
|
||||
// v, err := DoSomething(ctx)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// select {
|
||||
// case <-ctx.Done():
|
||||
// return ctx.Err()
|
||||
// case out <- v:
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// See http://blog.golang.org/pipelines for more examples of how to use
|
||||
// a Done channel for cancelation.
|
||||
Done() <-chan struct{} |
||||
|
||||
// Err returns a non-nil error value after Done is closed. Err returns
|
||||
// Canceled if the context was canceled or DeadlineExceeded if the
|
||||
// context's deadline passed. No other values for Err are defined.
|
||||
// After Done is closed, successive calls to Err return the same value.
|
||||
Err() error |
||||
|
||||
// Value returns the value associated with this context for key, or nil
|
||||
// if no value is associated with key. Successive calls to Value with
|
||||
// the same key returns the same result.
|
||||
//
|
||||
// Use context values only for request-scoped data that transits
|
||||
// processes and API boundaries, not for passing optional parameters to
|
||||
// functions.
|
||||
//
|
||||
// A key identifies a specific value in a Context. Functions that wish
|
||||
// to store values in Context typically allocate a key in a global
|
||||
// variable then use that key as the argument to context.WithValue and
|
||||
// Context.Value. A key can be any type that supports equality;
|
||||
// packages should define keys as an unexported type to avoid
|
||||
// collisions.
|
||||
//
|
||||
// Packages that define a Context key should provide type-safe accessors
|
||||
// for the values stores using that key:
|
||||
//
|
||||
// // Package user defines a User type that's stored in Contexts.
|
||||
// package user
|
||||
//
|
||||
// import "golang.org/x/net/context"
|
||||
//
|
||||
// // User is the type of value stored in the Contexts.
|
||||
// type User struct {...}
|
||||
//
|
||||
// // key is an unexported type for keys defined in this package.
|
||||
// // This prevents collisions with keys defined in other packages.
|
||||
// type key int
|
||||
//
|
||||
// // userKey is the key for user.User values in Contexts. It is
|
||||
// // unexported; clients use user.NewContext and user.FromContext
|
||||
// // instead of using this key directly.
|
||||
// var userKey key = 0
|
||||
//
|
||||
// // NewContext returns a new Context that carries value u.
|
||||
// func NewContext(ctx context.Context, u *User) context.Context {
|
||||
// return context.WithValue(ctx, userKey, u)
|
||||
// }
|
||||
//
|
||||
// // FromContext returns the User value stored in ctx, if any.
|
||||
// func FromContext(ctx context.Context) (*User, bool) {
|
||||
// u, ok := ctx.Value(userKey).(*User)
|
||||
// return u, ok
|
||||
// }
|
||||
Value(key interface{}) interface{} |
||||
} |
||||
|
||||
// Background returns a non-nil, empty Context. It is never canceled, has no
|
||||
// values, and has no deadline. It is typically used by the main function,
|
||||
// initialization, and tests, and as the top-level Context for incoming
|
||||
// requests.
|
||||
func Background() Context { |
||||
return background |
||||
} |
||||
|
||||
// TODO returns a non-nil, empty Context. Code should use context.TODO when
|
||||
// it's unclear which Context to use or it is not yet available (because the
|
||||
// surrounding function has not yet been extended to accept a Context
|
||||
// parameter). TODO is recognized by static analysis tools that determine
|
||||
// whether Contexts are propagated correctly in a program.
|
||||
func TODO() Context { |
||||
return todo |
||||
} |
||||
|
||||
// A CancelFunc tells an operation to abandon its work.
|
||||
// A CancelFunc does not wait for the work to stop.
|
||||
// After the first call, subsequent calls to a CancelFunc do nothing.
|
||||
type CancelFunc func() |
@ -0,0 +1,72 @@
|
||||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build go1.7
|
||||
|
||||
package context |
||||
|
||||
import ( |
||||
"context" // standard library's context, as of Go 1.7
|
||||
"time" |
||||
) |
||||
|
||||
var ( |
||||
todo = context.TODO() |
||||
background = context.Background() |
||||
) |
||||
|
||||
// Canceled is the error returned by Context.Err when the context is canceled.
|
||||
var Canceled = context.Canceled |
||||
|
||||
// DeadlineExceeded is the error returned by Context.Err when the context's
|
||||
// deadline passes.
|
||||
var DeadlineExceeded = context.DeadlineExceeded |
||||
|
||||
// WithCancel returns a copy of parent with a new Done channel. The returned
|
||||
// context's Done channel is closed when the returned cancel function is called
|
||||
// or when the parent context's Done channel is closed, whichever happens first.
|
||||
//
|
||||
// Canceling this context releases resources associated with it, so code should
|
||||
// call cancel as soon as the operations running in this Context complete.
|
||||
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { |
||||
ctx, f := context.WithCancel(parent) |
||||
return ctx, CancelFunc(f) |
||||
} |
||||
|
||||
// WithDeadline returns a copy of the parent context with the deadline adjusted
|
||||
// to be no later than d. If the parent's deadline is already earlier than d,
|
||||
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
|
||||
// context's Done channel is closed when the deadline expires, when the returned
|
||||
// cancel function is called, or when the parent context's Done channel is
|
||||
// closed, whichever happens first.
|
||||
//
|
||||
// Canceling this context releases resources associated with it, so code should
|
||||
// call cancel as soon as the operations running in this Context complete.
|
||||
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { |
||||
ctx, f := context.WithDeadline(parent, deadline) |
||||
return ctx, CancelFunc(f) |
||||
} |
||||
|
||||
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
|
||||
//
|
||||
// Canceling this context releases resources associated with it, so code should
|
||||
// call cancel as soon as the operations running in this Context complete:
|
||||
//
|
||||
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
|
||||
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
// defer cancel() // releases resources if slowOperation completes before timeout elapses
|
||||
// return slowOperation(ctx)
|
||||
// }
|
||||
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { |
||||
return WithDeadline(parent, time.Now().Add(timeout)) |
||||
} |
||||
|
||||
// WithValue returns a copy of parent in which the value associated with key is
|
||||
// val.
|
||||
//
|
||||
// Use context Values only for request-scoped data that transits processes and
|
||||
// APIs, not for passing optional parameters to functions.
|
||||
func WithValue(parent Context, key interface{}, val interface{}) Context { |
||||
return context.WithValue(parent, key, val) |
||||
} |
@ -0,0 +1,300 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !go1.7
|
||||
|
||||
package context |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"sync" |
||||
"time" |
||||
) |
||||
|
||||
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
|
||||
// struct{}, since vars of this type must have distinct addresses.
|
||||
type emptyCtx int |
||||
|
||||
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { |
||||
return |
||||
} |
||||
|
||||
func (*emptyCtx) Done() <-chan struct{} { |
||||
return nil |
||||
} |
||||
|
||||
func (*emptyCtx) Err() error { |
||||
return nil |
||||
} |
||||
|
||||
func (*emptyCtx) Value(key interface{}) interface{} { |
||||
return nil |
||||
} |
||||
|
||||
func (e *emptyCtx) String() string { |
||||
switch e { |
||||
case background: |
||||
return "context.Background" |
||||
case todo: |
||||
return "context.TODO" |
||||
} |
||||
return "unknown empty Context" |
||||
} |
||||
|
||||
var ( |
||||
background = new(emptyCtx) |
||||
todo = new(emptyCtx) |
||||
) |
||||
|
||||
// Canceled is the error returned by Context.Err when the context is canceled.
|
||||
var Canceled = errors.New("context canceled") |
||||
|
||||
// DeadlineExceeded is the error returned by Context.Err when the context's
|
||||
// deadline passes.
|
||||
var DeadlineExceeded = errors.New("context deadline exceeded") |
||||
|
||||
// WithCancel returns a copy of parent with a new Done channel. The returned
|
||||
// context's Done channel is closed when the returned cancel function is called
|
||||
// or when the parent context's Done channel is closed, whichever happens first.
|
||||
//
|
||||
// Canceling this context releases resources associated with it, so code should
|
||||
// call cancel as soon as the operations running in this Context complete.
|
||||
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { |
||||
c := newCancelCtx(parent) |
||||
propagateCancel(parent, c) |
||||
return c, func() { c.cancel(true, Canceled) } |
||||
} |
||||
|
||||
// newCancelCtx returns an initialized cancelCtx.
|
||||
func newCancelCtx(parent Context) *cancelCtx { |
||||
return &cancelCtx{ |
||||
Context: parent, |
||||
done: make(chan struct{}), |
||||
} |
||||
} |
||||
|
||||
// propagateCancel arranges for child to be canceled when parent is.
|
||||
func propagateCancel(parent Context, child canceler) { |
||||
if parent.Done() == nil { |
||||
return // parent is never canceled
|
||||
} |
||||
if p, ok := parentCancelCtx(parent); ok { |
||||
p.mu.Lock() |
||||
if p.err != nil { |
||||
// parent has already been canceled
|
||||
child.cancel(false, p.err) |
||||
} else { |
||||
if p.children == nil { |
||||
p.children = make(map[canceler]bool) |
||||
} |
||||
p.children[child] = true |
||||
} |
||||
p.mu.Unlock() |
||||
} else { |
||||
go func() { |
||||
select { |
||||
case <-parent.Done(): |
||||
child.cancel(false, parent.Err()) |
||||
case <-child.Done(): |
||||
} |
||||
}() |
||||
} |
||||
} |
||||
|
||||
// parentCancelCtx follows a chain of parent references until it finds a
|
||||
// *cancelCtx. This function understands how each of the concrete types in this
|
||||
// package represents its parent.
|
||||
func parentCancelCtx(parent Context) (*cancelCtx, bool) { |
||||
for { |
||||
switch c := parent.(type) { |
||||
case *cancelCtx: |
||||
return c, true |
||||
case *timerCtx: |
||||
return c.cancelCtx, true |
||||
case *valueCtx: |
||||
parent = c.Context |
||||
default: |
||||
return nil, false |
||||
} |
||||
} |
||||
} |
||||
|
||||
// removeChild removes a context from its parent.
|
||||
func removeChild(parent Context, child canceler) { |
||||
p, ok := parentCancelCtx(parent) |
||||
if !ok { |
||||
return |
||||
} |
||||
p.mu.Lock() |
||||
if p.children != nil { |
||||
delete(p.children, child) |
||||
} |
||||
p.mu.Unlock() |
||||
} |
||||
|
||||
// A canceler is a context type that can be canceled directly. The
|
||||
// implementations are *cancelCtx and *timerCtx.
|
||||
type canceler interface { |
||||
cancel(removeFromParent bool, err error) |
||||
Done() <-chan struct{} |
||||
} |
||||
|
||||
// A cancelCtx can be canceled. When canceled, it also cancels any children
|
||||
// that implement canceler.
|
||||
type cancelCtx struct { |
||||
Context |
||||
|
||||
done chan struct{} // closed by the first cancel call.
|
||||
|
||||
mu sync.Mutex |
||||
children map[canceler]bool // set to nil by the first cancel call
|
||||
err error // set to non-nil by the first cancel call
|
||||
} |
||||
|
||||
func (c *cancelCtx) Done() <-chan struct{} { |
||||
return c.done |
||||
} |
||||
|
||||
func (c *cancelCtx) Err() error { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
return c.err |
||||
} |
||||
|
||||
func (c *cancelCtx) String() string { |
||||
return fmt.Sprintf("%v.WithCancel", c.Context) |
||||
} |
||||
|
||||
// cancel closes c.done, cancels each of c's children, and, if
|
||||
// removeFromParent is true, removes c from its parent's children.
|
||||
func (c *cancelCtx) cancel(removeFromParent bool, err error) { |
||||
if err == nil { |
||||
panic("context: internal error: missing cancel error") |
||||
} |
||||
c.mu.Lock() |
||||
if c.err != nil { |
||||
c.mu.Unlock() |
||||
return // already canceled
|
||||
} |
||||
c.err = err |
||||
close(c.done) |
||||
for child := range c.children { |
||||
// NOTE: acquiring the child's lock while holding parent's lock.
|
||||
child.cancel(false, err) |
||||
} |
||||
c.children = nil |
||||
c.mu.Unlock() |
||||
|
||||
if removeFromParent { |
||||
removeChild(c.Context, c) |
||||
} |
||||
} |
||||
|
||||
// WithDeadline returns a copy of the parent context with the deadline adjusted
|
||||
// to be no later than d. If the parent's deadline is already earlier than d,
|
||||
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
|
||||
// context's Done channel is closed when the deadline expires, when the returned
|
||||
// cancel function is called, or when the parent context's Done channel is
|
||||
// closed, whichever happens first.
|
||||
//
|
||||
// Canceling this context releases resources associated with it, so code should
|
||||
// call cancel as soon as the operations running in this Context complete.
|
||||
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { |
||||
if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { |
||||
// The current deadline is already sooner than the new one.
|
||||
return WithCancel(parent) |
||||
} |
||||
c := &timerCtx{ |
||||
cancelCtx: newCancelCtx(parent), |
||||
deadline: deadline, |
||||
} |
||||
propagateCancel(parent, c) |
||||
d := deadline.Sub(time.Now()) |
||||
if d <= 0 { |
||||
c.cancel(true, DeadlineExceeded) // deadline has already passed
|
||||
return c, func() { c.cancel(true, Canceled) } |
||||
} |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
if c.err == nil { |
||||
c.timer = time.AfterFunc(d, func() { |
||||
c.cancel(true, DeadlineExceeded) |
||||
}) |
||||
} |
||||
return c, func() { c.cancel(true, Canceled) } |
||||
} |
||||
|
||||
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
|
||||
// implement Done and Err. It implements cancel by stopping its timer then
|
||||
// delegating to cancelCtx.cancel.
|
||||
type timerCtx struct { |
||||
*cancelCtx |
||||
timer *time.Timer // Under cancelCtx.mu.
|
||||
|
||||
deadline time.Time |
||||
} |
||||
|
||||
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { |
||||
return c.deadline, true |
||||
} |
||||
|
||||
func (c *timerCtx) String() string { |
||||
return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) |
||||
} |
||||
|
||||
func (c *timerCtx) cancel(removeFromParent bool, err error) { |
||||
c.cancelCtx.cancel(false, err) |
||||
if removeFromParent { |
||||
// Remove this timerCtx from its parent cancelCtx's children.
|
||||
removeChild(c.cancelCtx.Context, c) |
||||
} |
||||
c.mu.Lock() |
||||
if c.timer != nil { |
||||
c.timer.Stop() |
||||
c.timer = nil |
||||
} |
||||
c.mu.Unlock() |
||||
} |
||||
|
||||
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
|
||||
//
|
||||
// Canceling this context releases resources associated with it, so code should
|
||||
// call cancel as soon as the operations running in this Context complete:
|
||||
//
|
||||
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
|
||||
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
// defer cancel() // releases resources if slowOperation completes before timeout elapses
|
||||
// return slowOperation(ctx)
|
||||
// }
|
||||
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { |
||||
return WithDeadline(parent, time.Now().Add(timeout)) |
||||
} |
||||
|
||||
// WithValue returns a copy of parent in which the value associated with key is
|
||||
// val.
|
||||
//
|
||||
// Use context Values only for request-scoped data that transits processes and
|
||||
// APIs, not for passing optional parameters to functions.
|
||||
func WithValue(parent Context, key interface{}, val interface{}) Context { |
||||
return &valueCtx{parent, key, val} |
||||
} |
||||
|
||||
// A valueCtx carries a key-value pair. It implements Value for that key and
|
||||
// delegates all other calls to the embedded Context.
|
||||
type valueCtx struct { |
||||
Context |
||||
key, val interface{} |
||||
} |
||||
|
||||
func (c *valueCtx) String() string { |
||||
return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) |
||||
} |
||||
|
||||
func (c *valueCtx) Value(key interface{}) interface{} { |
||||
if c.key == key { |
||||
return c.val |
||||
} |
||||
return c.Context.Value(key) |
||||
} |
Loading…
Reference in new issue