Browse Source

WireGuard Inbound (User-space WireGuard server) (#2477)

* feat: wireguard inbound

* feat(command): generate wireguard compatible keypair

* feat(wireguard): connection idle timeout

* fix(wireguard): close endpoint after connection closed

* fix(wireguard): resolve conflicts

* feat(wireguard): set cubic as default cc algorithm in gVisor TUN

* chore(wireguard): resolve conflict

* chore(wireguard): remove redurant code

* chore(wireguard): remove redurant code

* feat: rework server for gvisor tun

* feat: keep user-space tun as an option

* fix: exclude android from native tun build

* feat: auto kernel tun

* fix: build

* fix: regulate function name & fix test
pull/2734/head
hax0r31337 1 year ago committed by GitHub
parent
commit
0ac7da2fc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      go.mod
  2. 3
      go.sum
  3. 61
      infra/conf/wireguard.go
  4. 7
      infra/conf/wireguard_test.go
  5. 3
      infra/conf/xray.go
  6. 15
      main/commands/all/x25519.go
  7. 170
      proxy/wireguard/bind.go
  8. 255
      proxy/wireguard/client.go
  9. 7
      proxy/wireguard/config.go
  10. 58
      proxy/wireguard/config.pb.go
  11. 4
      proxy/wireguard/config.proto
  12. 230
      proxy/wireguard/gvisortun/tun.go
  13. 181
      proxy/wireguard/server.go
  14. 100
      proxy/wireguard/tun.go
  15. 38
      proxy/wireguard/tun_default.go
  16. 16
      proxy/wireguard/tun_linux.go
  17. 333
      proxy/wireguard/wireguard.go

4
go.mod

@ -27,6 +27,7 @@ require (
golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb
google.golang.org/grpc v1.59.0
google.golang.org/protobuf v1.31.0
gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b
h12.io/socks v1.0.3
lukechampine.com/blake3 v1.2.1
)
@ -48,7 +49,7 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect
go.uber.org/mock v0.3.0 // indirect
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect
golang.org/x/mod v0.14.0 // indirect
@ -59,5 +60,4 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20231106174013-bbf56f31fb17 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b // indirect
)

3
go.sum

@ -168,9 +168,8 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3 h1:tkMT5pTye+1NlKIXETU78NXw0fyjnaNHmJyyLyzw8+U=
github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3/go.mod h1:cAAsePK2e15YDAMJNyOpGYEWNe4sIghTY7gpz4cX/Ik=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19 h1:capMfFYRgH9BCLd6A3Er/cH3A9Nz3CU2KwxwOQZIePI=
github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19/go.mod h1:dm4y/1QwzjGaK17ofi0Vs6NpKAHegZky8qk6J2JJZAE=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=

61
infra/conf/wireguard.go

@ -13,7 +13,7 @@ type WireGuardPeerConfig struct {
PublicKey string `json:"publicKey"`
PreSharedKey string `json:"preSharedKey"`
Endpoint string `json:"endpoint"`
KeepAlive int `json:"keepAlive"`
KeepAlive uint32 `json:"keepAlive"`
AllowedIPs []string `json:"allowedIPs,omitempty"`
}
@ -21,23 +21,23 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
var err error
config := new(wireguard.PeerConfig)
if c.PublicKey != "" {
config.PublicKey, err = parseWireGuardKey(c.PublicKey)
if err != nil {
return nil, err
}
}
if c.PreSharedKey != "" {
config.PreSharedKey, err = parseWireGuardKey(c.PreSharedKey)
if err != nil {
return nil, err
}
} else {
config.PreSharedKey = "0000000000000000000000000000000000000000000000000000000000000000"
}
config.Endpoint = c.Endpoint
// default 0
config.KeepAlive = int32(c.KeepAlive)
config.KeepAlive = c.KeepAlive
if c.AllowedIPs == nil {
config.AllowedIps = []string{"0.0.0.0/0", "::0/0"}
} else {
@ -48,11 +48,14 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
}
type WireGuardConfig struct {
IsClient bool `json:""`
KernelMode *bool `json:"kernelMode"`
SecretKey string `json:"secretKey"`
Address []string `json:"address"`
Peers []*WireGuardPeerConfig `json:"peers"`
MTU int `json:"mtu"`
NumWorkers int `json:"workers"`
MTU int32 `json:"mtu"`
NumWorkers int32 `json:"workers"`
Reserved []byte `json:"reserved"`
DomainStrategy string `json:"domainStrategy"`
}
@ -87,11 +90,11 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
if c.MTU == 0 {
config.Mtu = 1420
} else {
config.Mtu = int32(c.MTU)
config.Mtu = c.MTU
}
// these a fallback code exists in github.com/nanoda0523/wireguard-go code,
// these a fallback code exists in wireguard-go code,
// we don't need to process fallback manually
config.NumWorkers = int32(c.NumWorkers)
config.NumWorkers = c.NumWorkers
if len(c.Reserved) != 0 && len(c.Reserved) != 3 {
return nil, newError(`"reserved" should be empty or 3 bytes`)
@ -113,22 +116,42 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
return nil, newError("unsupported domain strategy: ", c.DomainStrategy)
}
config.IsClient = c.IsClient
if c.KernelMode != nil {
config.KernelMode = *c.KernelMode
if config.KernelMode && !wireguard.KernelTunSupported() {
newError("kernel mode is not supported on your OS or permission is insufficient").AtWarning().WriteToLog()
}
} else {
config.KernelMode = wireguard.KernelTunSupported()
if config.KernelMode {
newError("kernel mode is enabled as it's supported and permission is sufficient").AtDebug().WriteToLog()
}
}
return config, nil
}
func parseWireGuardKey(str string) (string, error) {
if len(str) != 64 {
// may in base64 form
dat, err := base64.StdEncoding.DecodeString(str)
if err != nil {
return "", err
var err error
if len(str)%2 == 0 {
_, err = hex.DecodeString(str)
if err == nil {
return str, nil
}
if len(dat) != 32 {
return "", newError("key should be 32 bytes: " + str)
}
return hex.EncodeToString(dat), err
var dat []byte
str = strings.TrimSuffix(str, "=")
if strings.ContainsRune(str, '+') || strings.ContainsRune(str, '/') {
dat, err = base64.RawStdEncoding.DecodeString(str)
} else {
// already hex form
return str, nil
dat, err = base64.RawURLEncoding.DecodeString(str)
}
if err == nil {
return hex.EncodeToString(dat), nil
}
return "", newError("failed to deserialize key").Base(err)
}

7
infra/conf/wireguard_test.go

@ -7,7 +7,7 @@ import (
"github.com/xtls/xray-core/proxy/wireguard"
)
func TestWireGuardOutbound(t *testing.T) {
func TestWireGuardConfig(t *testing.T) {
creator := func() Buildable {
return new(WireGuardConfig)
}
@ -25,7 +25,8 @@ func TestWireGuardOutbound(t *testing.T) {
],
"mtu": 1300,
"workers": 2,
"domainStrategy": "ForceIPv6v4"
"domainStrategy": "ForceIPv6v4",
"kernelMode": false
}`,
Parser: loadJSON(creator),
Output: &wireguard.DeviceConfig{
@ -36,7 +37,6 @@ func TestWireGuardOutbound(t *testing.T) {
{
// also can read from hex form directly
PublicKey: "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a",
PreSharedKey: "0000000000000000000000000000000000000000000000000000000000000000",
Endpoint: "127.0.0.1:1234",
KeepAlive: 0,
AllowedIps: []string{"0.0.0.0/0", "::0/0"},
@ -45,6 +45,7 @@ func TestWireGuardOutbound(t *testing.T) {
Mtu: 1300,
NumWorkers: 2,
DomainStrategy: wireguard.DeviceConfig_FORCE_IP64,
KernelMode: false,
},
},
})

3
infra/conf/xray.go

@ -24,6 +24,7 @@ var (
"vless": func() interface{} { return new(VLessInboundConfig) },
"vmess": func() interface{} { return new(VMessInboundConfig) },
"trojan": func() interface{} { return new(TrojanServerConfig) },
"wireguard": func() interface{} { return &WireGuardConfig{IsClient: false} },
}, "protocol", "settings")
outboundConfigLoader = NewJSONConfigLoader(ConfigCreatorCache{
@ -37,7 +38,7 @@ var (
"vmess": func() interface{} { return new(VMessOutboundConfig) },
"trojan": func() interface{} { return new(TrojanClientConfig) },
"dns": func() interface{} { return new(DNSOutboundConfig) },
"wireguard": func() interface{} { return new(WireGuardConfig) },
"wireguard": func() interface{} { return &WireGuardConfig{IsClient: true} },
}, "protocol", "settings")
ctllog = log.New(os.Stderr, "xctl> ", 0)

15
main/commands/all/x25519.go

@ -10,7 +10,7 @@ import (
)
var cmdX25519 = &base.Command{
UsageLine: `{{.Exec}} x25519 [-i "private key (base64.RawURLEncoding)"]`,
UsageLine: `{{.Exec}} x25519 [-i "private key (base64.RawURLEncoding)"] [--std-encoding]`,
Short: `Generate key pair for x25519 key exchange`,
Long: `
Generate key pair for x25519 key exchange.
@ -18,6 +18,7 @@ Generate key pair for x25519 key exchange.
Random: {{.Exec}} x25519
From private key: {{.Exec}} x25519 -i "private key (base64.RawURLEncoding)"
For Std Encoding: {{.Exec}} x25519 --std-encoding
`,
}
@ -26,12 +27,14 @@ func init() {
}
var input_base64 = cmdX25519.Flag.String("i", "", "")
var input_stdEncoding = cmdX25519.Flag.Bool("std-encoding", false, "")
func executeX25519(cmd *base.Command, args []string) {
var output string
var err error
var privateKey []byte
var publicKey []byte
var encoding *base64.Encoding
if len(*input_base64) > 0 {
privateKey, err = base64.RawURLEncoding.DecodeString(*input_base64)
if err != nil {
@ -63,9 +66,15 @@ func executeX25519(cmd *base.Command, args []string) {
goto out
}
if *input_stdEncoding {
encoding = base64.StdEncoding
} else {
encoding = base64.RawURLEncoding
}
output = fmt.Sprintf("Private key: %v\nPublic key: %v",
base64.RawURLEncoding.EncodeToString(privateKey),
base64.RawURLEncoding.EncodeToString(publicKey))
encoding.EncodeToString(privateKey),
encoding.EncodeToString(publicKey))
out:
fmt.Println(output)
}

170
proxy/wireguard/bind.go

@ -27,48 +27,45 @@ type netReadInfo struct {
err error
}
type netBindClient struct {
workers int
dialer internet.Dialer
// reduce duplicated code
type netBind struct {
dns dns.Client
dnsOption dns.IPOption
reserved []byte
workers int
readQueue chan *netReadInfo
}
func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
ipStr, port, _, err := splitAddrPort(s)
// SetMark implements conn.Bind
func (bind *netBind) SetMark(mark uint32) error {
return nil
}
// ParseEndpoint implements conn.Bind
func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
ipStr, port, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
portNum, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
var addr net.IP
if IsDomainName(ipStr) {
ips, err := bind.dns.LookupIP(ipStr, bind.dnsOption)
addr := xnet.ParseAddress(ipStr)
if addr.Family() == xnet.AddressFamilyDomain {
ips, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, dns.ErrEmptyResponse
}
addr = ips[0]
} else {
addr = net.ParseIP(ipStr)
}
if addr == nil {
return nil, errors.New("failed to parse ip: " + ipStr)
}
var ip xnet.Address
if p4 := addr.To4(); len(p4) == net.IPv4len {
ip = xnet.IPAddress(p4[:])
} else {
ip = xnet.IPAddress(addr[:])
addr = xnet.IPAddress(ips[0])
}
dst := xnet.Destination{
Address: ip,
Port: xnet.Port(port),
Address: addr,
Port: xnet.Port(portNum),
Network: xnet.Network_UDP,
}
@ -77,7 +74,13 @@ func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
}, nil
}
func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
// BatchSize implements conn.Bind
func (bind *netBind) BatchSize() int {
return 1
}
// Open implements conn.Bind
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
bind.readQueue = make(chan *netReadInfo)
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
@ -109,13 +112,21 @@ func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error
return arr, uint16(uport), nil
}
func (bind *netBindClient) Close() error {
// Close implements conn.Bind
func (bind *netBind) Close() error {
if bind.readQueue != nil {
close(bind.readQueue)
}
return nil
}
type netBindClient struct {
netBind
dialer internet.Dialer
reserved []byte
}
func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
if err != nil {
@ -177,12 +188,29 @@ func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
return nil
}
func (bind *netBindClient) SetMark(mark uint32) error {
return nil
type netBindServer struct {
netBind
}
func (bind *netBindClient) BatchSize() int {
return 1
func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
var err error
nend, ok := endpoint.(*netEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
if nend.conn == nil {
return newError("connection not open yet")
}
for _, buff := range buff {
if _, err = nend.conn.Write(buff); err != nil {
return err
}
}
return err
}
type netEndpoint struct {
@ -193,7 +221,7 @@ type netEndpoint struct {
func (netEndpoint) ClearSrc() {}
func (e netEndpoint) DstIP() netip.Addr {
return toNetIpAddr(e.dst.Address)
return netip.Addr{}
}
func (e netEndpoint) SrcIP() netip.Addr {
@ -232,83 +260,3 @@ func toNetIpAddr(addr xnet.Address) netip.Addr {
return netip.AddrFrom16(arr)
}
}
func stringsLastIndexByte(s string, b byte) int {
for i := len(s) - 1; i >= 0; i-- {
if s[i] == b {
return i
}
}
return -1
}
func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
i := stringsLastIndexByte(s, ':')
if i == -1 {
return "", 0, false, errors.New("not an ip:port")
}
ip = s[:i]
portStr := s[i+1:]
if len(ip) == 0 {
return "", 0, false, errors.New("no IP")
}
if len(portStr) == 0 {
return "", 0, false, errors.New("no port")
}
port64, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s))
}
port = uint16(port64)
if ip[0] == '[' {
if len(ip) < 2 || ip[len(ip)-1] != ']' {
return "", 0, false, errors.New("missing ]")
}
ip = ip[1 : len(ip)-1]
v6 = true
}
return ip, port, v6, nil
}
func IsDomainName(s string) bool {
l := len(s)
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
return false
}
last := byte('.')
nonNumeric := false
partlen := 0
for i := 0; i < len(s); i++ {
c := s[i]
switch {
default:
return false
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
nonNumeric = true
partlen++
case '0' <= c && c <= '9':
partlen++
case c == '-':
if last == '.' {
return false
}
partlen++
nonNumeric = true
case c == '.':
if last == '.' || last == '-' {
return false
}
if partlen > 63 || partlen == 0 {
return false
}
partlen = 0
}
last = c
}
if last == '-' || partlen > 63 {
return false
}
return nonNumeric
}

255
proxy/wireguard/client.go

@ -0,0 +1,255 @@
/*
Some of codes are copied from https://github.com/octeep/wireproxy, license below.
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package wireguard
import (
"context"
"net/netip"
"sync"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/dice"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet"
)
// Handler is an outbound connection that silently swallow the entire payload.
type Handler struct {
conf *DeviceConfig
net Tunnel
bind *netBindClient
policyManager policy.Manager
dns dns.Client
// cached configuration
ipc string
endpoints []netip.Addr
hasIPv4, hasIPv6 bool
wgLock sync.Mutex
}
// New creates a new wireguard handler.
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
v := core.MustFromContext(ctx)
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
if err != nil {
return nil, err
}
d := v.GetFeature(dns.ClientType()).(dns.Client)
return &Handler{
conf: conf,
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
dns: d,
ipc: createIPCRequest(conf),
endpoints: endpoints,
hasIPv4: hasIPv4,
hasIPv6: hasIPv6,
}, nil
}
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
h.wgLock.Lock()
defer h.wgLock.Unlock()
if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
return nil
}
log.Record(&log.GeneralMessage{
Severity: log.Severity_Info,
Content: "switching dialer",
})
if h.net != nil {
_ = h.net.Close()
h.net = nil
}
if h.bind != nil {
_ = h.bind.Close()
h.bind = nil
}
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
bind := &netBindClient{
netBind: netBind{
dns: h.dns,
dnsOption: dns.IPOption{
IPv4Enable: h.hasIPv4,
IPv6Enable: h.hasIPv6,
},
workers: int(h.conf.NumWorkers),
},
dialer: dialer,
reserved: h.conf.Reserved,
}
defer func() {
if err != nil {
_ = bind.Close()
}
}()
h.net, err = h.makeVirtualTun(bind)
if err != nil {
return newError("failed to create virtual tun interface").Base(err)
}
h.bind = bind
return nil
}
// Process implements OutboundHandler.Dispatch().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified")
}
outbound.Name = "wireguard"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
if err := h.processWireGuard(dialer); err != nil {
return err
}
// Destination of the inner request.
destination := outbound.Target
command := protocol.RequestCommandTCP
if destination.Network == net.Network_UDP {
command = protocol.RequestCommandUDP
}
// resolve dns
addr := destination.Address
if addr.Family().IsDomain() {
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
})
{ // Resolve fallback
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
})
}
}
if err != nil {
return newError("failed to lookup DNS").Base(err)
} else if len(ips) == 0 {
return dns.ErrEmptyResponse
}
addr = net.IPAddress(ips[dice.Roll(len(ips))])
}
var newCtx context.Context
var newCancel context.CancelFunc
if session.TimeoutOnlyFromContext(ctx) {
newCtx, newCancel = context.WithCancel(context.Background())
}
p := h.policyManager.ForLevel(0)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, func() {
cancel()
if newCancel != nil {
newCancel()
}
}, p.Timeouts.ConnectionIdle)
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
var requestFunc func() error
var responseFunc func() error
if command == protocol.RequestCommandTCP {
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
if err != nil {
return newError("failed to create TCP connection").Base(err)
}
defer conn.Close()
requestFunc = func() error {
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
}
responseFunc = func() error {
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
}
} else if command == protocol.RequestCommandUDP {
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
if err != nil {
return newError("failed to create UDP connection").Base(err)
}
defer conn.Close()
requestFunc = func() error {
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
}
responseFunc = func() error {
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
}
}
if newCtx != nil {
ctx = newCtx
}
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return newError("connection ends").Base(err)
}
return nil
}
// creates a tun interface on netstack given a configuration
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil)
if err != nil {
return nil, err
}
bind.dnsOption.IPv4Enable = h.hasIPv4
bind.dnsOption.IPv6Enable = h.hasIPv6
if err = t.BuildDevice(h.ipc, bind); err != nil {
_ = t.Close()
return nil, err
}
return t, nil
}

7
proxy/wireguard/config.go

@ -23,3 +23,10 @@ func (c *DeviceConfig) fallbackIP4() bool {
func (c *DeviceConfig) fallbackIP6() bool {
return c.DomainStrategy == DeviceConfig_FORCE_IP46
}
func (c *DeviceConfig) createTun() tunCreator {
if c.KernelMode {
return createKernelTun
}
return createGVisorTun
}

58
proxy/wireguard/config.pb.go

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.31.0
// protoc v4.23.1
// protoc-gen-go v1.28.1
// protoc v4.25.0
// source: proxy/wireguard/config.proto
package wireguard
@ -83,7 +83,7 @@ type PeerConfig struct {
PublicKey string `protobuf:"bytes,1,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
PreSharedKey string `protobuf:"bytes,2,opt,name=pre_shared_key,json=preSharedKey,proto3" json:"pre_shared_key,omitempty"`
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
KeepAlive int32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"`
KeepAlive uint32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"`
AllowedIps []string `protobuf:"bytes,5,rep,name=allowed_ips,json=allowedIps,proto3" json:"allowed_ips,omitempty"`
}
@ -140,7 +140,7 @@ func (x *PeerConfig) GetEndpoint() string {
return ""
}
func (x *PeerConfig) GetKeepAlive() int32 {
func (x *PeerConfig) GetKeepAlive() uint32 {
if x != nil {
return x.KeepAlive
}
@ -166,6 +166,8 @@ type DeviceConfig struct {
NumWorkers int32 `protobuf:"varint,5,opt,name=num_workers,json=numWorkers,proto3" json:"num_workers,omitempty"`
Reserved []byte `protobuf:"bytes,6,opt,name=reserved,proto3" json:"reserved,omitempty"`
DomainStrategy DeviceConfig_DomainStrategy `protobuf:"varint,7,opt,name=domain_strategy,json=domainStrategy,proto3,enum=xray.proxy.wireguard.DeviceConfig_DomainStrategy" json:"domain_strategy,omitempty"`
IsClient bool `protobuf:"varint,8,opt,name=is_client,json=isClient,proto3" json:"is_client,omitempty"`
KernelMode bool `protobuf:"varint,9,opt,name=kernel_mode,json=kernelMode,proto3" json:"kernel_mode,omitempty"`
}
func (x *DeviceConfig) Reset() {
@ -249,6 +251,20 @@ func (x *DeviceConfig) GetDomainStrategy() DeviceConfig_DomainStrategy {
return DeviceConfig_FORCE_IP
}
func (x *DeviceConfig) GetIsClient() bool {
if x != nil {
return x.IsClient
}
return false
}
func (x *DeviceConfig) GetKernelMode() bool {
if x != nil {
return x.KernelMode
}
return false
}
var File_proxy_wireguard_config_proto protoreflect.FileDescriptor
var file_proxy_wireguard_config_proto_rawDesc = []byte{
@ -263,10 +279,10 @@ var file_proxy_wireguard_config_proto_rawDesc = []byte{
0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70,
0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x64, 0x70,
0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, 0x69,
0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c,
0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c,
0x69, 0x76, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x69,
0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65,
0x64, 0x49, 0x70, 0x73, 0x22, 0x8a, 0x03, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43,
0x64, 0x49, 0x70, 0x73, 0x22, 0xc8, 0x03, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f,
0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x63, 0x72, 0x65,
0x74, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
@ -285,19 +301,23 @@ var file_proxy_wireguard_config_proto_rawDesc = []byte{
0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x6f,
0x6e, 0x66, 0x69, 0x67, 0x2e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
0x65, 0x67, 0x79, 0x52, 0x0e, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
0x65, 0x67, 0x79, 0x22, 0x5c, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72,
0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49,
0x50, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34,
0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10,
0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10,
0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10,
0x04, 0x42, 0x5e, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72,
0x6f, 0x78, 0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a,
0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73,
0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79,
0x2f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61,
0x79, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72,
0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x65, 0x67, 0x79, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x73, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74,
0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x69, 0x73, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74,
0x12, 0x1f, 0x0a, 0x0b, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x18,
0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x4d, 0x6f, 0x64,
0x65, 0x22, 0x5c, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
0x65, 0x67, 0x79, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x10,
0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x10, 0x01,
0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10, 0x02, 0x12,
0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10, 0x03, 0x12,
0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10, 0x04, 0x42,
0x5e, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78,
0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a, 0x29, 0x67,
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78,
0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x77,
0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61, 0x79, 0x2e,
0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72, 0x64, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (

4
proxy/wireguard/config.proto

@ -10,7 +10,7 @@ message PeerConfig {
string public_key = 1;
string pre_shared_key = 2;
string endpoint = 3;
int32 keep_alive = 4;
uint32 keep_alive = 4;
repeated string allowed_ips = 5;
}
@ -29,4 +29,6 @@ message DeviceConfig {
int32 num_workers = 5;
bytes reserved = 6;
DomainStrategy domain_strategy = 7;
bool is_client = 8;
bool kernel_mode = 9;
}

230
proxy/wireguard/gvisortun/tun.go

@ -0,0 +1,230 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
*/
package gvisortun
import (
"context"
"fmt"
"net/netip"
"os"
"syscall"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
type netTun struct {
ep *channel.Endpoint
stack *stack.Stack
events chan tun.Event
incomingPacket chan *buffer.View
mtu int
hasV4, hasV6 bool
}
type Net netTun
func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
HandleLocal: !promiscuousMode,
}
dev := &netTun{
ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts),
events: make(chan tun.Event, 1),
incomingPacket: make(chan *buffer.View),
mtu: mtu,
}
dev.ep.AddNotify(dev)
tcpipErr := dev.stack.CreateNIC(1, dev.ep)
if tcpipErr != nil {
return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
for _, ip := range localAddresses {
var protoNumber tcpip.NetworkProtocolNumber
if ip.Is4() {
protoNumber = ipv4.ProtocolNumber
} else if ip.Is6() {
protoNumber = ipv6.ProtocolNumber
}
protoAddr := tcpip.ProtocolAddress{
Protocol: protoNumber,
AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
}
if ip.Is4() {
dev.hasV4 = true
} else if ip.Is6() {
dev.hasV6 = true
}
}
if dev.hasV4 {
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
}
if dev.hasV6 {
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
}
if promiscuousMode {
// enable promiscuous mode to handle all packets processed by netstack
dev.stack.SetPromiscuousMode(1, true)
dev.stack.SetSpoofing(1, true)
}
opt := tcpip.CongestionControlOption("cubic")
if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
}
dev.events <- tun.EventUp
return dev, (*Net)(dev), dev.stack, nil
}
// BatchSize implements tun.Device
func (tun *netTun) BatchSize() int {
return 1
}
// Name implements tun.Device
func (tun *netTun) Name() (string, error) {
return "go", nil
}
// File implements tun.Device
func (tun *netTun) File() *os.File {
return nil
}
// Events implements tun.Device
func (tun *netTun) Events() <-chan tun.Event {
return tun.events
}
// Read implements tun.Device
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
view, ok := <-tun.incomingPacket
if !ok {
return 0, os.ErrClosed
}
n, err := view.Read(buf[0][offset:])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// Write implements tun.Device
func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
for _, buf := range buf {
packet := buf[offset:]
if len(packet) == 0 {
continue
}
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
switch packet[0] >> 4 {
case 4:
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
case 6:
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
default:
return 0, syscall.EAFNOSUPPORT
}
}
return len(buf), nil
}
// WriteNotify implements channel.Notification
func (tun *netTun) WriteNotify() {
pkt := tun.ep.Read()
if pkt.IsNil() {
return
}
view := pkt.ToView()
pkt.DecRef()
tun.incomingPacket <- view
}
// Flush implements tun.Device
func (tun *netTun) Flush() error {
return nil
}
// Close implements tun.Device
func (tun *netTun) Close() error {
tun.stack.RemoveNIC(1)
if tun.events != nil {
close(tun.events)
}
tun.ep.Close()
if tun.incomingPacket != nil {
close(tun.incomingPacket)
}
return nil
}
// MTU implements tun.Device
func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() {
protoNumber = ipv4.ProtocolNumber
} else {
protoNumber = ipv6.ProtocolNumber
}
return tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
Port: endpoint.Port(),
}, protoNumber
}
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
}
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber
if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(laddr)
lfa = &addr
}
if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(raddr)
rfa = &addr
}
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}

181
proxy/wireguard/server.go

@ -0,0 +1,181 @@
package wireguard
import (
"context"
"errors"
"io"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport/internet/stat"
)
var nullDestination = net.TCPDestination(net.AnyIP, 0)
type Server struct {
bindServer *netBindServer
info routingInfo
policyManager policy.Manager
}
type routingInfo struct {
ctx context.Context
dispatcher routing.Dispatcher
inboundTag *session.Inbound
outboundTag *session.Outbound
contentTag *session.Content
}
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
v := core.MustFromContext(ctx)
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
if err != nil {
return nil, err
}
server := &Server{
bindServer: &netBindServer{
netBind: netBind{
dns: v.GetFeature(dns.ClientType()).(dns.Client),
dnsOption: dns.IPOption{
IPv4Enable: hasIPv4,
IPv6Enable: hasIPv6,
},
},
},
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
}
tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
if err != nil {
return nil, err
}
if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
_ = tun.Close()
return nil, err
}
return server, nil
}
// Network implements proxy.Inbound.
func (*Server) Network() []net.Network {
return []net.Network{net.Network_UDP}
}
// Process implements proxy.Inbound.
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
s.info = routingInfo{
ctx: core.ToBackgroundDetachedContext(ctx),
dispatcher: dispatcher,
inboundTag: session.InboundFromContext(ctx),
outboundTag: session.OutboundFromContext(ctx),
contentTag: session.ContentFromContext(ctx),
}
ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
if err != nil {
return err
}
nep := ep.(*netEndpoint)
nep.conn = conn
reader := buf.NewPacketReader(conn)
for {
mpayload, err := reader.ReadMultiBuffer()
if err != nil {
return err
}
for _, payload := range mpayload {
v, ok := <-s.bindServer.readQueue
if !ok {
return nil
}
i, err := payload.Read(v.buff)
v.bytes = i
v.endpoint = nep
v.err = err
v.waiter.Done()
if err != nil && errors.Is(err, io.EOF) {
nep.conn = nil
return nil
}
}
}
}
func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
if s.info.dispatcher == nil {
newError("unexpected: dispatcher == nil").AtError().WriteToLog()
return
}
defer conn.Close()
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
plcy := s.policyManager.ForLevel(0)
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: nullDestination,
To: dest,
Status: log.AccessAccepted,
Reason: "",
})
if s.info.inboundTag != nil {
ctx = session.ContextWithInbound(ctx, s.info.inboundTag)
}
if s.info.outboundTag != nil {
ctx = session.ContextWithOutbound(ctx, s.info.outboundTag)
}
if s.info.contentTag != nil {
ctx = session.ContextWithContent(ctx, s.info.contentTag)
}
link, err := s.info.dispatcher.Dispatch(ctx, dest)
if err != nil {
newError("dispatch connection").Base(err).AtError().WriteToLog()
}
defer cancel()
requestDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
return newError("failed to transport all TCP request").Base(err)
}
return nil
}
responseDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
return newError("failed to transport all TCP response").Base(err)
}
return nil
}
requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
newError("connection ends").Base(err).AtDebug().WriteToLog()
return
}
}

100
proxy/wireguard/tun.go

@ -10,14 +10,26 @@ import (
"strconv"
"strings"
"sync"
"time"
"github.com/xtls/xray-core/common/log"
xnet "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/proxy/wireguard/gvisortun"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
)
type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error)
type promiscuousModeHandler func(dest xnet.Destination, conn net.Conn)
type Tunnel interface {
BuildDevice(ipc string, bind conn.Bind) error
DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
@ -103,3 +115,91 @@ func CalculateInterfaceName(name string) (tunName string) {
tunName = fmt.Sprintf("%s%d", tunName, tunIndex)
return
}
var _ Tunnel = (*gvisorNet)(nil)
type gvisorNet struct {
tunnel
net *gvisortun.Net
}
func (g *gvisorNet) Close() error {
return g.tunnel.Close()
}
func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
net.Conn, error,
) {
return g.net.DialContextTCPAddrPort(ctx, addr)
}
func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
return g.net.DialUDPAddrPort(laddr, raddr)
}
func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
out := &gvisorNet{}
tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
if err != nil {
return nil, err
}
if handler != nil {
// handler is only used for promiscuous mode
// capture all packets and send to handler
tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
go func(r *tcp.ForwarderRequest) {
var (
wq waiter.Queue
id = r.ID()
)
// Perform a TCP three-way handshake.
ep, err := r.CreateEndpoint(&wq)
if err != nil {
newError(err.String()).AtError().WriteToLog()
r.Complete(true)
return
}
r.Complete(false)
defer ep.Close()
// enable tcp keep-alive to prevent hanging connections
ep.SocketOptions().SetKeepAlive(true)
// local address is actually destination
handler(xnet.TCPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
}(r)
})
stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) {
go func(r *udp.ForwarderRequest) {
var (
wq waiter.Queue
id = r.ID()
)
ep, err := r.CreateEndpoint(&wq)
if err != nil {
newError(err.String()).AtError().WriteToLog()
return
}
defer ep.Close()
// prevents hanging connections and ensure timely release
ep.SocketOptions().SetLinger(tcpip.LingerOption{
Enabled: true,
Timeout: 15 * time.Second,
})
handler(xnet.UDPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewUDPConn(stack, &wq, ep))
}(r)
})
stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
}
out.tun, out.net = tun, n
return out, nil
}

38
proxy/wireguard/tun_default.go

@ -1,42 +1,16 @@
//go:build !linux
//go:build !linux || android
package wireguard
import (
"context"
"net"
"errors"
"net/netip"
"golang.zx2c4.com/wireguard/tun/netstack"
)
var _ Tunnel = (*gvisorNet)(nil)
type gvisorNet struct {
tunnel
net *netstack.Net
func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
return nil, errors.New("not implemented")
}
func (g *gvisorNet) Close() error {
return g.tunnel.Close()
}
func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
net.Conn, error,
) {
return g.net.DialContextTCPAddrPort(ctx, addr)
}
func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
return g.net.DialUDPAddrPort(laddr, raddr)
}
func CreateTun(localAddresses []netip.Addr, mtu int) (Tunnel, error) {
out := &gvisorNet{}
tun, n, err := netstack.CreateNetTUN(localAddresses, nil, mtu)
if err != nil {
return nil, err
}
out.tun, out.net = tun, n
return out, nil
func KernelTunSupported() bool {
return false
}

16
proxy/wireguard/tun_linux.go

@ -1,3 +1,5 @@
//go:build linux && !android
package wireguard
import (
@ -69,7 +71,11 @@ func (d *deviceNet) Close() (err error) {
return errors.Join(errs...)
}
func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) {
func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
if handler != nil {
return nil, newError("TODO: support promiscuous mode")
}
var v4, v6 *netip.Addr
for _, prefixes := range localAddresses {
if v4 == nil && prefixes.Is4() {
@ -221,3 +227,11 @@ func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) {
out.tun = wgt
return out, nil
}
func KernelTunSupported() bool {
// run a superuser permission check to check
// if the current user has the sufficient permission
// to create a tun device.
return unix.Geteuid() == 0 // 0 means root
}

333
proxy/wireguard/wireguard.go

@ -1,326 +1,111 @@
/*
Some of codes are copied from https://github.com/octeep/wireproxy, license below.
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package wireguard
import (
"bytes"
"context"
"fmt"
stdnet "net"
"net/netip"
"strings"
"sync"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/dice"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet"
"golang.zx2c4.com/wireguard/device"
)
// Handler is an outbound connection that silently swallow the entire payload.
type Handler struct {
conf *DeviceConfig
net Tunnel
bind *netBindClient
policyManager policy.Manager
dns dns.Client
// cached configuration
ipc string
endpoints []netip.Addr
hasIPv4, hasIPv6 bool
wgLock sync.Mutex
}
// New creates a new wireguard handler.
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
v := core.MustFromContext(ctx)
endpoints, err := parseEndpoints(conf)
if err != nil {
return nil, err
}
hasIPv4, hasIPv6 := false, false
for _, e := range endpoints {
if e.Is4() {
hasIPv4 = true
}
if e.Is6() {
hasIPv6 = true
}
}
d := v.GetFeature(dns.ClientType()).(dns.Client)
return &Handler{
conf: conf,
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
dns: d,
ipc: createIPCRequest(conf, d, hasIPv6),
endpoints: endpoints,
hasIPv4: hasIPv4,
hasIPv6: hasIPv6,
}, nil
}
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
h.wgLock.Lock()
defer h.wgLock.Unlock()
if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
return nil
}
//go:generate go run github.com/xtls/xray-core/common/errors/errorgen
var wgLogger = &device.Logger{
Verbosef: func(format string, args ...any) {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Info,
Content: "switching dialer",
})
if h.net != nil {
_ = h.net.Close()
h.net = nil
}
if h.bind != nil {
_ = h.bind.Close()
h.bind = nil
}
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
bind := &netBindClient{
dialer: dialer,
workers: int(h.conf.NumWorkers),
dns: h.dns,
reserved: h.conf.Reserved,
}
defer func() {
if err != nil {
_ = bind.Close()
}
}()
h.net, err = h.makeVirtualTun(bind)
if err != nil {
return newError("failed to create virtual tun interface").Base(err)
}
h.bind = bind
return nil
}
// Process implements OutboundHandler.Dispatch().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified")
}
outbound.Name = "wireguard"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
if err := h.processWireGuard(dialer); err != nil {
return err
}
// Destination of the inner request.
destination := outbound.Target
command := protocol.RequestCommandTCP
if destination.Network == net.Network_UDP {
command = protocol.RequestCommandUDP
}
// resolve dns
addr := destination.Address
if addr.Family().IsDomain() {
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
Severity: log.Severity_Debug,
Content: fmt.Sprintf(format, args...),
})
{ // Resolve fallback
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
},
Errorf: func(format string, args ...any) {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Error,
Content: fmt.Sprintf(format, args...),
})
}
}
if err != nil {
return newError("failed to lookup DNS").Base(err)
} else if len(ips) == 0 {
return dns.ErrEmptyResponse
}
addr = net.IPAddress(ips[dice.Roll(len(ips))])
},
}
var newCtx context.Context
var newCancel context.CancelFunc
if session.TimeoutOnlyFromContext(ctx) {
newCtx, newCancel = context.WithCancel(context.Background())
func init() {
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
deviceConfig := config.(*DeviceConfig)
if deviceConfig.IsClient {
return New(ctx, deviceConfig)
} else {
return NewServer(ctx, deviceConfig)
}
p := h.policyManager.ForLevel(0)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, func() {
cancel()
if newCancel != nil {
newCancel()
}))
}
}, p.Timeouts.ConnectionIdle)
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
var requestFunc func() error
var responseFunc func() error
// convert endpoint string to netip.Addr
func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, bool, bool, error) {
var hasIPv4, hasIPv6 bool
if command == protocol.RequestCommandTCP {
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
endpoints := make([]netip.Addr, len(conf.Endpoint))
for i, str := range conf.Endpoint {
var addr netip.Addr
if strings.Contains(str, "/") {
prefix, err := netip.ParsePrefix(str)
if err != nil {
return newError("failed to create TCP connection").Base(err)
return nil, false, false, err
}
defer conn.Close()
requestFunc = func() error {
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
}
responseFunc = func() error {
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
addr = prefix.Addr()
if prefix.Bits() != addr.BitLen() {
return nil, false, false, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
}
} else if command == protocol.RequestCommandUDP {
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
} else {
var err error
addr, err = netip.ParseAddr(str)
if err != nil {
return newError("failed to create UDP connection").Base(err)
}
defer conn.Close()
requestFunc = func() error {
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
}
responseFunc = func() error {
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
return nil, false, false, err
}
}
endpoints[i] = addr
if newCtx != nil {
ctx = newCtx
if addr.Is4() {
hasIPv4 = true
} else if addr.Is6() {
hasIPv6 = true
}
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return newError("connection ends").Base(err)
}
return nil
return endpoints, hasIPv4, hasIPv6, nil
}
// serialize the config into an IPC request
func createIPCRequest(conf *DeviceConfig, d dns.Client, resolveEndPointToV4 bool) string {
var request bytes.Buffer
func createIPCRequest(conf *DeviceConfig) string {
var request strings.Builder
request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
for _, peer := range conf.Peers {
endpoint := peer.Endpoint
host, port, err := net.SplitHostPort(endpoint)
if resolveEndPointToV4 && err == nil {
_, err = netip.ParseAddr(host)
if err != nil {
ipList, err := d.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: false})
if err == nil && len(ipList) > 0 {
endpoint = stdnet.JoinHostPort(ipList[0].String(), port)
}
}
}
request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n",
peer.PublicKey, endpoint, peer.KeepAlive, peer.PreSharedKey))
for _, ip := range peer.AllowedIps {
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
}
if !conf.IsClient {
// placeholder, we'll handle actual port listening on Xray
request.WriteString("listen_port=1337\n")
}
return request.String()[:request.Len()]
for _, peer := range conf.Peers {
if peer.PublicKey != "" {
request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey))
}
// convert endpoint string to netip.Addr
func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) {
endpoints := make([]netip.Addr, len(conf.Endpoint))
for i, str := range conf.Endpoint {
var addr netip.Addr
if strings.Contains(str, "/") {
prefix, err := netip.ParsePrefix(str)
if err != nil {
return nil, err
}
addr = prefix.Addr()
if prefix.Bits() != addr.BitLen() {
return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
}
} else {
var err error
addr, err = netip.ParseAddr(str)
if err != nil {
return nil, err
}
}
endpoints[i] = addr
if peer.PreSharedKey != "" {
request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey))
}
return endpoints, nil
if peer.Endpoint != "" {
request.WriteString(fmt.Sprintf("endpoint=%s\n", peer.Endpoint))
}
// creates a tun interface on netstack given a configuration
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
t, err := CreateTun(h.endpoints, int(h.conf.Mtu))
if err != nil {
return nil, err
for _, ip := range peer.AllowedIps {
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
}
bind.dnsOption.IPv4Enable = h.hasIPv4
bind.dnsOption.IPv6Enable = h.hasIPv6
if err = t.BuildDevice(h.ipc, bind); err != nil {
_ = t.Close()
return nil, err
if peer.KeepAlive != 0 {
request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive))
}
return t, nil
}
func init() {
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
return New(ctx, config.(*DeviceConfig))
}))
return request.String()[:request.Len()]
}

Loading…
Cancel
Save