mirror of https://github.com/fatedier/frp
fatedier
8 years ago
14 changed files with 1392 additions and 0 deletions
@ -0,0 +1,22 @@
|
||||
# Compiled Object files, Static and Dynamic libs (Shared Objects) |
||||
*.o |
||||
*.a |
||||
*.so |
||||
|
||||
# Folders |
||||
_obj |
||||
_test |
||||
|
||||
# Architecture specific extensions/prefixes |
||||
*.[568vq] |
||||
[568vq].out |
||||
|
||||
*.cgo1.go |
||||
*.cgo2.c |
||||
_cgo_defun.c |
||||
_cgo_gotypes.go |
||||
_cgo_export.* |
||||
|
||||
_testmain.go |
||||
|
||||
*.exe |
@ -0,0 +1,4 @@
|
||||
language: go |
||||
go: |
||||
- 1.1 |
||||
- tip |
@ -0,0 +1,20 @@
|
||||
The MIT License (MIT) |
||||
|
||||
Copyright (c) 2014 Armon Dadgar |
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of |
||||
this software and associated documentation files (the "Software"), to deal in |
||||
the Software without restriction, including without limitation the rights to |
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of |
||||
the Software, and to permit persons to whom the Software is furnished to do so, |
||||
subject to the following conditions: |
||||
|
||||
The above copyright notice and this permission notice shall be included in all |
||||
copies or substantial portions of the Software. |
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS |
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR |
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER |
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
@ -0,0 +1,45 @@
|
||||
go-socks5 [![Build Status](https://travis-ci.org/armon/go-socks5.png)](https://travis-ci.org/armon/go-socks5) |
||||
========= |
||||
|
||||
Provides the `socks5` package that implements a [SOCKS5 server](http://en.wikipedia.org/wiki/SOCKS). |
||||
SOCKS (Secure Sockets) is used to route traffic between a client and server through |
||||
an intermediate proxy layer. This can be used to bypass firewalls or NATs. |
||||
|
||||
Feature |
||||
======= |
||||
|
||||
The package has the following features: |
||||
* "No Auth" mode |
||||
* User/Password authentication |
||||
* Support for the CONNECT command |
||||
* Rules to do granular filtering of commands |
||||
* Custom DNS resolution |
||||
* Unit tests |
||||
|
||||
TODO |
||||
==== |
||||
|
||||
The package still needs the following: |
||||
* Support for the BIND command |
||||
* Support for the ASSOCIATE command |
||||
|
||||
|
||||
Example |
||||
======= |
||||
|
||||
Below is a simple example of usage |
||||
|
||||
```go |
||||
// Create a SOCKS5 server |
||||
conf := &socks5.Config{} |
||||
server, err := socks5.New(conf) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
|
||||
// Create SOCKS5 proxy on localhost port 8000 |
||||
if err := server.ListenAndServe("tcp", "127.0.0.1:8000"); err != nil { |
||||
panic(err) |
||||
} |
||||
``` |
||||
|
@ -0,0 +1,151 @@
|
||||
package socks5 |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
) |
||||
|
||||
const ( |
||||
NoAuth = uint8(0) |
||||
noAcceptable = uint8(255) |
||||
UserPassAuth = uint8(2) |
||||
userAuthVersion = uint8(1) |
||||
authSuccess = uint8(0) |
||||
authFailure = uint8(1) |
||||
) |
||||
|
||||
var ( |
||||
UserAuthFailed = fmt.Errorf("User authentication failed") |
||||
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") |
||||
) |
||||
|
||||
// A Request encapsulates authentication state provided
|
||||
// during negotiation
|
||||
type AuthContext struct { |
||||
// Provided auth method
|
||||
Method uint8 |
||||
// Payload provided during negotiation.
|
||||
// Keys depend on the used auth method.
|
||||
// For UserPassauth contains Username
|
||||
Payload map[string]string |
||||
} |
||||
|
||||
type Authenticator interface { |
||||
Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) |
||||
GetCode() uint8 |
||||
} |
||||
|
||||
// NoAuthAuthenticator is used to handle the "No Authentication" mode
|
||||
type NoAuthAuthenticator struct{} |
||||
|
||||
func (a NoAuthAuthenticator) GetCode() uint8 { |
||||
return NoAuth |
||||
} |
||||
|
||||
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { |
||||
_, err := writer.Write([]byte{socks5Version, NoAuth}) |
||||
return &AuthContext{NoAuth, nil}, err |
||||
} |
||||
|
||||
// UserPassAuthenticator is used to handle username/password based
|
||||
// authentication
|
||||
type UserPassAuthenticator struct { |
||||
Credentials CredentialStore |
||||
} |
||||
|
||||
func (a UserPassAuthenticator) GetCode() uint8 { |
||||
return UserPassAuth |
||||
} |
||||
|
||||
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { |
||||
// Tell the client to use user/pass auth
|
||||
if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Get the version and username length
|
||||
header := []byte{0, 0} |
||||
if _, err := io.ReadAtLeast(reader, header, 2); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Ensure we are compatible
|
||||
if header[0] != userAuthVersion { |
||||
return nil, fmt.Errorf("Unsupported auth version: %v", header[0]) |
||||
} |
||||
|
||||
// Get the user name
|
||||
userLen := int(header[1]) |
||||
user := make([]byte, userLen) |
||||
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Get the password length
|
||||
if _, err := reader.Read(header[:1]); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Get the password
|
||||
passLen := int(header[0]) |
||||
pass := make([]byte, passLen) |
||||
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Verify the password
|
||||
if a.Credentials.Valid(string(user), string(pass)) { |
||||
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil { |
||||
return nil, err |
||||
} |
||||
} else { |
||||
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil { |
||||
return nil, err |
||||
} |
||||
return nil, UserAuthFailed |
||||
} |
||||
|
||||
// Done
|
||||
return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil |
||||
} |
||||
|
||||
// authenticate is used to handle connection authentication
|
||||
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) { |
||||
// Get the methods
|
||||
methods, err := readMethods(bufConn) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("Failed to get auth methods: %v", err) |
||||
} |
||||
|
||||
// Select a usable method
|
||||
for _, method := range methods { |
||||
cator, found := s.authMethods[method] |
||||
if found { |
||||
return cator.Authenticate(bufConn, conn) |
||||
} |
||||
} |
||||
|
||||
// No usable method found
|
||||
return nil, noAcceptableAuth(conn) |
||||
} |
||||
|
||||
// noAcceptableAuth is used to handle when we have no eligible
|
||||
// authentication mechanism
|
||||
func noAcceptableAuth(conn io.Writer) error { |
||||
conn.Write([]byte{socks5Version, noAcceptable}) |
||||
return NoSupportedAuth |
||||
} |
||||
|
||||
// readMethods is used to read the number of methods
|
||||
// and proceeding auth methods
|
||||
func readMethods(r io.Reader) ([]byte, error) { |
||||
header := []byte{0} |
||||
if _, err := r.Read(header); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
numMethods := int(header[0]) |
||||
methods := make([]byte, numMethods) |
||||
_, err := io.ReadAtLeast(r, methods, numMethods) |
||||
return methods, err |
||||
} |
@ -0,0 +1,17 @@
|
||||
package socks5 |
||||
|
||||
// CredentialStore is used to support user/pass authentication
|
||||
type CredentialStore interface { |
||||
Valid(user, password string) bool |
||||
} |
||||
|
||||
// StaticCredentials enables using a map directly as a credential store
|
||||
type StaticCredentials map[string]string |
||||
|
||||
func (s StaticCredentials) Valid(user, password string) bool { |
||||
pass, ok := s[user] |
||||
if !ok { |
||||
return false |
||||
} |
||||
return password == pass |
||||
} |
@ -0,0 +1,364 @@
|
||||
package socks5 |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"golang.org/x/net/context" |
||||
) |
||||
|
||||
const ( |
||||
ConnectCommand = uint8(1) |
||||
BindCommand = uint8(2) |
||||
AssociateCommand = uint8(3) |
||||
ipv4Address = uint8(1) |
||||
fqdnAddress = uint8(3) |
||||
ipv6Address = uint8(4) |
||||
) |
||||
|
||||
const ( |
||||
successReply uint8 = iota |
||||
serverFailure |
||||
ruleFailure |
||||
networkUnreachable |
||||
hostUnreachable |
||||
connectionRefused |
||||
ttlExpired |
||||
commandNotSupported |
||||
addrTypeNotSupported |
||||
) |
||||
|
||||
var ( |
||||
unrecognizedAddrType = fmt.Errorf("Unrecognized address type") |
||||
) |
||||
|
||||
// AddressRewriter is used to rewrite a destination transparently
|
||||
type AddressRewriter interface { |
||||
Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec) |
||||
} |
||||
|
||||
// AddrSpec is used to return the target AddrSpec
|
||||
// which may be specified as IPv4, IPv6, or a FQDN
|
||||
type AddrSpec struct { |
||||
FQDN string |
||||
IP net.IP |
||||
Port int |
||||
} |
||||
|
||||
func (a *AddrSpec) String() string { |
||||
if a.FQDN != "" { |
||||
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) |
||||
} |
||||
return fmt.Sprintf("%s:%d", a.IP, a.Port) |
||||
} |
||||
|
||||
// Address returns a string suitable to dial; prefer returning IP-based
|
||||
// address, fallback to FQDN
|
||||
func (a AddrSpec) Address() string { |
||||
if 0 != len(a.IP) { |
||||
return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) |
||||
} |
||||
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) |
||||
} |
||||
|
||||
// A Request represents request received by a server
|
||||
type Request struct { |
||||
// Protocol version
|
||||
Version uint8 |
||||
// Requested command
|
||||
Command uint8 |
||||
// AuthContext provided during negotiation
|
||||
AuthContext *AuthContext |
||||
// AddrSpec of the the network that sent the request
|
||||
RemoteAddr *AddrSpec |
||||
// AddrSpec of the desired destination
|
||||
DestAddr *AddrSpec |
||||
// AddrSpec of the actual destination (might be affected by rewrite)
|
||||
realDestAddr *AddrSpec |
||||
bufConn io.Reader |
||||
} |
||||
|
||||
type conn interface { |
||||
Write([]byte) (int, error) |
||||
RemoteAddr() net.Addr |
||||
} |
||||
|
||||
// NewRequest creates a new Request from the tcp connection
|
||||
func NewRequest(bufConn io.Reader) (*Request, error) { |
||||
// Read the version byte
|
||||
header := []byte{0, 0, 0} |
||||
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { |
||||
return nil, fmt.Errorf("Failed to get command version: %v", err) |
||||
} |
||||
|
||||
// Ensure we are compatible
|
||||
if header[0] != socks5Version { |
||||
return nil, fmt.Errorf("Unsupported command version: %v", header[0]) |
||||
} |
||||
|
||||
// Read in the destination address
|
||||
dest, err := readAddrSpec(bufConn) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
request := &Request{ |
||||
Version: socks5Version, |
||||
Command: header[1], |
||||
DestAddr: dest, |
||||
bufConn: bufConn, |
||||
} |
||||
|
||||
return request, nil |
||||
} |
||||
|
||||
// handleRequest is used for request processing after authentication
|
||||
func (s *Server) handleRequest(req *Request, conn conn) error { |
||||
ctx := context.Background() |
||||
|
||||
// Resolve the address if we have a FQDN
|
||||
dest := req.DestAddr |
||||
if dest.FQDN != "" { |
||||
ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN) |
||||
if err != nil { |
||||
if err := sendReply(conn, hostUnreachable, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) |
||||
} |
||||
ctx = ctx_ |
||||
dest.IP = addr |
||||
} |
||||
|
||||
// Apply any address rewrites
|
||||
req.realDestAddr = req.DestAddr |
||||
if s.config.Rewriter != nil { |
||||
ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req) |
||||
} |
||||
|
||||
// Switch on the command
|
||||
switch req.Command { |
||||
case ConnectCommand: |
||||
return s.handleConnect(ctx, conn, req) |
||||
case BindCommand: |
||||
return s.handleBind(ctx, conn, req) |
||||
case AssociateCommand: |
||||
return s.handleAssociate(ctx, conn, req) |
||||
default: |
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return fmt.Errorf("Unsupported command: %v", req.Command) |
||||
} |
||||
} |
||||
|
||||
// handleConnect is used to handle a connect command
|
||||
func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error { |
||||
// Check if this is allowed
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { |
||||
if err := sendReply(conn, ruleFailure, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr) |
||||
} else { |
||||
ctx = ctx_ |
||||
} |
||||
|
||||
// Attempt to connect
|
||||
dial := s.config.Dial |
||||
if dial == nil { |
||||
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) { |
||||
return net.Dial(net_, addr) |
||||
} |
||||
} |
||||
target, err := dial(ctx, "tcp", req.realDestAddr.Address()) |
||||
if err != nil { |
||||
msg := err.Error() |
||||
resp := hostUnreachable |
||||
if strings.Contains(msg, "refused") { |
||||
resp = connectionRefused |
||||
} else if strings.Contains(msg, "network is unreachable") { |
||||
resp = networkUnreachable |
||||
} |
||||
if err := sendReply(conn, resp, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err) |
||||
} |
||||
defer target.Close() |
||||
|
||||
// Send success
|
||||
local := target.LocalAddr().(*net.TCPAddr) |
||||
bind := AddrSpec{IP: local.IP, Port: local.Port} |
||||
if err := sendReply(conn, successReply, &bind); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
|
||||
// Start proxying
|
||||
errCh := make(chan error, 2) |
||||
go proxy(target, req.bufConn, errCh) |
||||
go proxy(conn, target, errCh) |
||||
|
||||
// Wait
|
||||
for i := 0; i < 2; i++ { |
||||
e := <-errCh |
||||
if e != nil { |
||||
// return from this function closes target (and conn).
|
||||
return e |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// handleBind is used to handle a connect command
|
||||
func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error { |
||||
// Check if this is allowed
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { |
||||
if err := sendReply(conn, ruleFailure, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr) |
||||
} else { |
||||
ctx = ctx_ |
||||
} |
||||
|
||||
// TODO: Support bind
|
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// handleAssociate is used to handle a connect command
|
||||
func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error { |
||||
// Check if this is allowed
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { |
||||
if err := sendReply(conn, ruleFailure, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr) |
||||
} else { |
||||
ctx = ctx_ |
||||
} |
||||
|
||||
// TODO: Support associate
|
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// readAddrSpec is used to read AddrSpec.
|
||||
// Expects an address type byte, follwed by the address and port
|
||||
func readAddrSpec(r io.Reader) (*AddrSpec, error) { |
||||
d := &AddrSpec{} |
||||
|
||||
// Get the address type
|
||||
addrType := []byte{0} |
||||
if _, err := r.Read(addrType); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Handle on a per type basis
|
||||
switch addrType[0] { |
||||
case ipv4Address: |
||||
addr := make([]byte, 4) |
||||
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { |
||||
return nil, err |
||||
} |
||||
d.IP = net.IP(addr) |
||||
|
||||
case ipv6Address: |
||||
addr := make([]byte, 16) |
||||
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { |
||||
return nil, err |
||||
} |
||||
d.IP = net.IP(addr) |
||||
|
||||
case fqdnAddress: |
||||
if _, err := r.Read(addrType); err != nil { |
||||
return nil, err |
||||
} |
||||
addrLen := int(addrType[0]) |
||||
fqdn := make([]byte, addrLen) |
||||
if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil { |
||||
return nil, err |
||||
} |
||||
d.FQDN = string(fqdn) |
||||
|
||||
default: |
||||
return nil, unrecognizedAddrType |
||||
} |
||||
|
||||
// Read the port
|
||||
port := []byte{0, 0} |
||||
if _, err := io.ReadAtLeast(r, port, 2); err != nil { |
||||
return nil, err |
||||
} |
||||
d.Port = (int(port[0]) << 8) | int(port[1]) |
||||
|
||||
return d, nil |
||||
} |
||||
|
||||
// sendReply is used to send a reply message
|
||||
func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { |
||||
// Format the address
|
||||
var addrType uint8 |
||||
var addrBody []byte |
||||
var addrPort uint16 |
||||
switch { |
||||
case addr == nil: |
||||
addrType = ipv4Address |
||||
addrBody = []byte{0, 0, 0, 0} |
||||
addrPort = 0 |
||||
|
||||
case addr.FQDN != "": |
||||
addrType = fqdnAddress |
||||
addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) |
||||
addrPort = uint16(addr.Port) |
||||
|
||||
case addr.IP.To4() != nil: |
||||
addrType = ipv4Address |
||||
addrBody = []byte(addr.IP.To4()) |
||||
addrPort = uint16(addr.Port) |
||||
|
||||
case addr.IP.To16() != nil: |
||||
addrType = ipv6Address |
||||
addrBody = []byte(addr.IP.To16()) |
||||
addrPort = uint16(addr.Port) |
||||
|
||||
default: |
||||
return fmt.Errorf("Failed to format address: %v", addr) |
||||
} |
||||
|
||||
// Format the message
|
||||
msg := make([]byte, 6+len(addrBody)) |
||||
msg[0] = socks5Version |
||||
msg[1] = resp |
||||
msg[2] = 0 // Reserved
|
||||
msg[3] = addrType |
||||
copy(msg[4:], addrBody) |
||||
msg[4+len(addrBody)] = byte(addrPort >> 8) |
||||
msg[4+len(addrBody)+1] = byte(addrPort & 0xff) |
||||
|
||||
// Send the message
|
||||
_, err := w.Write(msg) |
||||
return err |
||||
} |
||||
|
||||
type closeWriter interface { |
||||
CloseWrite() error |
||||
} |
||||
|
||||
// proxy is used to suffle data from src to destination, and sends errors
|
||||
// down a dedicated channel
|
||||
func proxy(dst io.Writer, src io.Reader, errCh chan error) { |
||||
_, err := io.Copy(dst, src) |
||||
if tcpConn, ok := dst.(closeWriter); ok { |
||||
tcpConn.CloseWrite() |
||||
} |
||||
errCh <- err |
||||
} |
@ -0,0 +1,23 @@
|
||||
package socks5 |
||||
|
||||
import ( |
||||
"net" |
||||
|
||||
"golang.org/x/net/context" |
||||
) |
||||
|
||||
// NameResolver is used to implement custom name resolution
|
||||
type NameResolver interface { |
||||
Resolve(ctx context.Context, name string) (context.Context, net.IP, error) |
||||
} |
||||
|
||||
// DNSResolver uses the system DNS to resolve host names
|
||||
type DNSResolver struct{} |
||||
|
||||
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { |
||||
addr, err := net.ResolveIPAddr("ip", name) |
||||
if err != nil { |
||||
return ctx, nil, err |
||||
} |
||||
return ctx, addr.IP, err |
||||
} |
@ -0,0 +1,41 @@
|
||||
package socks5 |
||||
|
||||
import ( |
||||
"golang.org/x/net/context" |
||||
) |
||||
|
||||
// RuleSet is used to provide custom rules to allow or prohibit actions
|
||||
type RuleSet interface { |
||||
Allow(ctx context.Context, req *Request) (context.Context, bool) |
||||
} |
||||
|
||||
// PermitAll returns a RuleSet which allows all types of connections
|
||||
func PermitAll() RuleSet { |
||||
return &PermitCommand{true, true, true} |
||||
} |
||||
|
||||
// PermitNone returns a RuleSet which disallows all types of connections
|
||||
func PermitNone() RuleSet { |
||||
return &PermitCommand{false, false, false} |
||||
} |
||||
|
||||
// PermitCommand is an implementation of the RuleSet which
|
||||
// enables filtering supported commands
|
||||
type PermitCommand struct { |
||||
EnableConnect bool |
||||
EnableBind bool |
||||
EnableAssociate bool |
||||
} |
||||
|
||||
func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) { |
||||
switch req.Command { |
||||
case ConnectCommand: |
||||
return ctx, p.EnableConnect |
||||
case BindCommand: |
||||
return ctx, p.EnableBind |
||||
case AssociateCommand: |
||||
return ctx, p.EnableAssociate |
||||
} |
||||
|
||||
return ctx, false |
||||
} |
@ -0,0 +1,169 @@
|
||||
package socks5 |
||||
|
||||
import ( |
||||
"bufio" |
||||
"fmt" |
||||
"log" |
||||
"net" |
||||
"os" |
||||
|
||||
"golang.org/x/net/context" |
||||
) |
||||
|
||||
const ( |
||||
socks5Version = uint8(5) |
||||
) |
||||
|
||||
// Config is used to setup and configure a Server
|
||||
type Config struct { |
||||
// AuthMethods can be provided to implement custom authentication
|
||||
// By default, "auth-less" mode is enabled.
|
||||
// For password-based auth use UserPassAuthenticator.
|
||||
AuthMethods []Authenticator |
||||
|
||||
// If provided, username/password authentication is enabled,
|
||||
// by appending a UserPassAuthenticator to AuthMethods. If not provided,
|
||||
// and AUthMethods is nil, then "auth-less" mode is enabled.
|
||||
Credentials CredentialStore |
||||
|
||||
// Resolver can be provided to do custom name resolution.
|
||||
// Defaults to DNSResolver if not provided.
|
||||
Resolver NameResolver |
||||
|
||||
// Rules is provided to enable custom logic around permitting
|
||||
// various commands. If not provided, PermitAll is used.
|
||||
Rules RuleSet |
||||
|
||||
// Rewriter can be used to transparently rewrite addresses.
|
||||
// This is invoked before the RuleSet is invoked.
|
||||
// Defaults to NoRewrite.
|
||||
Rewriter AddressRewriter |
||||
|
||||
// BindIP is used for bind or udp associate
|
||||
BindIP net.IP |
||||
|
||||
// Logger can be used to provide a custom log target.
|
||||
// Defaults to stdout.
|
||||
Logger *log.Logger |
||||
|
||||
// Optional function for dialing out
|
||||
Dial func(ctx context.Context, network, addr string) (net.Conn, error) |
||||
} |
||||
|
||||
// Server is reponsible for accepting connections and handling
|
||||
// the details of the SOCKS5 protocol
|
||||
type Server struct { |
||||
config *Config |
||||
authMethods map[uint8]Authenticator |
||||
} |
||||
|
||||
// New creates a new Server and potentially returns an error
|
||||
func New(conf *Config) (*Server, error) { |
||||
// Ensure we have at least one authentication method enabled
|
||||
if len(conf.AuthMethods) == 0 { |
||||
if conf.Credentials != nil { |
||||
conf.AuthMethods = []Authenticator{&UserPassAuthenticator{conf.Credentials}} |
||||
} else { |
||||
conf.AuthMethods = []Authenticator{&NoAuthAuthenticator{}} |
||||
} |
||||
} |
||||
|
||||
// Ensure we have a DNS resolver
|
||||
if conf.Resolver == nil { |
||||
conf.Resolver = DNSResolver{} |
||||
} |
||||
|
||||
// Ensure we have a rule set
|
||||
if conf.Rules == nil { |
||||
conf.Rules = PermitAll() |
||||
} |
||||
|
||||
// Ensure we have a log target
|
||||
if conf.Logger == nil { |
||||
conf.Logger = log.New(os.Stdout, "", log.LstdFlags) |
||||
} |
||||
|
||||
server := &Server{ |
||||
config: conf, |
||||
} |
||||
|
||||
server.authMethods = make(map[uint8]Authenticator) |
||||
|
||||
for _, a := range conf.AuthMethods { |
||||
server.authMethods[a.GetCode()] = a |
||||
} |
||||
|
||||
return server, nil |
||||
} |
||||
|
||||
// ListenAndServe is used to create a listener and serve on it
|
||||
func (s *Server) ListenAndServe(network, addr string) error { |
||||
l, err := net.Listen(network, addr) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return s.Serve(l) |
||||
} |
||||
|
||||
// Serve is used to serve connections from a listener
|
||||
func (s *Server) Serve(l net.Listener) error { |
||||
for { |
||||
conn, err := l.Accept() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
go s.ServeConn(conn) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// ServeConn is used to serve a single connection.
|
||||
func (s *Server) ServeConn(conn net.Conn) error { |
||||
defer conn.Close() |
||||
bufConn := bufio.NewReader(conn) |
||||
|
||||
// Read the version byte
|
||||
version := []byte{0} |
||||
if _, err := bufConn.Read(version); err != nil { |
||||
s.config.Logger.Printf("[ERR] socks: Failed to get version byte: %v", err) |
||||
return err |
||||
} |
||||
|
||||
// Ensure we are compatible
|
||||
if version[0] != socks5Version { |
||||
err := fmt.Errorf("Unsupported SOCKS version: %v", version) |
||||
s.config.Logger.Printf("[ERR] socks: %v", err) |
||||
return err |
||||
} |
||||
|
||||
// Authenticate the connection
|
||||
authContext, err := s.authenticate(conn, bufConn) |
||||
if err != nil { |
||||
err = fmt.Errorf("Failed to authenticate: %v", err) |
||||
s.config.Logger.Printf("[ERR] socks: %v", err) |
||||
return err |
||||
} |
||||
|
||||
request, err := NewRequest(bufConn) |
||||
if err != nil { |
||||
if err == unrecognizedAddrType { |
||||
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil { |
||||
return fmt.Errorf("Failed to send reply: %v", err) |
||||
} |
||||
} |
||||
return fmt.Errorf("Failed to read destination address: %v", err) |
||||
} |
||||
request.AuthContext = authContext |
||||
if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok { |
||||
request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port} |
||||
} |
||||
|
||||
// Process the client request
|
||||
if err := s.handleRequest(request, conn); err != nil { |
||||
err = fmt.Errorf("Failed to handle request: %v", err) |
||||
s.config.Logger.Printf("[ERR] socks: %v", err) |
||||
return err |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,156 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package context defines the Context type, which carries deadlines,
|
||||
// cancelation signals, and other request-scoped values across API boundaries
|
||||
// and between processes.
|
||||
//
|
||||
// Incoming requests to a server should create a Context, and outgoing calls to
|
||||
// servers should accept a Context. The chain of function calls between must
|
||||
// propagate the Context, optionally replacing it with a modified copy created
|
||||
// using WithDeadline, WithTimeout, WithCancel, or WithValue.
|
||||
//
|
||||
// Programs that use Contexts should follow these rules to keep interfaces
|
||||
// consistent across packages and enable static analysis tools to check context
|
||||
// propagation:
|
||||
//
|
||||
// Do not store Contexts inside a struct type; instead, pass a Context
|
||||
// explicitly to each function that needs it. The Context should be the first
|
||||
// parameter, typically named ctx:
|
||||
//
|
||||
// func DoSomething(ctx context.Context, arg Arg) error {
|
||||
// // ... use ctx ...
|
||||
// }
|
||||
//
|
||||
// Do not pass a nil Context, even if a function permits it. Pass context.TODO
|
||||
// if you are unsure about which Context to use.
|
||||
//
|
||||
// Use context Values only for request-scoped data that transits processes and
|
||||
// APIs, not for passing optional parameters to functions.
|
||||
//
|
||||
// The same Context may be passed to functions running in different goroutines;
|
||||
// Contexts are safe for simultaneous use by multiple goroutines.
|
||||
//
|
||||
// See http://blog.golang.org/context for example code for a server that uses
|
||||
// Contexts.
|
||||
package context |
||||
|
||||
import "time" |
||||
|
||||
// A Context carries a deadline, a cancelation signal, and other values across
|
||||
// API boundaries.
|
||||
//
|
||||
// Context's methods may be called by multiple goroutines simultaneously.
|
||||
type Context interface { |
||||
// Deadline returns the time when work done on behalf of this context
|
||||
// should be canceled. Deadline returns ok==false when no deadline is
|
||||
// set. Successive calls to Deadline return the same results.
|
||||
Deadline() (deadline time.Time, ok bool) |
||||
|
||||
// Done returns a channel that's closed when work done on behalf of this
|
||||
// context should be canceled. Done may return nil if this context can
|
||||
// never be canceled. Successive calls to Done return the same value.
|
||||
//
|
||||
// WithCancel arranges for Done to be closed when cancel is called;
|
||||
// WithDeadline arranges for Done to be closed when the deadline
|
||||
// expires; WithTimeout arranges for Done to be closed when the timeout
|
||||
// elapses.
|
||||
//
|
||||
// Done is provided for use in select statements:
|
||||
//
|
||||
// // Stream generates values with DoSomething and sends them to out
|
||||
// // until DoSomething returns an error or ctx.Done is closed.
|
||||
// func Stream(ctx context.Context, out chan<- Value) error {
|
||||
// for {
|
||||
// v, err := DoSomething(ctx)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// select {
|
||||
// case <-ctx.Done():
|
||||
// return ctx.Err()
|
||||
// case out <- v:
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// See http://blog.golang.org/pipelines for more examples of how to use
|
||||
// a Done channel for cancelation.
|
||||
Done() <-chan struct{} |
||||
|
||||
// Err returns a non-nil error value after Done is closed. Err returns
|
||||
// Canceled if the context was canceled or DeadlineExceeded if the
|
||||
// context's deadline passed. No other values for Err are defined.
|
||||
// After Done is closed, successive calls to Err return the same value.
|
||||
Err() error |
||||
|
||||
// Value returns the value associated with this context for key, or nil
|
||||
// if no value is associated with key. Successive calls to Value with
|
||||
// the same key returns the same result.
|
||||
//
|
||||
// Use context values only for request-scoped data that transits
|
||||
// processes and API boundaries, not for passing optional parameters to
|
||||
// functions.
|
||||
//
|
||||
// A key identifies a specific value in a Context. Functions that wish
|
||||
// to store values in Context typically allocate a key in a global
|
||||
// variable then use that key as the argument to context.WithValue and
|
||||
// Context.Value. A key can be any type that supports equality;
|
||||
// packages should define keys as an unexported type to avoid
|
||||
// collisions.
|
||||
//
|
||||
// Packages that define a Context key should provide type-safe accessors
|
||||
// for the values stores using that key:
|
||||
//
|
||||
// // Package user defines a User type that's stored in Contexts.
|
||||
// package user
|
||||
//
|
||||
// import "golang.org/x/net/context"
|
||||
//
|
||||
// // User is the type of value stored in the Contexts.
|
||||
// type User struct {...}
|
||||
//
|
||||
// // key is an unexported type for keys defined in this package.
|
||||
// // This prevents collisions with keys defined in other packages.
|
||||
// type key int
|
||||
//
|
||||
// // userKey is the key for user.User values in Contexts. It is
|
||||
// // unexported; clients use user.NewContext and user.FromContext
|
||||
// // instead of using this key directly.
|
||||
// var userKey key = 0
|
||||
//
|
||||
// // NewContext returns a new Context that carries value u.
|
||||
// func NewContext(ctx context.Context, u *User) context.Context {
|
||||
// return context.WithValue(ctx, userKey, u)
|
||||
// }
|
||||
//
|
||||
// // FromContext returns the User value stored in ctx, if any.
|
||||
// func FromContext(ctx context.Context) (*User, bool) {
|
||||
// u, ok := ctx.Value(userKey).(*User)
|
||||
// return u, ok
|
||||
// }
|
||||
Value(key interface{}) interface{} |
||||
} |
||||
|
||||
// Background returns a non-nil, empty Context. It is never canceled, has no
|
||||
// values, and has no deadline. It is typically used by the main function,
|
||||
// initialization, and tests, and as the top-level Context for incoming
|
||||
// requests.
|
||||
func Background() Context { |
||||
return background |
||||
} |
||||
|
||||
// TODO returns a non-nil, empty Context. Code should use context.TODO when
|
||||
// it's unclear which Context to use or it is not yet available (because the
|
||||
// surrounding function has not yet been extended to accept a Context
|
||||
// parameter). TODO is recognized by static analysis tools that determine
|
||||
// whether Contexts are propagated correctly in a program.
|
||||
func TODO() Context { |
||||
return todo |
||||
} |
||||
|
||||
// A CancelFunc tells an operation to abandon its work.
|
||||
// A CancelFunc does not wait for the work to stop.
|
||||
// After the first call, subsequent calls to a CancelFunc do nothing.
|
||||
type CancelFunc func() |
@ -0,0 +1,72 @@
|
||||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build go1.7
|
||||
|
||||
package context |
||||
|
||||
import ( |
||||
"context" // standard library's context, as of Go 1.7
|
||||
"time" |
||||
) |
||||
|
||||
var ( |
||||
todo = context.TODO() |
||||
background = context.Background() |
||||
) |
||||
|
||||
// Canceled is the error returned by Context.Err when the context is canceled.
|
||||
var Canceled = context.Canceled |
||||
|
||||
// DeadlineExceeded is the error returned by Context.Err when the context's
|
||||
// deadline passes.
|
||||
var DeadlineExceeded = context.DeadlineExceeded |
||||
|
||||
// WithCancel returns a copy of parent with a new Done channel. The returned
|
||||
// context's Done channel is closed when the returned cancel function is called
|
||||
// or when the parent context's Done channel is closed, whichever happens first.
|
||||
//
|
||||
// Canceling this context releases resources associated with it, so code should
|
||||
// call cancel as soon as the operations running in this Context complete.
|
||||
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { |
||||
ctx, f := context.WithCancel(parent) |
||||
return ctx, CancelFunc(f) |
||||
} |
||||
|
||||
// WithDeadline returns a copy of the parent context with the deadline adjusted
|
||||
// to be no later than d. If the parent's deadline is already earlier than d,
|
||||
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
|
||||
// context's Done channel is closed when the deadline expires, when the returned
|
||||
// cancel function is called, or when the parent context's Done channel is
|
||||
// closed, whichever happens first.
|
||||
//
|
||||
// Canceling this context releases resources associated with it, so code should
|
||||
// call cancel as soon as the operations running in this Context complete.
|
||||
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { |
||||
ctx, f := context.WithDeadline(parent, deadline) |
||||
return ctx, CancelFunc(f) |
||||
} |
||||
|
||||
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
|
||||
//
|
||||
// Canceling this context releases resources associated with it, so code should
|
||||
// call cancel as soon as the operations running in this Context complete:
|
||||
//
|
||||
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
|
||||
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
// defer cancel() // releases resources if slowOperation completes before timeout elapses
|
||||
// return slowOperation(ctx)
|
||||
// }
|
||||
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { |
||||
return WithDeadline(parent, time.Now().Add(timeout)) |
||||
} |
||||
|
||||
// WithValue returns a copy of parent in which the value associated with key is
|
||||
// val.
|
||||
//
|
||||
// Use context Values only for request-scoped data that transits processes and
|
||||
// APIs, not for passing optional parameters to functions.
|
||||
func WithValue(parent Context, key interface{}, val interface{}) Context { |
||||
return context.WithValue(parent, key, val) |
||||
} |
@ -0,0 +1,300 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !go1.7
|
||||
|
||||
package context |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"sync" |
||||
"time" |
||||
) |
||||
|
||||
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
|
||||
// struct{}, since vars of this type must have distinct addresses.
|
||||
type emptyCtx int |
||||
|
||||
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { |
||||
return |
||||
} |
||||
|
||||
func (*emptyCtx) Done() <-chan struct{} { |
||||
return nil |
||||
} |
||||
|
||||
func (*emptyCtx) Err() error { |
||||
return nil |
||||
} |
||||
|
||||
func (*emptyCtx) Value(key interface{}) interface{} { |
||||
return nil |
||||
} |
||||
|
||||
func (e *emptyCtx) String() string { |
||||
switch e { |
||||
case background: |
||||
return "context.Background" |
||||
case todo: |
||||
return "context.TODO" |
||||
} |
||||
return "unknown empty Context" |
||||
} |
||||
|
||||
var ( |
||||
background = new(emptyCtx) |
||||
todo = new(emptyCtx) |
||||
) |
||||
|
||||
// Canceled is the error returned by Context.Err when the context is canceled.
|
||||
var Canceled = errors.New("context canceled") |
||||
|
||||
// DeadlineExceeded is the error returned by Context.Err when the context's
|
||||
// deadline passes.
|
||||
var DeadlineExceeded = errors.New("context deadline exceeded") |
||||
|
||||
// WithCancel returns a copy of parent with a new Done channel. The returned
|
||||
// context's Done channel is closed when the returned cancel function is called
|
||||
// or when the parent context's Done channel is closed, whichever happens first.
|
||||
//
|
||||
// Canceling this context releases resources associated with it, so code should
|
||||
// call cancel as soon as the operations running in this Context complete.
|
||||
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { |
||||
c := newCancelCtx(parent) |
||||
propagateCancel(parent, c) |
||||
return c, func() { c.cancel(true, Canceled) } |
||||
} |
||||
|
||||
// newCancelCtx returns an initialized cancelCtx.
|
||||
func newCancelCtx(parent Context) *cancelCtx { |
||||
return &cancelCtx{ |
||||
Context: parent, |
||||
done: make(chan struct{}), |
||||
} |
||||
} |
||||
|
||||
// propagateCancel arranges for child to be canceled when parent is.
|
||||
func propagateCancel(parent Context, child canceler) { |
||||
if parent.Done() == nil { |
||||
return // parent is never canceled
|
||||
} |
||||
if p, ok := parentCancelCtx(parent); ok { |
||||
p.mu.Lock() |
||||
if p.err != nil { |
||||
// parent has already been canceled
|
||||
child.cancel(false, p.err) |
||||
} else { |
||||
if p.children == nil { |
||||
p.children = make(map[canceler]bool) |
||||
} |
||||
p.children[child] = true |
||||
} |
||||
p.mu.Unlock() |
||||
} else { |
||||
go func() { |
||||
select { |
||||
case <-parent.Done(): |
||||
child.cancel(false, parent.Err()) |
||||
case <-child.Done(): |
||||
} |
||||
}() |
||||
} |
||||
} |
||||
|
||||
// parentCancelCtx follows a chain of parent references until it finds a
|
||||
// *cancelCtx. This function understands how each of the concrete types in this
|
||||
// package represents its parent.
|
||||
func parentCancelCtx(parent Context) (*cancelCtx, bool) { |
||||
for { |
||||
switch c := parent.(type) { |
||||
case *cancelCtx: |
||||
return c, true |
||||
case *timerCtx: |
||||
return c.cancelCtx, true |
||||
case *valueCtx: |
||||
parent = c.Context |
||||
default: |
||||
return nil, false |
||||
} |
||||
} |
||||
} |
||||
|
||||
// removeChild removes a context from its parent.
|
||||
func removeChild(parent Context, child canceler) { |
||||
p, ok := parentCancelCtx(parent) |
||||
if !ok { |
||||
return |
||||
} |
||||
p.mu.Lock() |
||||
if p.children != nil { |
||||
delete(p.children, child) |
||||
} |
||||
p.mu.Unlock() |
||||
} |
||||
|
||||
// A canceler is a context type that can be canceled directly. The
|
||||
// implementations are *cancelCtx and *timerCtx.
|
||||
type canceler interface { |
||||
cancel(removeFromParent bool, err error) |
||||
Done() <-chan struct{} |
||||
} |
||||
|
||||
// A cancelCtx can be canceled. When canceled, it also cancels any children
|
||||
// that implement canceler.
|
||||
type cancelCtx struct { |
||||
Context |
||||
|
||||
done chan struct{} // closed by the first cancel call.
|
||||
|
||||
mu sync.Mutex |
||||
children map[canceler]bool // set to nil by the first cancel call
|
||||
err error // set to non-nil by the first cancel call
|
||||
} |
||||
|
||||
func (c *cancelCtx) Done() <-chan struct{} { |
||||
return c.done |
||||
} |
||||
|
||||
func (c *cancelCtx) Err() error { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
return c.err |
||||
} |
||||
|
||||
func (c *cancelCtx) String() string { |
||||
return fmt.Sprintf("%v.WithCancel", c.Context) |
||||
} |
||||
|
||||
// cancel closes c.done, cancels each of c's children, and, if
|
||||
// removeFromParent is true, removes c from its parent's children.
|
||||
func (c *cancelCtx) cancel(removeFromParent bool, err error) { |
||||
if err == nil { |
||||
panic("context: internal error: missing cancel error") |
||||
} |
||||
c.mu.Lock() |
||||
if c.err != nil { |
||||
c.mu.Unlock() |
||||
return // already canceled
|
||||
} |
||||
c.err = err |
||||
close(c.done) |
||||
for child := range c.children { |
||||
// NOTE: acquiring the child's lock while holding parent's lock.
|
||||
child.cancel(false, err) |
||||
} |
||||
c.children = nil |
||||
c.mu.Unlock() |
||||
|
||||
if removeFromParent { |
||||
removeChild(c.Context, c) |
||||
} |
||||
} |
||||
|
||||
// WithDeadline returns a copy of the parent context with the deadline adjusted
|
||||
// to be no later than d. If the parent's deadline is already earlier than d,
|
||||
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
|
||||
// context's Done channel is closed when the deadline expires, when the returned
|
||||
// cancel function is called, or when the parent context's Done channel is
|
||||
// closed, whichever happens first.
|
||||
//
|
||||
// Canceling this context releases resources associated with it, so code should
|
||||
// call cancel as soon as the operations running in this Context complete.
|
||||
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { |
||||
if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { |
||||
// The current deadline is already sooner than the new one.
|
||||
return WithCancel(parent) |
||||
} |
||||
c := &timerCtx{ |
||||
cancelCtx: newCancelCtx(parent), |
||||
deadline: deadline, |
||||
} |
||||
propagateCancel(parent, c) |
||||
d := deadline.Sub(time.Now()) |
||||
if d <= 0 { |
||||
c.cancel(true, DeadlineExceeded) // deadline has already passed
|
||||
return c, func() { c.cancel(true, Canceled) } |
||||
} |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
if c.err == nil { |
||||
c.timer = time.AfterFunc(d, func() { |
||||
c.cancel(true, DeadlineExceeded) |
||||
}) |
||||
} |
||||
return c, func() { c.cancel(true, Canceled) } |
||||
} |
||||
|
||||
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
|
||||
// implement Done and Err. It implements cancel by stopping its timer then
|
||||
// delegating to cancelCtx.cancel.
|
||||
type timerCtx struct { |
||||
*cancelCtx |
||||
timer *time.Timer // Under cancelCtx.mu.
|
||||
|
||||
deadline time.Time |
||||
} |
||||
|
||||
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { |
||||
return c.deadline, true |
||||
} |
||||
|
||||
func (c *timerCtx) String() string { |
||||
return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) |
||||
} |
||||
|
||||
func (c *timerCtx) cancel(removeFromParent bool, err error) { |
||||
c.cancelCtx.cancel(false, err) |
||||
if removeFromParent { |
||||
// Remove this timerCtx from its parent cancelCtx's children.
|
||||
removeChild(c.cancelCtx.Context, c) |
||||
} |
||||
c.mu.Lock() |
||||
if c.timer != nil { |
||||
c.timer.Stop() |
||||
c.timer = nil |
||||
} |
||||
c.mu.Unlock() |
||||
} |
||||
|
||||
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
|
||||
//
|
||||
// Canceling this context releases resources associated with it, so code should
|
||||
// call cancel as soon as the operations running in this Context complete:
|
||||
//
|
||||
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
|
||||
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
// defer cancel() // releases resources if slowOperation completes before timeout elapses
|
||||
// return slowOperation(ctx)
|
||||
// }
|
||||
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { |
||||
return WithDeadline(parent, time.Now().Add(timeout)) |
||||
} |
||||
|
||||
// WithValue returns a copy of parent in which the value associated with key is
|
||||
// val.
|
||||
//
|
||||
// Use context Values only for request-scoped data that transits processes and
|
||||
// APIs, not for passing optional parameters to functions.
|
||||
func WithValue(parent Context, key interface{}, val interface{}) Context { |
||||
return &valueCtx{parent, key, val} |
||||
} |
||||
|
||||
// A valueCtx carries a key-value pair. It implements Value for that key and
|
||||
// delegates all other calls to the embedded Context.
|
||||
type valueCtx struct { |
||||
Context |
||||
key, val interface{} |
||||
} |
||||
|
||||
func (c *valueCtx) String() string { |
||||
return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) |
||||
} |
||||
|
||||
func (c *valueCtx) Value(key interface{}) interface{} { |
||||
if c.key == key { |
||||
return c.val |
||||
} |
||||
return c.Context.Value(key) |
||||
} |
Loading…
Reference in new issue