@ -22,7 +22,9 @@ package wireguard
import (
import (
"context"
"context"
"fmt"
"net/netip"
"net/netip"
"strings"
"sync"
"sync"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common"
@ -49,7 +51,6 @@ type Handler struct {
policyManager policy . Manager
policyManager policy . Manager
dns dns . Client
dns dns . Client
// cached configuration
// cached configuration
ipc string
endpoints [ ] netip . Addr
endpoints [ ] netip . Addr
hasIPv4 , hasIPv6 bool
hasIPv4 , hasIPv6 bool
wgLock sync . Mutex
wgLock sync . Mutex
@ -69,7 +70,6 @@ func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
conf : conf ,
conf : conf ,
policyManager : v . GetFeature ( policy . ManagerType ( ) ) . ( policy . Manager ) ,
policyManager : v . GetFeature ( policy . ManagerType ( ) ) . ( policy . Manager ) ,
dns : d ,
dns : d ,
ipc : createIPCRequest ( conf ) ,
endpoints : endpoints ,
endpoints : endpoints ,
hasIPv4 : hasIPv4 ,
hasIPv4 : hasIPv4 ,
hasIPv6 : hasIPv6 ,
hasIPv6 : hasIPv6 ,
@ -247,9 +247,76 @@ func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
bind . dnsOption . IPv4Enable = h . hasIPv4
bind . dnsOption . IPv4Enable = h . hasIPv4
bind . dnsOption . IPv6Enable = h . hasIPv6
bind . dnsOption . IPv6Enable = h . hasIPv6
if err = t . BuildDevice ( h . ipc , bind ) ; err != nil {
if err = t . BuildDevice ( h . createIPCRequest ( bind , h . conf ) , bind ) ; err != nil {
_ = t . Close ( )
_ = t . Close ( )
return nil , err
return nil , err
}
}
return t , nil
return t , nil
}
}
// serialize the config into an IPC request
func ( h * Handler ) createIPCRequest ( bind * netBindClient , conf * DeviceConfig ) string {
var request strings . Builder
request . WriteString ( fmt . Sprintf ( "private_key=%s\n" , conf . SecretKey ) )
if ! conf . IsClient {
// placeholder, we'll handle actual port listening on Xray
request . WriteString ( "listen_port=1337\n" )
}
for _ , peer := range conf . Peers {
if peer . PublicKey != "" {
request . WriteString ( fmt . Sprintf ( "public_key=%s\n" , peer . PublicKey ) )
}
if peer . PreSharedKey != "" {
request . WriteString ( fmt . Sprintf ( "preshared_key=%s\n" , peer . PreSharedKey ) )
}
split := strings . Split ( peer . Endpoint , ":" )
addr := net . ParseAddress ( split [ 0 ] )
if addr . Family ( ) . IsDomain ( ) {
dialerIp := bind . dialer . DestIpAddress ( )
if dialerIp != nil {
addr = net . ParseAddress ( dialerIp . String ( ) )
newError ( "createIPCRequest use dialer dest ip: " , addr ) . WriteToLog ( )
} else {
ips , err := h . dns . LookupIP ( addr . Domain ( ) , dns . IPOption {
IPv4Enable : h . hasIPv4 && h . conf . preferIP4 ( ) ,
IPv6Enable : h . hasIPv6 && h . conf . preferIP6 ( ) ,
} )
{ // Resolve fallback
if ( len ( ips ) == 0 || err != nil ) && h . conf . hasFallback ( ) {
ips , err = h . dns . LookupIP ( addr . Domain ( ) , dns . IPOption {
IPv4Enable : h . hasIPv4 && h . conf . fallbackIP4 ( ) ,
IPv6Enable : h . hasIPv6 && h . conf . fallbackIP6 ( ) ,
} )
}
}
if err != nil {
newError ( "createIPCRequest failed to lookup DNS" ) . Base ( err ) . WriteToLog ( )
} else if len ( ips ) == 0 {
newError ( "createIPCRequest empty lookup DNS" ) . WriteToLog ( )
} else {
addr = net . IPAddress ( ips [ dice . Roll ( len ( ips ) ) ] )
}
}
}
if peer . Endpoint != "" {
request . WriteString ( fmt . Sprintf ( "endpoint=%s:%s\n" , addr , split [ 1 ] ) )
}
for _ , ip := range peer . AllowedIps {
request . WriteString ( fmt . Sprintf ( "allowed_ip=%s\n" , ip ) )
}
if peer . KeepAlive != 0 {
request . WriteString ( fmt . Sprintf ( "persistent_keepalive_interval=%d\n" , peer . KeepAlive ) )
}
}
return request . String ( ) [ : request . Len ( ) ]
}