mirror of https://github.com/XTLS/Xray-core
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
263 lines
7.5 KiB
263 lines
7.5 KiB
/* |
|
|
|
Some of codes are copied from https://github.com/octeep/wireproxy, license below. |
|
|
|
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me> |
|
|
|
Permission to use, copy, modify, and distribute this software for any |
|
purpose with or without fee is hereby granted, provided that the above |
|
copyright notice and this permission notice appear in all copies. |
|
|
|
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES |
|
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF |
|
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR |
|
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES |
|
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN |
|
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF |
|
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. |
|
|
|
*/ |
|
|
|
package wireguard |
|
|
|
import ( |
|
"bytes" |
|
"context" |
|
"fmt" |
|
"net/netip" |
|
"strings" |
|
|
|
"github.com/sagernet/wireguard-go/device" |
|
"github.com/xtls/xray-core/common" |
|
"github.com/xtls/xray-core/common/buf" |
|
"github.com/xtls/xray-core/common/log" |
|
"github.com/xtls/xray-core/common/net" |
|
"github.com/xtls/xray-core/common/protocol" |
|
"github.com/xtls/xray-core/common/session" |
|
"github.com/xtls/xray-core/common/signal" |
|
"github.com/xtls/xray-core/common/task" |
|
"github.com/xtls/xray-core/core" |
|
"github.com/xtls/xray-core/features/dns" |
|
"github.com/xtls/xray-core/features/policy" |
|
"github.com/xtls/xray-core/transport" |
|
"github.com/xtls/xray-core/transport/internet" |
|
) |
|
|
|
// Handler is an outbound connection that silently swallow the entire payload. |
|
type Handler struct { |
|
conf *DeviceConfig |
|
net *Net |
|
bind *netBindClient |
|
policyManager policy.Manager |
|
dns dns.Client |
|
// cached configuration |
|
ipc string |
|
endpoints []netip.Addr |
|
} |
|
|
|
// New creates a new wireguard handler. |
|
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { |
|
v := core.MustFromContext(ctx) |
|
|
|
endpoints, err := parseEndpoints(conf) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return &Handler{ |
|
conf: conf, |
|
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), |
|
dns: v.GetFeature(dns.ClientType()).(dns.Client), |
|
ipc: createIPCRequest(conf), |
|
endpoints: endpoints, |
|
}, nil |
|
} |
|
|
|
// Process implements OutboundHandler.Dispatch(). |
|
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { |
|
if h.bind == nil || h.bind.dialer != dialer || h.net == nil { |
|
log.Record(&log.GeneralMessage{ |
|
Severity: log.Severity_Info, |
|
Content: "switching dialer", |
|
}) |
|
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer |
|
bind := &netBindClient{ |
|
dialer: dialer, |
|
workers: int(h.conf.NumWorkers), |
|
dns: h.dns, |
|
} |
|
|
|
net, err := h.makeVirtualTun(bind) |
|
if err != nil { |
|
bind.Close() |
|
return newError("failed to create virtual tun interface").Base(err) |
|
} |
|
|
|
h.net = net |
|
if h.bind != nil { |
|
h.bind.Close() |
|
} |
|
h.bind = bind |
|
} |
|
|
|
outbound := session.OutboundFromContext(ctx) |
|
if outbound == nil || !outbound.Target.IsValid() { |
|
return newError("target not specified") |
|
} |
|
// Destination of the inner request. |
|
destination := outbound.Target |
|
command := protocol.RequestCommandTCP |
|
if destination.Network == net.Network_UDP { |
|
command = protocol.RequestCommandUDP |
|
} |
|
|
|
// resolve dns |
|
addr := destination.Address |
|
if addr.Family().IsDomain() { |
|
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ |
|
IPv4Enable: h.net.HasV4(), |
|
IPv6Enable: h.net.HasV6(), |
|
}) |
|
if err != nil { |
|
return newError("failed to lookup DNS").Base(err) |
|
} else if len(ips) == 0 { |
|
return dns.ErrEmptyResponse |
|
} |
|
addr = net.IPAddress(ips[0]) |
|
} |
|
|
|
p := h.policyManager.ForLevel(0) |
|
|
|
ctx, cancel := context.WithCancel(ctx) |
|
timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle) |
|
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value()) |
|
|
|
var requestFunc func() error |
|
var responseFunc func() error |
|
|
|
if command == protocol.RequestCommandTCP { |
|
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort) |
|
if err != nil { |
|
return newError("failed to create TCP connection").Base(err) |
|
} |
|
|
|
requestFunc = func() error { |
|
defer timer.SetTimeout(p.Timeouts.DownlinkOnly) |
|
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) |
|
} |
|
responseFunc = func() error { |
|
defer timer.SetTimeout(p.Timeouts.UplinkOnly) |
|
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) |
|
} |
|
} else if command == protocol.RequestCommandUDP { |
|
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort) |
|
if err != nil { |
|
return newError("failed to create UDP connection").Base(err) |
|
} |
|
|
|
requestFunc = func() error { |
|
defer timer.SetTimeout(p.Timeouts.DownlinkOnly) |
|
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) |
|
} |
|
responseFunc = func() error { |
|
defer timer.SetTimeout(p.Timeouts.UplinkOnly) |
|
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) |
|
} |
|
} |
|
|
|
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) |
|
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { |
|
return newError("connection ends").Base(err) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
// serialize the config into an IPC request |
|
func createIPCRequest(conf *DeviceConfig) string { |
|
var request bytes.Buffer |
|
|
|
request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey)) |
|
|
|
for _, peer := range conf.Peers { |
|
request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n", |
|
peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey)) |
|
|
|
for _, ip := range peer.AllowedIps { |
|
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) |
|
} |
|
} |
|
|
|
return request.String()[:request.Len()] |
|
} |
|
|
|
// convert endpoint string to netip.Addr |
|
func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) { |
|
endpoints := make([]netip.Addr, len(conf.Endpoint)) |
|
for i, str := range conf.Endpoint { |
|
var addr netip.Addr |
|
if strings.Contains(str, "/") { |
|
prefix, err := netip.ParsePrefix(str) |
|
if err != nil { |
|
return nil, err |
|
} |
|
addr = prefix.Addr() |
|
if prefix.Bits() != addr.BitLen() { |
|
return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6") |
|
} |
|
} else { |
|
var err error |
|
addr, err = netip.ParseAddr(str) |
|
if err != nil { |
|
return nil, err |
|
} |
|
} |
|
endpoints[i] = addr |
|
} |
|
|
|
return endpoints, nil |
|
} |
|
|
|
// creates a tun interface on netstack given a configuration |
|
func (h *Handler) makeVirtualTun(bind *netBindClient) (*Net, error) { |
|
tun, tnet, err := CreateNetTUN(h.endpoints, h.dns, int(h.conf.Mtu)) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
bind.dnsOption.IPv4Enable = tnet.HasV4() |
|
bind.dnsOption.IPv6Enable = tnet.HasV6() |
|
|
|
// dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */) |
|
dev := device.NewDevice(tun, bind, &device.Logger{ |
|
Verbosef: func(format string, args ...any) { |
|
log.Record(&log.GeneralMessage{ |
|
Severity: log.Severity_Debug, |
|
Content: fmt.Sprintf(format, args...), |
|
}) |
|
}, |
|
Errorf: func(format string, args ...any) { |
|
log.Record(&log.GeneralMessage{ |
|
Severity: log.Severity_Error, |
|
Content: fmt.Sprintf(format, args...), |
|
}) |
|
}, |
|
}, int(h.conf.NumWorkers)) |
|
err = dev.IpcSet(h.ipc) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
err = dev.Up() |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return tnet, nil |
|
} |
|
|
|
func init() { |
|
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { |
|
return New(ctx, config.(*DeviceConfig)) |
|
})) |
|
}
|
|
|