mirror of https://github.com/XTLS/Xray-core
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
380 lines
8.5 KiB
380 lines
8.5 KiB
package dns |
|
|
|
import ( |
|
"context" |
|
"io" |
|
"sync" |
|
"time" |
|
|
|
"github.com/xtls/xray-core/common" |
|
"github.com/xtls/xray-core/common/buf" |
|
"github.com/xtls/xray-core/common/errors" |
|
"github.com/xtls/xray-core/common/net" |
|
dns_proto "github.com/xtls/xray-core/common/protocol/dns" |
|
"github.com/xtls/xray-core/common/session" |
|
"github.com/xtls/xray-core/common/signal" |
|
"github.com/xtls/xray-core/common/task" |
|
"github.com/xtls/xray-core/core" |
|
"github.com/xtls/xray-core/features/dns" |
|
"github.com/xtls/xray-core/features/policy" |
|
"github.com/xtls/xray-core/transport" |
|
"github.com/xtls/xray-core/transport/internet" |
|
"github.com/xtls/xray-core/transport/internet/stat" |
|
"golang.org/x/net/dns/dnsmessage" |
|
) |
|
|
|
func init() { |
|
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { |
|
h := new(Handler) |
|
if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { |
|
core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { |
|
h.fdns = fdns |
|
}) |
|
return h.Init(config.(*Config), dnsClient, policyManager) |
|
}); err != nil { |
|
return nil, err |
|
} |
|
return h, nil |
|
})) |
|
} |
|
|
|
type ownLinkVerifier interface { |
|
IsOwnLink(ctx context.Context) bool |
|
} |
|
|
|
type Handler struct { |
|
client dns.Client |
|
fdns dns.FakeDNSEngine |
|
ownLinkVerifier ownLinkVerifier |
|
server net.Destination |
|
timeout time.Duration |
|
nonIPQuery string |
|
blockTypes []int32 |
|
} |
|
|
|
func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error { |
|
h.client = dnsClient |
|
h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle |
|
|
|
if v, ok := dnsClient.(ownLinkVerifier); ok { |
|
h.ownLinkVerifier = v |
|
} |
|
|
|
if config.Server != nil { |
|
h.server = config.Server.AsDestination() |
|
} |
|
h.nonIPQuery = config.Non_IPQuery |
|
h.blockTypes = config.BlockTypes |
|
return nil |
|
} |
|
|
|
func (h *Handler) isOwnLink(ctx context.Context) bool { |
|
return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) |
|
} |
|
|
|
func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { |
|
var parser dnsmessage.Parser |
|
header, err := parser.Start(b) |
|
if err != nil { |
|
errors.LogInfoInner(context.Background(), err, "parser start") |
|
return |
|
} |
|
|
|
id = header.ID |
|
q, err := parser.Question() |
|
if err != nil { |
|
errors.LogInfoInner(context.Background(), err, "question") |
|
return |
|
} |
|
domain = q.Name.String() |
|
qType = q.Type |
|
if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { |
|
return |
|
} |
|
|
|
r = true |
|
return |
|
} |
|
|
|
// Process implements proxy.Outbound. |
|
func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { |
|
outbounds := session.OutboundsFromContext(ctx) |
|
ob := outbounds[len(outbounds)-1] |
|
if !ob.Target.IsValid() { |
|
return errors.New("invalid outbound") |
|
} |
|
ob.Name = "dns" |
|
|
|
srcNetwork := ob.Target.Network |
|
|
|
dest := ob.Target |
|
if h.server.Network != net.Network_Unknown { |
|
dest.Network = h.server.Network |
|
} |
|
if h.server.Address != nil { |
|
dest.Address = h.server.Address |
|
} |
|
if h.server.Port != 0 { |
|
dest.Port = h.server.Port |
|
} |
|
|
|
errors.LogInfo(ctx, "handling DNS traffic to ", dest) |
|
|
|
conn := &outboundConn{ |
|
dialer: func() (stat.Connection, error) { |
|
return d.Dial(ctx, dest) |
|
}, |
|
connReady: make(chan struct{}, 1), |
|
} |
|
|
|
var reader dns_proto.MessageReader |
|
var writer dns_proto.MessageWriter |
|
if srcNetwork == net.Network_TCP { |
|
reader = dns_proto.NewTCPReader(link.Reader) |
|
writer = &dns_proto.TCPWriter{ |
|
Writer: link.Writer, |
|
} |
|
} else { |
|
reader = &dns_proto.UDPReader{ |
|
Reader: link.Reader, |
|
} |
|
writer = &dns_proto.UDPWriter{ |
|
Writer: link.Writer, |
|
} |
|
} |
|
|
|
var connReader dns_proto.MessageReader |
|
var connWriter dns_proto.MessageWriter |
|
if dest.Network == net.Network_TCP { |
|
connReader = dns_proto.NewTCPReader(buf.NewReader(conn)) |
|
connWriter = &dns_proto.TCPWriter{ |
|
Writer: buf.NewWriter(conn), |
|
} |
|
} else { |
|
connReader = &dns_proto.UDPReader{ |
|
Reader: buf.NewPacketReader(conn), |
|
} |
|
connWriter = &dns_proto.UDPWriter{ |
|
Writer: buf.NewWriter(conn), |
|
} |
|
} |
|
|
|
if session.TimeoutOnlyFromContext(ctx) { |
|
ctx, _ = context.WithCancel(context.Background()) |
|
} |
|
|
|
ctx, cancel := context.WithCancel(ctx) |
|
timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) |
|
|
|
request := func() error { |
|
defer conn.Close() |
|
|
|
for { |
|
b, err := reader.ReadMessage() |
|
if err == io.EOF { |
|
return nil |
|
} |
|
|
|
if err != nil { |
|
return err |
|
} |
|
|
|
timer.Update() |
|
|
|
if !h.isOwnLink(ctx) { |
|
isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) |
|
if len(h.blockTypes) > 0 { |
|
for _, blocktype := range h.blockTypes { |
|
if blocktype == int32(qType) { |
|
errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain) |
|
return nil |
|
} |
|
} |
|
} |
|
if isIPQuery { |
|
go h.handleIPQuery(id, qType, domain, writer) |
|
} |
|
if isIPQuery || h.nonIPQuery == "drop" { |
|
b.Release() |
|
continue |
|
} |
|
} |
|
|
|
if err := connWriter.WriteMessage(b); err != nil { |
|
return err |
|
} |
|
} |
|
} |
|
|
|
response := func() error { |
|
for { |
|
b, err := connReader.ReadMessage() |
|
if err == io.EOF { |
|
return nil |
|
} |
|
|
|
if err != nil { |
|
return err |
|
} |
|
|
|
timer.Update() |
|
|
|
if err := writer.WriteMessage(b); err != nil { |
|
return err |
|
} |
|
} |
|
} |
|
|
|
if err := task.Run(ctx, request, response); err != nil { |
|
return errors.New("connection ends").Base(err) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { |
|
var ips []net.IP |
|
var err error |
|
|
|
var ttl uint32 = 600 |
|
|
|
switch qType { |
|
case dnsmessage.TypeA: |
|
ips, err = h.client.LookupIP(domain, dns.IPOption{ |
|
IPv4Enable: true, |
|
IPv6Enable: false, |
|
FakeEnable: true, |
|
}) |
|
case dnsmessage.TypeAAAA: |
|
ips, err = h.client.LookupIP(domain, dns.IPOption{ |
|
IPv4Enable: false, |
|
IPv6Enable: true, |
|
FakeEnable: true, |
|
}) |
|
} |
|
|
|
rcode := dns.RCodeFromError(err) |
|
if rcode == 0 && len(ips) == 0 && !errors.AllEqual(dns.ErrEmptyResponse, errors.Cause(err)) { |
|
errors.LogInfoInner(context.Background(), err, "ip query") |
|
return |
|
} |
|
|
|
if fkr0, ok := h.fdns.(dns.FakeDNSEngineRev0); ok && len(ips) > 0 && fkr0.IsIPInIPPool(net.IPAddress(ips[0])) { |
|
ttl = 1 |
|
} |
|
|
|
switch qType { |
|
case dnsmessage.TypeA: |
|
for i, ip := range ips { |
|
ips[i] = ip.To4() |
|
} |
|
case dnsmessage.TypeAAAA: |
|
for i, ip := range ips { |
|
ips[i] = ip.To16() |
|
} |
|
} |
|
|
|
b := buf.New() |
|
rawBytes := b.Extend(buf.Size) |
|
builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ |
|
ID: id, |
|
RCode: dnsmessage.RCode(rcode), |
|
RecursionAvailable: true, |
|
RecursionDesired: true, |
|
Response: true, |
|
Authoritative: true, |
|
}) |
|
builder.EnableCompression() |
|
common.Must(builder.StartQuestions()) |
|
common.Must(builder.Question(dnsmessage.Question{ |
|
Name: dnsmessage.MustNewName(domain), |
|
Class: dnsmessage.ClassINET, |
|
Type: qType, |
|
})) |
|
common.Must(builder.StartAnswers()) |
|
|
|
rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl} |
|
for _, ip := range ips { |
|
if len(ip) == net.IPv4len { |
|
var r dnsmessage.AResource |
|
copy(r.A[:], ip) |
|
common.Must(builder.AResource(rHeader, r)) |
|
} else { |
|
var r dnsmessage.AAAAResource |
|
copy(r.AAAA[:], ip) |
|
common.Must(builder.AAAAResource(rHeader, r)) |
|
} |
|
} |
|
msgBytes, err := builder.Finish() |
|
if err != nil { |
|
errors.LogInfoInner(context.Background(), err, "pack message") |
|
b.Release() |
|
return |
|
} |
|
b.Resize(0, int32(len(msgBytes))) |
|
|
|
if err := writer.WriteMessage(b); err != nil { |
|
errors.LogInfoInner(context.Background(), err, "write IP answer") |
|
} |
|
} |
|
|
|
type outboundConn struct { |
|
access sync.Mutex |
|
dialer func() (stat.Connection, error) |
|
|
|
conn net.Conn |
|
connReady chan struct{} |
|
} |
|
|
|
func (c *outboundConn) dial() error { |
|
conn, err := c.dialer() |
|
if err != nil { |
|
return err |
|
} |
|
c.conn = conn |
|
c.connReady <- struct{}{} |
|
return nil |
|
} |
|
|
|
func (c *outboundConn) Write(b []byte) (int, error) { |
|
c.access.Lock() |
|
|
|
if c.conn == nil { |
|
if err := c.dial(); err != nil { |
|
c.access.Unlock() |
|
errors.LogWarningInner(context.Background(), err, "failed to dial outbound connection") |
|
return len(b), nil |
|
} |
|
} |
|
|
|
c.access.Unlock() |
|
|
|
return c.conn.Write(b) |
|
} |
|
|
|
func (c *outboundConn) Read(b []byte) (int, error) { |
|
var conn net.Conn |
|
c.access.Lock() |
|
conn = c.conn |
|
c.access.Unlock() |
|
|
|
if conn == nil { |
|
_, open := <-c.connReady |
|
if !open { |
|
return 0, io.EOF |
|
} |
|
conn = c.conn |
|
} |
|
|
|
return conn.Read(b) |
|
} |
|
|
|
func (c *outboundConn) Close() error { |
|
c.access.Lock() |
|
close(c.connReady) |
|
if c.conn != nil { |
|
c.conn.Close() |
|
} |
|
c.access.Unlock() |
|
return nil |
|
}
|
|
|