mirror of https://github.com/XTLS/Xray-core
Browse Source
* feat: wireguard inbound * feat(command): generate wireguard compatible keypair * feat(wireguard): connection idle timeout * fix(wireguard): close endpoint after connection closed * fix(wireguard): resolve conflicts * feat(wireguard): set cubic as default cc algorithm in gVisor TUN * chore(wireguard): resolve conflict * chore(wireguard): remove redurant code * chore(wireguard): remove redurant code * feat: rework server for gvisor tun * feat: keep user-space tun as an option * fix: exclude android from native tun build * feat: auto kernel tun * fix: build * fix: regulate function name & fix testpull/2734/head
hax0r31337
1 year ago
committed by
GitHub
17 changed files with 1048 additions and 499 deletions
@ -0,0 +1,255 @@
|
||||
/* |
||||
|
||||
Some of codes are copied from https://github.com/octeep/wireproxy, license below.
|
||||
|
||||
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me> |
||||
|
||||
Permission to use, copy, modify, and distribute this software for any |
||||
purpose with or without fee is hereby granted, provided that the above |
||||
copyright notice and this permission notice appear in all copies. |
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES |
||||
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF |
||||
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR |
||||
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES |
||||
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN |
||||
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF |
||||
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. |
||||
|
||||
*/ |
||||
|
||||
package wireguard |
||||
|
||||
import ( |
||||
"context" |
||||
"net/netip" |
||||
"sync" |
||||
|
||||
"github.com/xtls/xray-core/common" |
||||
"github.com/xtls/xray-core/common/buf" |
||||
"github.com/xtls/xray-core/common/dice" |
||||
"github.com/xtls/xray-core/common/log" |
||||
"github.com/xtls/xray-core/common/net" |
||||
"github.com/xtls/xray-core/common/protocol" |
||||
"github.com/xtls/xray-core/common/session" |
||||
"github.com/xtls/xray-core/common/signal" |
||||
"github.com/xtls/xray-core/common/task" |
||||
"github.com/xtls/xray-core/core" |
||||
"github.com/xtls/xray-core/features/dns" |
||||
"github.com/xtls/xray-core/features/policy" |
||||
"github.com/xtls/xray-core/transport" |
||||
"github.com/xtls/xray-core/transport/internet" |
||||
) |
||||
|
||||
// Handler is an outbound connection that silently swallow the entire payload.
|
||||
type Handler struct { |
||||
conf *DeviceConfig |
||||
net Tunnel |
||||
bind *netBindClient |
||||
policyManager policy.Manager |
||||
dns dns.Client |
||||
// cached configuration
|
||||
ipc string |
||||
endpoints []netip.Addr |
||||
hasIPv4, hasIPv6 bool |
||||
wgLock sync.Mutex |
||||
} |
||||
|
||||
// New creates a new wireguard handler.
|
||||
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { |
||||
v := core.MustFromContext(ctx) |
||||
|
||||
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
d := v.GetFeature(dns.ClientType()).(dns.Client) |
||||
return &Handler{ |
||||
conf: conf, |
||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), |
||||
dns: d, |
||||
ipc: createIPCRequest(conf), |
||||
endpoints: endpoints, |
||||
hasIPv4: hasIPv4, |
||||
hasIPv6: hasIPv6, |
||||
}, nil |
||||
} |
||||
|
||||
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) { |
||||
h.wgLock.Lock() |
||||
defer h.wgLock.Unlock() |
||||
|
||||
if h.bind != nil && h.bind.dialer == dialer && h.net != nil { |
||||
return nil |
||||
} |
||||
|
||||
log.Record(&log.GeneralMessage{ |
||||
Severity: log.Severity_Info, |
||||
Content: "switching dialer", |
||||
}) |
||||
|
||||
if h.net != nil { |
||||
_ = h.net.Close() |
||||
h.net = nil |
||||
} |
||||
if h.bind != nil { |
||||
_ = h.bind.Close() |
||||
h.bind = nil |
||||
} |
||||
|
||||
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
|
||||
bind := &netBindClient{ |
||||
netBind: netBind{ |
||||
dns: h.dns, |
||||
dnsOption: dns.IPOption{ |
||||
IPv4Enable: h.hasIPv4, |
||||
IPv6Enable: h.hasIPv6, |
||||
}, |
||||
workers: int(h.conf.NumWorkers), |
||||
}, |
||||
dialer: dialer, |
||||
reserved: h.conf.Reserved, |
||||
} |
||||
defer func() { |
||||
if err != nil { |
||||
_ = bind.Close() |
||||
} |
||||
}() |
||||
|
||||
h.net, err = h.makeVirtualTun(bind) |
||||
if err != nil { |
||||
return newError("failed to create virtual tun interface").Base(err) |
||||
} |
||||
h.bind = bind |
||||
return nil |
||||
} |
||||
|
||||
// Process implements OutboundHandler.Dispatch().
|
||||
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { |
||||
outbound := session.OutboundFromContext(ctx) |
||||
if outbound == nil || !outbound.Target.IsValid() { |
||||
return newError("target not specified") |
||||
} |
||||
outbound.Name = "wireguard" |
||||
inbound := session.InboundFromContext(ctx) |
||||
if inbound != nil { |
||||
inbound.SetCanSpliceCopy(3) |
||||
} |
||||
|
||||
if err := h.processWireGuard(dialer); err != nil { |
||||
return err |
||||
} |
||||
|
||||
// Destination of the inner request.
|
||||
destination := outbound.Target |
||||
command := protocol.RequestCommandTCP |
||||
if destination.Network == net.Network_UDP { |
||||
command = protocol.RequestCommandUDP |
||||
} |
||||
|
||||
// resolve dns
|
||||
addr := destination.Address |
||||
if addr.Family().IsDomain() { |
||||
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ |
||||
IPv4Enable: h.hasIPv4 && h.conf.preferIP4(), |
||||
IPv6Enable: h.hasIPv6 && h.conf.preferIP6(), |
||||
}) |
||||
{ // Resolve fallback
|
||||
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() { |
||||
ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{ |
||||
IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(), |
||||
IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(), |
||||
}) |
||||
} |
||||
} |
||||
if err != nil { |
||||
return newError("failed to lookup DNS").Base(err) |
||||
} else if len(ips) == 0 { |
||||
return dns.ErrEmptyResponse |
||||
} |
||||
addr = net.IPAddress(ips[dice.Roll(len(ips))]) |
||||
} |
||||
|
||||
var newCtx context.Context |
||||
var newCancel context.CancelFunc |
||||
if session.TimeoutOnlyFromContext(ctx) { |
||||
newCtx, newCancel = context.WithCancel(context.Background()) |
||||
} |
||||
|
||||
p := h.policyManager.ForLevel(0) |
||||
|
||||
ctx, cancel := context.WithCancel(ctx) |
||||
timer := signal.CancelAfterInactivity(ctx, func() { |
||||
cancel() |
||||
if newCancel != nil { |
||||
newCancel() |
||||
} |
||||
}, p.Timeouts.ConnectionIdle) |
||||
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value()) |
||||
|
||||
var requestFunc func() error |
||||
var responseFunc func() error |
||||
|
||||
if command == protocol.RequestCommandTCP { |
||||
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort) |
||||
if err != nil { |
||||
return newError("failed to create TCP connection").Base(err) |
||||
} |
||||
defer conn.Close() |
||||
|
||||
requestFunc = func() error { |
||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly) |
||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) |
||||
} |
||||
responseFunc = func() error { |
||||
defer timer.SetTimeout(p.Timeouts.UplinkOnly) |
||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) |
||||
} |
||||
} else if command == protocol.RequestCommandUDP { |
||||
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort) |
||||
if err != nil { |
||||
return newError("failed to create UDP connection").Base(err) |
||||
} |
||||
defer conn.Close() |
||||
|
||||
requestFunc = func() error { |
||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly) |
||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) |
||||
} |
||||
responseFunc = func() error { |
||||
defer timer.SetTimeout(p.Timeouts.UplinkOnly) |
||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) |
||||
} |
||||
} |
||||
|
||||
if newCtx != nil { |
||||
ctx = newCtx |
||||
} |
||||
|
||||
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) |
||||
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { |
||||
common.Interrupt(link.Reader) |
||||
common.Interrupt(link.Writer) |
||||
return newError("connection ends").Base(err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// creates a tun interface on netstack given a configuration
|
||||
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) { |
||||
t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
bind.dnsOption.IPv4Enable = h.hasIPv4 |
||||
bind.dnsOption.IPv6Enable = h.hasIPv6 |
||||
|
||||
if err = t.BuildDevice(h.ipc, bind); err != nil { |
||||
_ = t.Close() |
||||
return nil, err |
||||
} |
||||
return t, nil |
||||
} |
@ -0,0 +1,230 @@
|
||||
/* SPDX-License-Identifier: MIT |
||||
* |
||||
* Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. |
||||
*/ |
||||
|
||||
package gvisortun |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"net/netip" |
||||
"os" |
||||
"syscall" |
||||
|
||||
"golang.zx2c4.com/wireguard/tun" |
||||
"gvisor.dev/gvisor/pkg/buffer" |
||||
"gvisor.dev/gvisor/pkg/tcpip" |
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" |
||||
"gvisor.dev/gvisor/pkg/tcpip/header" |
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel" |
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4" |
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6" |
||||
"gvisor.dev/gvisor/pkg/tcpip/stack" |
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp" |
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" |
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" |
||||
) |
||||
|
||||
type netTun struct { |
||||
ep *channel.Endpoint |
||||
stack *stack.Stack |
||||
events chan tun.Event |
||||
incomingPacket chan *buffer.View |
||||
mtu int |
||||
hasV4, hasV6 bool |
||||
} |
||||
|
||||
type Net netTun |
||||
|
||||
func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) { |
||||
opts := stack.Options{ |
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, |
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, |
||||
HandleLocal: !promiscuousMode, |
||||
} |
||||
dev := &netTun{ |
||||
ep: channel.New(1024, uint32(mtu), ""), |
||||
stack: stack.New(opts), |
||||
events: make(chan tun.Event, 1), |
||||
incomingPacket: make(chan *buffer.View), |
||||
mtu: mtu, |
||||
} |
||||
dev.ep.AddNotify(dev) |
||||
tcpipErr := dev.stack.CreateNIC(1, dev.ep) |
||||
if tcpipErr != nil { |
||||
return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr) |
||||
} |
||||
for _, ip := range localAddresses { |
||||
var protoNumber tcpip.NetworkProtocolNumber |
||||
if ip.Is4() { |
||||
protoNumber = ipv4.ProtocolNumber |
||||
} else if ip.Is6() { |
||||
protoNumber = ipv6.ProtocolNumber |
||||
} |
||||
protoAddr := tcpip.ProtocolAddress{ |
||||
Protocol: protoNumber, |
||||
AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), |
||||
} |
||||
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) |
||||
if tcpipErr != nil { |
||||
return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) |
||||
} |
||||
if ip.Is4() { |
||||
dev.hasV4 = true |
||||
} else if ip.Is6() { |
||||
dev.hasV6 = true |
||||
} |
||||
} |
||||
if dev.hasV4 { |
||||
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) |
||||
} |
||||
if dev.hasV6 { |
||||
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) |
||||
} |
||||
if promiscuousMode { |
||||
// enable promiscuous mode to handle all packets processed by netstack
|
||||
dev.stack.SetPromiscuousMode(1, true) |
||||
dev.stack.SetSpoofing(1, true) |
||||
} |
||||
|
||||
opt := tcpip.CongestionControlOption("cubic") |
||||
if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { |
||||
return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err) |
||||
} |
||||
|
||||
dev.events <- tun.EventUp |
||||
return dev, (*Net)(dev), dev.stack, nil |
||||
} |
||||
|
||||
// BatchSize implements tun.Device
|
||||
func (tun *netTun) BatchSize() int { |
||||
return 1 |
||||
} |
||||
|
||||
// Name implements tun.Device
|
||||
func (tun *netTun) Name() (string, error) { |
||||
return "go", nil |
||||
} |
||||
|
||||
// File implements tun.Device
|
||||
func (tun *netTun) File() *os.File { |
||||
return nil |
||||
} |
||||
|
||||
// Events implements tun.Device
|
||||
func (tun *netTun) Events() <-chan tun.Event { |
||||
return tun.events |
||||
} |
||||
|
||||
// Read implements tun.Device
|
||||
|
||||
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { |
||||
view, ok := <-tun.incomingPacket |
||||
if !ok { |
||||
return 0, os.ErrClosed |
||||
} |
||||
|
||||
n, err := view.Read(buf[0][offset:]) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
sizes[0] = n |
||||
return 1, nil |
||||
} |
||||
|
||||
// Write implements tun.Device
|
||||
func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { |
||||
for _, buf := range buf { |
||||
packet := buf[offset:] |
||||
if len(packet) == 0 { |
||||
continue |
||||
} |
||||
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) |
||||
switch packet[0] >> 4 { |
||||
case 4: |
||||
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) |
||||
case 6: |
||||
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) |
||||
default: |
||||
return 0, syscall.EAFNOSUPPORT |
||||
} |
||||
} |
||||
return len(buf), nil |
||||
} |
||||
|
||||
// WriteNotify implements channel.Notification
|
||||
func (tun *netTun) WriteNotify() { |
||||
pkt := tun.ep.Read() |
||||
if pkt.IsNil() { |
||||
return |
||||
} |
||||
|
||||
view := pkt.ToView() |
||||
pkt.DecRef() |
||||
|
||||
tun.incomingPacket <- view |
||||
} |
||||
|
||||
// Flush implements tun.Device
|
||||
func (tun *netTun) Flush() error { |
||||
return nil |
||||
} |
||||
|
||||
// Close implements tun.Device
|
||||
func (tun *netTun) Close() error { |
||||
tun.stack.RemoveNIC(1) |
||||
|
||||
if tun.events != nil { |
||||
close(tun.events) |
||||
} |
||||
|
||||
tun.ep.Close() |
||||
|
||||
if tun.incomingPacket != nil { |
||||
close(tun.incomingPacket) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// MTU implements tun.Device
|
||||
func (tun *netTun) MTU() (int, error) { |
||||
return tun.mtu, nil |
||||
} |
||||
|
||||
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { |
||||
var protoNumber tcpip.NetworkProtocolNumber |
||||
if endpoint.Addr().Is4() { |
||||
protoNumber = ipv4.ProtocolNumber |
||||
} else { |
||||
protoNumber = ipv6.ProtocolNumber |
||||
} |
||||
return tcpip.FullAddress{ |
||||
NIC: 1, |
||||
Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), |
||||
Port: endpoint.Port(), |
||||
}, protoNumber |
||||
} |
||||
|
||||
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { |
||||
fa, pn := convertToFullAddr(addr) |
||||
return gonet.DialContextTCP(ctx, net.stack, fa, pn) |
||||
} |
||||
|
||||
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { |
||||
var lfa, rfa *tcpip.FullAddress |
||||
var pn tcpip.NetworkProtocolNumber |
||||
if laddr.IsValid() || laddr.Port() > 0 { |
||||
var addr tcpip.FullAddress |
||||
addr, pn = convertToFullAddr(laddr) |
||||
lfa = &addr |
||||
} |
||||
if raddr.IsValid() || raddr.Port() > 0 { |
||||
var addr tcpip.FullAddress |
||||
addr, pn = convertToFullAddr(raddr) |
||||
rfa = &addr |
||||
} |
||||
return gonet.DialUDP(net.stack, lfa, rfa, pn) |
||||
} |
@ -0,0 +1,181 @@
|
||||
package wireguard |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"io" |
||||
|
||||
"github.com/xtls/xray-core/common" |
||||
"github.com/xtls/xray-core/common/buf" |
||||
"github.com/xtls/xray-core/common/log" |
||||
"github.com/xtls/xray-core/common/net" |
||||
"github.com/xtls/xray-core/common/session" |
||||
"github.com/xtls/xray-core/common/signal" |
||||
"github.com/xtls/xray-core/common/task" |
||||
"github.com/xtls/xray-core/core" |
||||
"github.com/xtls/xray-core/features/dns" |
||||
"github.com/xtls/xray-core/features/policy" |
||||
"github.com/xtls/xray-core/features/routing" |
||||
"github.com/xtls/xray-core/transport/internet/stat" |
||||
) |
||||
|
||||
var nullDestination = net.TCPDestination(net.AnyIP, 0) |
||||
|
||||
type Server struct { |
||||
bindServer *netBindServer |
||||
|
||||
info routingInfo |
||||
policyManager policy.Manager |
||||
} |
||||
|
||||
type routingInfo struct { |
||||
ctx context.Context |
||||
dispatcher routing.Dispatcher |
||||
inboundTag *session.Inbound |
||||
outboundTag *session.Outbound |
||||
contentTag *session.Content |
||||
} |
||||
|
||||
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) { |
||||
v := core.MustFromContext(ctx) |
||||
|
||||
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
server := &Server{ |
||||
bindServer: &netBindServer{ |
||||
netBind: netBind{ |
||||
dns: v.GetFeature(dns.ClientType()).(dns.Client), |
||||
dnsOption: dns.IPOption{ |
||||
IPv4Enable: hasIPv4, |
||||
IPv6Enable: hasIPv6, |
||||
}, |
||||
}, |
||||
}, |
||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), |
||||
} |
||||
|
||||
tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil { |
||||
_ = tun.Close() |
||||
return nil, err |
||||
} |
||||
|
||||
return server, nil |
||||
} |
||||
|
||||
// Network implements proxy.Inbound.
|
||||
func (*Server) Network() []net.Network { |
||||
return []net.Network{net.Network_UDP} |
||||
} |
||||
|
||||
// Process implements proxy.Inbound.
|
||||
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { |
||||
s.info = routingInfo{ |
||||
ctx: core.ToBackgroundDetachedContext(ctx), |
||||
dispatcher: dispatcher, |
||||
inboundTag: session.InboundFromContext(ctx), |
||||
outboundTag: session.OutboundFromContext(ctx), |
||||
contentTag: session.ContentFromContext(ctx), |
||||
} |
||||
|
||||
ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String()) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
nep := ep.(*netEndpoint) |
||||
nep.conn = conn |
||||
|
||||
reader := buf.NewPacketReader(conn) |
||||
for { |
||||
mpayload, err := reader.ReadMultiBuffer() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
for _, payload := range mpayload { |
||||
v, ok := <-s.bindServer.readQueue |
||||
if !ok { |
||||
return nil |
||||
} |
||||
i, err := payload.Read(v.buff) |
||||
|
||||
v.bytes = i |
||||
v.endpoint = nep |
||||
v.err = err |
||||
v.waiter.Done() |
||||
if err != nil && errors.Is(err, io.EOF) { |
||||
nep.conn = nil |
||||
return nil |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { |
||||
if s.info.dispatcher == nil { |
||||
newError("unexpected: dispatcher == nil").AtError().WriteToLog() |
||||
return |
||||
} |
||||
defer conn.Close() |
||||
|
||||
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx)) |
||||
plcy := s.policyManager.ForLevel(0) |
||||
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle) |
||||
|
||||
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ |
||||
From: nullDestination, |
||||
To: dest, |
||||
Status: log.AccessAccepted, |
||||
Reason: "", |
||||
}) |
||||
|
||||
if s.info.inboundTag != nil { |
||||
ctx = session.ContextWithInbound(ctx, s.info.inboundTag) |
||||
} |
||||
if s.info.outboundTag != nil { |
||||
ctx = session.ContextWithOutbound(ctx, s.info.outboundTag) |
||||
} |
||||
if s.info.contentTag != nil { |
||||
ctx = session.ContextWithContent(ctx, s.info.contentTag) |
||||
} |
||||
|
||||
link, err := s.info.dispatcher.Dispatch(ctx, dest) |
||||
if err != nil { |
||||
newError("dispatch connection").Base(err).AtError().WriteToLog() |
||||
} |
||||
defer cancel() |
||||
|
||||
requestDone := func() error { |
||||
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly) |
||||
if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil { |
||||
return newError("failed to transport all TCP request").Base(err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
responseDone := func() error { |
||||
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) |
||||
if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil { |
||||
return newError("failed to transport all TCP response").Base(err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer)) |
||||
if err := task.Run(ctx, requestDonePost, responseDone); err != nil { |
||||
common.Interrupt(link.Reader) |
||||
common.Interrupt(link.Writer) |
||||
newError("connection ends").Base(err).AtDebug().WriteToLog() |
||||
return |
||||
} |
||||
} |
@ -1,42 +1,16 @@
|
||||
//go:build !linux
|
||||
//go:build !linux || android
|
||||
|
||||
package wireguard |
||||
|
||||
import ( |
||||
"context" |
||||
"net" |
||||
"errors" |
||||
"net/netip" |
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack" |
||||
) |
||||
|
||||
var _ Tunnel = (*gvisorNet)(nil) |
||||
|
||||
type gvisorNet struct { |
||||
tunnel |
||||
net *netstack.Net |
||||
func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) { |
||||
return nil, errors.New("not implemented") |
||||
} |
||||
|
||||
func (g *gvisorNet) Close() error { |
||||
return g.tunnel.Close() |
||||
} |
||||
|
||||
func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) ( |
||||
net.Conn, error, |
||||
) { |
||||
return g.net.DialContextTCPAddrPort(ctx, addr) |
||||
} |
||||
|
||||
func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) { |
||||
return g.net.DialUDPAddrPort(laddr, raddr) |
||||
} |
||||
|
||||
func CreateTun(localAddresses []netip.Addr, mtu int) (Tunnel, error) { |
||||
out := &gvisorNet{} |
||||
tun, n, err := netstack.CreateNetTUN(localAddresses, nil, mtu) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
out.tun, out.net = tun, n |
||||
return out, nil |
||||
func KernelTunSupported() bool { |
||||
return false |
||||
} |
||||
|
@ -1,326 +1,111 @@
|
||||
/* |
||||
|
||||
Some of codes are copied from https://github.com/octeep/wireproxy, license below.
|
||||
|
||||
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me> |
||||
|
||||
Permission to use, copy, modify, and distribute this software for any |
||||
purpose with or without fee is hereby granted, provided that the above |
||||
copyright notice and this permission notice appear in all copies. |
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES |
||||
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF |
||||
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR |
||||
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES |
||||
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN |
||||
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF |
||||
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. |
||||
|
||||
*/ |
||||
|
||||
package wireguard |
||||
|
||||
import ( |
||||
"bytes" |
||||
"context" |
||||
"fmt" |
||||
stdnet "net" |
||||
"net/netip" |
||||
"strings" |
||||
"sync" |
||||
|
||||
"github.com/xtls/xray-core/common" |
||||
"github.com/xtls/xray-core/common/buf" |
||||
"github.com/xtls/xray-core/common/dice" |
||||
"github.com/xtls/xray-core/common/log" |
||||
"github.com/xtls/xray-core/common/net" |
||||
"github.com/xtls/xray-core/common/protocol" |
||||
"github.com/xtls/xray-core/common/session" |
||||
"github.com/xtls/xray-core/common/signal" |
||||
"github.com/xtls/xray-core/common/task" |
||||
"github.com/xtls/xray-core/core" |
||||
"github.com/xtls/xray-core/features/dns" |
||||
"github.com/xtls/xray-core/features/policy" |
||||
"github.com/xtls/xray-core/transport" |
||||
"github.com/xtls/xray-core/transport/internet" |
||||
"golang.zx2c4.com/wireguard/device" |
||||
) |
||||
|
||||
// Handler is an outbound connection that silently swallow the entire payload.
|
||||
type Handler struct { |
||||
conf *DeviceConfig |
||||
net Tunnel |
||||
bind *netBindClient |
||||
policyManager policy.Manager |
||||
dns dns.Client |
||||
// cached configuration
|
||||
ipc string |
||||
endpoints []netip.Addr |
||||
hasIPv4, hasIPv6 bool |
||||
wgLock sync.Mutex |
||||
} |
||||
|
||||
// New creates a new wireguard handler.
|
||||
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { |
||||
v := core.MustFromContext(ctx) |
||||
|
||||
endpoints, err := parseEndpoints(conf) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
hasIPv4, hasIPv6 := false, false |
||||
for _, e := range endpoints { |
||||
if e.Is4() { |
||||
hasIPv4 = true |
||||
} |
||||
if e.Is6() { |
||||
hasIPv6 = true |
||||
} |
||||
} |
||||
|
||||
d := v.GetFeature(dns.ClientType()).(dns.Client) |
||||
return &Handler{ |
||||
conf: conf, |
||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), |
||||
dns: d, |
||||
ipc: createIPCRequest(conf, d, hasIPv6), |
||||
endpoints: endpoints, |
||||
hasIPv4: hasIPv4, |
||||
hasIPv6: hasIPv6, |
||||
}, nil |
||||
} |
||||
|
||||
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) { |
||||
h.wgLock.Lock() |
||||
defer h.wgLock.Unlock() |
||||
|
||||
if h.bind != nil && h.bind.dialer == dialer && h.net != nil { |
||||
return nil |
||||
} |
||||
//go:generate go run github.com/xtls/xray-core/common/errors/errorgen
|
||||
|
||||
var wgLogger = &device.Logger{ |
||||
Verbosef: func(format string, args ...any) { |
||||
log.Record(&log.GeneralMessage{ |
||||
Severity: log.Severity_Info, |
||||
Content: "switching dialer", |
||||
}) |
||||
|
||||
if h.net != nil { |
||||
_ = h.net.Close() |
||||
h.net = nil |
||||
} |
||||
if h.bind != nil { |
||||
_ = h.bind.Close() |
||||
h.bind = nil |
||||
} |
||||
|
||||
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
|
||||
bind := &netBindClient{ |
||||
dialer: dialer, |
||||
workers: int(h.conf.NumWorkers), |
||||
dns: h.dns, |
||||
reserved: h.conf.Reserved, |
||||
} |
||||
defer func() { |
||||
if err != nil { |
||||
_ = bind.Close() |
||||
} |
||||
}() |
||||
|
||||
h.net, err = h.makeVirtualTun(bind) |
||||
if err != nil { |
||||
return newError("failed to create virtual tun interface").Base(err) |
||||
} |
||||
h.bind = bind |
||||
return nil |
||||
} |
||||
|
||||
// Process implements OutboundHandler.Dispatch().
|
||||
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { |
||||
outbound := session.OutboundFromContext(ctx) |
||||
if outbound == nil || !outbound.Target.IsValid() { |
||||
return newError("target not specified") |
||||
} |
||||
outbound.Name = "wireguard" |
||||
inbound := session.InboundFromContext(ctx) |
||||
if inbound != nil { |
||||
inbound.SetCanSpliceCopy(3) |
||||
} |
||||
|
||||
if err := h.processWireGuard(dialer); err != nil { |
||||
return err |
||||
} |
||||
|
||||
// Destination of the inner request.
|
||||
destination := outbound.Target |
||||
command := protocol.RequestCommandTCP |
||||
if destination.Network == net.Network_UDP { |
||||
command = protocol.RequestCommandUDP |
||||
} |
||||
|
||||
// resolve dns
|
||||
addr := destination.Address |
||||
if addr.Family().IsDomain() { |
||||
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ |
||||
IPv4Enable: h.hasIPv4 && h.conf.preferIP4(), |
||||
IPv6Enable: h.hasIPv6 && h.conf.preferIP6(), |
||||
Severity: log.Severity_Debug, |
||||
Content: fmt.Sprintf(format, args...), |
||||
}) |
||||
{ // Resolve fallback
|
||||
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() { |
||||
ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{ |
||||
IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(), |
||||
IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(), |
||||
}, |
||||
Errorf: func(format string, args ...any) { |
||||
log.Record(&log.GeneralMessage{ |
||||
Severity: log.Severity_Error, |
||||
Content: fmt.Sprintf(format, args...), |
||||
}) |
||||
} |
||||
} |
||||
if err != nil { |
||||
return newError("failed to lookup DNS").Base(err) |
||||
} else if len(ips) == 0 { |
||||
return dns.ErrEmptyResponse |
||||
} |
||||
addr = net.IPAddress(ips[dice.Roll(len(ips))]) |
||||
}, |
||||
} |
||||
|
||||
var newCtx context.Context |
||||
var newCancel context.CancelFunc |
||||
if session.TimeoutOnlyFromContext(ctx) { |
||||
newCtx, newCancel = context.WithCancel(context.Background()) |
||||
func init() { |
||||
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { |
||||
deviceConfig := config.(*DeviceConfig) |
||||
if deviceConfig.IsClient { |
||||
return New(ctx, deviceConfig) |
||||
} else { |
||||
return NewServer(ctx, deviceConfig) |
||||
} |
||||
|
||||
p := h.policyManager.ForLevel(0) |
||||
|
||||
ctx, cancel := context.WithCancel(ctx) |
||||
timer := signal.CancelAfterInactivity(ctx, func() { |
||||
cancel() |
||||
if newCancel != nil { |
||||
newCancel() |
||||
})) |
||||
} |
||||
}, p.Timeouts.ConnectionIdle) |
||||
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value()) |
||||
|
||||
var requestFunc func() error |
||||
var responseFunc func() error |
||||
// convert endpoint string to netip.Addr
|
||||
func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, bool, bool, error) { |
||||
var hasIPv4, hasIPv6 bool |
||||
|
||||
if command == protocol.RequestCommandTCP { |
||||
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort) |
||||
endpoints := make([]netip.Addr, len(conf.Endpoint)) |
||||
for i, str := range conf.Endpoint { |
||||
var addr netip.Addr |
||||
if strings.Contains(str, "/") { |
||||
prefix, err := netip.ParsePrefix(str) |
||||
if err != nil { |
||||
return newError("failed to create TCP connection").Base(err) |
||||
return nil, false, false, err |
||||
} |
||||
defer conn.Close() |
||||
|
||||
requestFunc = func() error { |
||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly) |
||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) |
||||
} |
||||
responseFunc = func() error { |
||||
defer timer.SetTimeout(p.Timeouts.UplinkOnly) |
||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) |
||||
addr = prefix.Addr() |
||||
if prefix.Bits() != addr.BitLen() { |
||||
return nil, false, false, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6") |
||||
} |
||||
} else if command == protocol.RequestCommandUDP { |
||||
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort) |
||||
} else { |
||||
var err error |
||||
addr, err = netip.ParseAddr(str) |
||||
if err != nil { |
||||
return newError("failed to create UDP connection").Base(err) |
||||
} |
||||
defer conn.Close() |
||||
|
||||
requestFunc = func() error { |
||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly) |
||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) |
||||
} |
||||
responseFunc = func() error { |
||||
defer timer.SetTimeout(p.Timeouts.UplinkOnly) |
||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) |
||||
return nil, false, false, err |
||||
} |
||||
} |
||||
endpoints[i] = addr |
||||
|
||||
if newCtx != nil { |
||||
ctx = newCtx |
||||
if addr.Is4() { |
||||
hasIPv4 = true |
||||
} else if addr.Is6() { |
||||
hasIPv6 = true |
||||
} |
||||
|
||||
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) |
||||
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { |
||||
common.Interrupt(link.Reader) |
||||
common.Interrupt(link.Writer) |
||||
return newError("connection ends").Base(err) |
||||
} |
||||
|
||||
return nil |
||||
return endpoints, hasIPv4, hasIPv6, nil |
||||
} |
||||
|
||||
// serialize the config into an IPC request
|
||||
func createIPCRequest(conf *DeviceConfig, d dns.Client, resolveEndPointToV4 bool) string { |
||||
var request bytes.Buffer |
||||
func createIPCRequest(conf *DeviceConfig) string { |
||||
var request strings.Builder |
||||
|
||||
request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey)) |
||||
|
||||
for _, peer := range conf.Peers { |
||||
endpoint := peer.Endpoint |
||||
host, port, err := net.SplitHostPort(endpoint) |
||||
if resolveEndPointToV4 && err == nil { |
||||
_, err = netip.ParseAddr(host) |
||||
if err != nil { |
||||
ipList, err := d.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: false}) |
||||
if err == nil && len(ipList) > 0 { |
||||
endpoint = stdnet.JoinHostPort(ipList[0].String(), port) |
||||
} |
||||
} |
||||
} |
||||
|
||||
request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n", |
||||
peer.PublicKey, endpoint, peer.KeepAlive, peer.PreSharedKey)) |
||||
|
||||
for _, ip := range peer.AllowedIps { |
||||
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) |
||||
} |
||||
if !conf.IsClient { |
||||
// placeholder, we'll handle actual port listening on Xray
|
||||
request.WriteString("listen_port=1337\n") |
||||
} |
||||
|
||||
return request.String()[:request.Len()] |
||||
for _, peer := range conf.Peers { |
||||
if peer.PublicKey != "" { |
||||
request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey)) |
||||
} |
||||
|
||||
// convert endpoint string to netip.Addr
|
||||
func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) { |
||||
endpoints := make([]netip.Addr, len(conf.Endpoint)) |
||||
for i, str := range conf.Endpoint { |
||||
var addr netip.Addr |
||||
if strings.Contains(str, "/") { |
||||
prefix, err := netip.ParsePrefix(str) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
addr = prefix.Addr() |
||||
if prefix.Bits() != addr.BitLen() { |
||||
return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6") |
||||
} |
||||
} else { |
||||
var err error |
||||
addr, err = netip.ParseAddr(str) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
endpoints[i] = addr |
||||
if peer.PreSharedKey != "" { |
||||
request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey)) |
||||
} |
||||
|
||||
return endpoints, nil |
||||
if peer.Endpoint != "" { |
||||
request.WriteString(fmt.Sprintf("endpoint=%s\n", peer.Endpoint)) |
||||
} |
||||
|
||||
// creates a tun interface on netstack given a configuration
|
||||
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) { |
||||
t, err := CreateTun(h.endpoints, int(h.conf.Mtu)) |
||||
if err != nil { |
||||
return nil, err |
||||
for _, ip := range peer.AllowedIps { |
||||
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) |
||||
} |
||||
|
||||
bind.dnsOption.IPv4Enable = h.hasIPv4 |
||||
bind.dnsOption.IPv6Enable = h.hasIPv6 |
||||
|
||||
if err = t.BuildDevice(h.ipc, bind); err != nil { |
||||
_ = t.Close() |
||||
return nil, err |
||||
if peer.KeepAlive != 0 { |
||||
request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive)) |
||||
} |
||||
return t, nil |
||||
} |
||||
|
||||
func init() { |
||||
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { |
||||
return New(ctx, config.(*DeviceConfig)) |
||||
})) |
||||
return request.String()[:request.Len()] |
||||
} |
||||
|
Loading…
Reference in new issue