mirror of https://github.com/XTLS/Xray-core
248 lines
5.8 KiB
Go
248 lines
5.8 KiB
Go
package dns
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/xtls/xray-core/common"
|
|
"github.com/xtls/xray-core/common/errors"
|
|
"github.com/xtls/xray-core/common/log"
|
|
"github.com/xtls/xray-core/common/net"
|
|
"github.com/xtls/xray-core/common/session"
|
|
"github.com/xtls/xray-core/core"
|
|
dns_feature "github.com/xtls/xray-core/features/dns"
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
)
|
|
|
|
// Fqdn normalizes domain make sure it ends with '.'
|
|
func Fqdn(domain string) string {
|
|
if len(domain) > 0 && strings.HasSuffix(domain, ".") {
|
|
return domain
|
|
}
|
|
return domain + "."
|
|
}
|
|
|
|
type record struct {
|
|
A *IPRecord
|
|
AAAA *IPRecord
|
|
}
|
|
|
|
// IPRecord is a cacheable item for a resolved domain
|
|
type IPRecord struct {
|
|
ReqID uint16
|
|
IP []net.Address
|
|
Expire time.Time
|
|
RCode dnsmessage.RCode
|
|
}
|
|
|
|
func (r *IPRecord) getIPs() ([]net.Address, error) {
|
|
if r == nil || r.Expire.Before(time.Now()) {
|
|
return nil, errRecordNotFound
|
|
}
|
|
if r.RCode != dnsmessage.RCodeSuccess {
|
|
return nil, dns_feature.RCodeError(r.RCode)
|
|
}
|
|
return r.IP, nil
|
|
}
|
|
|
|
func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
|
|
if newRec == nil {
|
|
return false
|
|
}
|
|
if baseRec == nil {
|
|
return true
|
|
}
|
|
return baseRec.Expire.Before(newRec.Expire)
|
|
}
|
|
|
|
var errRecordNotFound = errors.New("record not found")
|
|
|
|
type dnsRequest struct {
|
|
reqType dnsmessage.Type
|
|
domain string
|
|
start time.Time
|
|
expire time.Time
|
|
msg *dnsmessage.Message
|
|
}
|
|
|
|
func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
|
|
if len(clientIP) == 0 {
|
|
return nil
|
|
}
|
|
|
|
var netmask int
|
|
var family uint16
|
|
|
|
if len(clientIP) == 4 {
|
|
family = 1
|
|
netmask = 24 // 24 for IPV4, 96 for IPv6
|
|
} else {
|
|
family = 2
|
|
netmask = 96
|
|
}
|
|
|
|
b := make([]byte, 4)
|
|
binary.BigEndian.PutUint16(b[0:], family)
|
|
b[2] = byte(netmask)
|
|
b[3] = 0
|
|
switch family {
|
|
case 1:
|
|
ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
|
|
needLength := (netmask + 8 - 1) / 8 // division rounding up
|
|
b = append(b, ip[:needLength]...)
|
|
case 2:
|
|
ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
|
|
needLength := (netmask + 8 - 1) / 8 // division rounding up
|
|
b = append(b, ip[:needLength]...)
|
|
}
|
|
|
|
const EDNS0SUBNET = 0x08
|
|
|
|
opt := new(dnsmessage.Resource)
|
|
common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
|
|
|
|
opt.Body = &dnsmessage.OPTResource{
|
|
Options: []dnsmessage.Option{
|
|
{
|
|
Code: EDNS0SUBNET,
|
|
Data: b,
|
|
},
|
|
},
|
|
}
|
|
|
|
return opt
|
|
}
|
|
|
|
func buildReqMsgs(domain string, option dns_feature.IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
|
|
qA := dnsmessage.Question{
|
|
Name: dnsmessage.MustNewName(domain),
|
|
Type: dnsmessage.TypeA,
|
|
Class: dnsmessage.ClassINET,
|
|
}
|
|
|
|
qAAAA := dnsmessage.Question{
|
|
Name: dnsmessage.MustNewName(domain),
|
|
Type: dnsmessage.TypeAAAA,
|
|
Class: dnsmessage.ClassINET,
|
|
}
|
|
|
|
var reqs []*dnsRequest
|
|
now := time.Now()
|
|
|
|
if option.IPv4Enable {
|
|
msg := new(dnsmessage.Message)
|
|
msg.Header.ID = reqIDGen()
|
|
msg.Header.RecursionDesired = true
|
|
msg.Questions = []dnsmessage.Question{qA}
|
|
if reqOpts != nil {
|
|
msg.Additionals = append(msg.Additionals, *reqOpts)
|
|
}
|
|
reqs = append(reqs, &dnsRequest{
|
|
reqType: dnsmessage.TypeA,
|
|
domain: domain,
|
|
start: now,
|
|
msg: msg,
|
|
})
|
|
}
|
|
|
|
if option.IPv6Enable {
|
|
msg := new(dnsmessage.Message)
|
|
msg.Header.ID = reqIDGen()
|
|
msg.Header.RecursionDesired = true
|
|
msg.Questions = []dnsmessage.Question{qAAAA}
|
|
if reqOpts != nil {
|
|
msg.Additionals = append(msg.Additionals, *reqOpts)
|
|
}
|
|
reqs = append(reqs, &dnsRequest{
|
|
reqType: dnsmessage.TypeAAAA,
|
|
domain: domain,
|
|
start: now,
|
|
msg: msg,
|
|
})
|
|
}
|
|
|
|
return reqs
|
|
}
|
|
|
|
// parseResponse parses DNS answers from the returned payload
|
|
func parseResponse(payload []byte) (*IPRecord, error) {
|
|
var parser dnsmessage.Parser
|
|
h, err := parser.Start(payload)
|
|
if err != nil {
|
|
return nil, errors.New("failed to parse DNS response").Base(err).AtWarning()
|
|
}
|
|
if err := parser.SkipAllQuestions(); err != nil {
|
|
return nil, errors.New("failed to skip questions in DNS response").Base(err).AtWarning()
|
|
}
|
|
|
|
now := time.Now()
|
|
ipRecord := &IPRecord{
|
|
ReqID: h.ID,
|
|
RCode: h.RCode,
|
|
Expire: now.Add(time.Second * 600),
|
|
}
|
|
|
|
L:
|
|
for {
|
|
ah, err := parser.AnswerHeader()
|
|
if err != nil {
|
|
if err != dnsmessage.ErrSectionDone {
|
|
errors.LogInfoInner(context.Background(), err, "failed to parse answer section for domain: ", ah.Name.String())
|
|
}
|
|
break
|
|
}
|
|
|
|
ttl := ah.TTL
|
|
if ttl == 0 {
|
|
ttl = 600
|
|
}
|
|
expire := now.Add(time.Duration(ttl) * time.Second)
|
|
if ipRecord.Expire.After(expire) {
|
|
ipRecord.Expire = expire
|
|
}
|
|
|
|
switch ah.Type {
|
|
case dnsmessage.TypeA:
|
|
ans, err := parser.AResource()
|
|
if err != nil {
|
|
errors.LogInfoInner(context.Background(), err, "failed to parse A record for domain: ", ah.Name)
|
|
break L
|
|
}
|
|
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
|
|
case dnsmessage.TypeAAAA:
|
|
ans, err := parser.AAAAResource()
|
|
if err != nil {
|
|
errors.LogInfoInner(context.Background(), err, "failed to parse AAAA record for domain: ", ah.Name)
|
|
break L
|
|
}
|
|
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
|
|
default:
|
|
if err := parser.SkipAnswer(); err != nil {
|
|
errors.LogInfoInner(context.Background(), err, "failed to skip answer")
|
|
break L
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
|
|
return ipRecord, nil
|
|
}
|
|
|
|
// toDnsContext create a new background context with parent inbound, session and dns log
|
|
func toDnsContext(ctx context.Context, addr string) context.Context {
|
|
dnsCtx := core.ToBackgroundDetachedContext(ctx)
|
|
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
|
dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
|
|
}
|
|
dnsCtx = session.ContextWithContent(dnsCtx, session.ContentFromContext(ctx))
|
|
dnsCtx = log.ContextWithAccessMessage(dnsCtx, &log.AccessMessage{
|
|
From: "DNS",
|
|
To: addr,
|
|
Status: log.AccessAccepted,
|
|
Reason: "",
|
|
})
|
|
return dnsCtx
|
|
}
|