|
|
|
@ -5,36 +5,60 @@ package dns
|
|
|
|
|
import ( |
|
|
|
|
"context" |
|
|
|
|
"encoding/binary" |
|
|
|
|
fmt "fmt" |
|
|
|
|
"sync" |
|
|
|
|
"sync/atomic" |
|
|
|
|
"time" |
|
|
|
|
|
|
|
|
|
"golang.org/x/net/dns/dnsmessage" |
|
|
|
|
"v2ray.com/core/common" |
|
|
|
|
"v2ray.com/core/common/errors" |
|
|
|
|
"v2ray.com/core/common/net" |
|
|
|
|
"v2ray.com/core/common/protocol/dns" |
|
|
|
|
udp_proto "v2ray.com/core/common/protocol/udp" |
|
|
|
|
"v2ray.com/core/common/session" |
|
|
|
|
"v2ray.com/core/common/signal/pubsub" |
|
|
|
|
"v2ray.com/core/common/task" |
|
|
|
|
dns_feature "v2ray.com/core/features/dns" |
|
|
|
|
"v2ray.com/core/features/routing" |
|
|
|
|
"v2ray.com/core/transport/internet/udp" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
type record struct { |
|
|
|
|
A *IPRecord |
|
|
|
|
AAAA *IPRecord |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type IPRecord struct { |
|
|
|
|
IP net.Address |
|
|
|
|
IP []net.Address |
|
|
|
|
Expire time.Time |
|
|
|
|
RCode dnsmessage.RCode |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (r *IPRecord) getIPs() ([]net.Address, error) { |
|
|
|
|
if r == nil || r.Expire.Before(time.Now()) { |
|
|
|
|
return nil, errRecordNotFound |
|
|
|
|
} |
|
|
|
|
if r.RCode != dnsmessage.RCodeSuccess { |
|
|
|
|
return nil, dns_feature.RCodeError(r.RCode) |
|
|
|
|
} |
|
|
|
|
return r.IP, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type pendingRequest struct { |
|
|
|
|
domain string |
|
|
|
|
expire time.Time |
|
|
|
|
domain string |
|
|
|
|
expire time.Time |
|
|
|
|
recType dnsmessage.Type |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var ( |
|
|
|
|
errRecordNotFound = errors.New("record not found") |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
type ClassicNameServer struct { |
|
|
|
|
sync.RWMutex |
|
|
|
|
address net.Destination |
|
|
|
|
ips map[string][]IPRecord |
|
|
|
|
ips map[string]record |
|
|
|
|
requests map[uint16]pendingRequest |
|
|
|
|
pub *pubsub.Service |
|
|
|
|
udpServer *udp.Dispatcher |
|
|
|
@ -46,7 +70,7 @@ type ClassicNameServer struct {
|
|
|
|
|
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer { |
|
|
|
|
s := &ClassicNameServer{ |
|
|
|
|
address: address, |
|
|
|
|
ips: make(map[string][]IPRecord), |
|
|
|
|
ips: make(map[string]record), |
|
|
|
|
requests: make(map[uint16]pendingRequest), |
|
|
|
|
clientIP: clientIP, |
|
|
|
|
pub: pubsub.NewService(), |
|
|
|
@ -72,22 +96,23 @@ func (s *ClassicNameServer) Cleanup() error {
|
|
|
|
|
return newError("nothing to do. stopping...") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for domain, ips := range s.ips { |
|
|
|
|
newIPs := make([]IPRecord, 0, len(ips)) |
|
|
|
|
for _, ip := range ips { |
|
|
|
|
if ip.Expire.After(now) { |
|
|
|
|
newIPs = append(newIPs, ip) |
|
|
|
|
} |
|
|
|
|
for domain, record := range s.ips { |
|
|
|
|
if record.A != nil && record.A.Expire.Before(now) { |
|
|
|
|
record.A = nil |
|
|
|
|
} |
|
|
|
|
if len(newIPs) == 0 { |
|
|
|
|
if record.AAAA != nil && record.AAAA.Expire.Before(now) { |
|
|
|
|
record.AAAA = nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if record.A == nil && record.AAAA == nil { |
|
|
|
|
delete(s.ips, domain) |
|
|
|
|
} else if len(newIPs) < len(ips) { |
|
|
|
|
s.ips[domain] = newIPs |
|
|
|
|
} else { |
|
|
|
|
s.ips[domain] = record |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if len(s.ips) == 0 { |
|
|
|
|
s.ips = make(map[string][]IPRecord) |
|
|
|
|
s.ips = make(map[string]record) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for id, req := range s.requests { |
|
|
|
@ -130,9 +155,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
domain := req.domain |
|
|
|
|
ips := make([]IPRecord, 0, 16) |
|
|
|
|
recType := req.recType |
|
|
|
|
|
|
|
|
|
now := time.Now() |
|
|
|
|
ipRecord := &IPRecord{ |
|
|
|
|
RCode: header.RCode, |
|
|
|
|
Expire: now.Add(time.Second * 600), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for { |
|
|
|
|
header, err := parser.AnswerHeader() |
|
|
|
|
if err != nil { |
|
|
|
@ -145,6 +175,15 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
|
|
|
|
if ttl == 0 { |
|
|
|
|
ttl = 600 |
|
|
|
|
} |
|
|
|
|
expire := now.Add(time.Duration(ttl) * time.Second) |
|
|
|
|
if ipRecord.Expire.After(expire) { |
|
|
|
|
ipRecord.Expire = expire |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if header.Type != recType { |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
switch header.Type { |
|
|
|
|
case dnsmessage.TypeA: |
|
|
|
|
ans, err := parser.AResource() |
|
|
|
@ -152,20 +191,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
|
|
|
|
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog() |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
ips = append(ips, IPRecord{ |
|
|
|
|
IP: net.IPAddress(ans.A[:]), |
|
|
|
|
Expire: now.Add(time.Duration(ttl) * time.Second), |
|
|
|
|
}) |
|
|
|
|
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:])) |
|
|
|
|
case dnsmessage.TypeAAAA: |
|
|
|
|
ans, err := parser.AAAAResource() |
|
|
|
|
if err != nil { |
|
|
|
|
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog() |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
ips = append(ips, IPRecord{ |
|
|
|
|
IP: net.IPAddress(ans.AAAA[:]), |
|
|
|
|
Expire: now.Add(time.Duration(ttl) * time.Second), |
|
|
|
|
}) |
|
|
|
|
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:])) |
|
|
|
|
default: |
|
|
|
|
if err := parser.SkipAnswer(); err != nil { |
|
|
|
|
newError("failed to skip answer").Base(err).WriteToLog() |
|
|
|
@ -173,24 +206,49 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if len(domain) > 0 && len(ips) > 0 { |
|
|
|
|
s.updateIP(domain, ips) |
|
|
|
|
var rec record |
|
|
|
|
switch recType { |
|
|
|
|
case dnsmessage.TypeA: |
|
|
|
|
rec.A = ipRecord |
|
|
|
|
case dnsmessage.TypeAAAA: |
|
|
|
|
rec.AAAA = ipRecord |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) { |
|
|
|
|
s.updateIP(domain, rec) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func isNewer(baseRec *IPRecord, newRec *IPRecord) bool { |
|
|
|
|
if newRec == nil { |
|
|
|
|
return false |
|
|
|
|
} |
|
|
|
|
if baseRec == nil { |
|
|
|
|
return true |
|
|
|
|
} |
|
|
|
|
return baseRec.Expire.Before(newRec.Expire) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) { |
|
|
|
|
func (s *ClassicNameServer) updateIP(domain string, newRec record) { |
|
|
|
|
s.Lock() |
|
|
|
|
|
|
|
|
|
newError("updating IP records for domain:", domain).AtDebug().WriteToLog() |
|
|
|
|
now := time.Now() |
|
|
|
|
eips := s.ips[domain] |
|
|
|
|
for _, ip := range eips { |
|
|
|
|
if ip.Expire.After(now) { |
|
|
|
|
ips = append(ips, ip) |
|
|
|
|
} |
|
|
|
|
rec := s.ips[domain] |
|
|
|
|
|
|
|
|
|
updated := false |
|
|
|
|
if isNewer(rec.A, newRec.A) { |
|
|
|
|
rec.A = newRec.A |
|
|
|
|
updated = true |
|
|
|
|
} |
|
|
|
|
if isNewer(rec.AAAA, newRec.AAAA) { |
|
|
|
|
rec.AAAA = newRec.AAAA |
|
|
|
|
updated = true |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if updated { |
|
|
|
|
s.ips[domain] = rec |
|
|
|
|
s.pub.Publish(domain, nil) |
|
|
|
|
} |
|
|
|
|
s.ips[domain] = ips |
|
|
|
|
s.pub.Publish(domain, nil) |
|
|
|
|
|
|
|
|
|
s.Unlock() |
|
|
|
|
common.Must(s.cleanup.Start()) |
|
|
|
@ -244,14 +302,15 @@ func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
|
|
|
|
|
return opt |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *ClassicNameServer) addPendingRequest(domain string) uint16 { |
|
|
|
|
func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 { |
|
|
|
|
id := uint16(atomic.AddUint32(&s.reqID, 1)) |
|
|
|
|
s.Lock() |
|
|
|
|
defer s.Unlock() |
|
|
|
|
|
|
|
|
|
s.requests[id] = pendingRequest{ |
|
|
|
|
domain: domain, |
|
|
|
|
expire: time.Now().Add(time.Second * 8), |
|
|
|
|
domain: domain, |
|
|
|
|
expire: time.Now().Add(time.Second * 8), |
|
|
|
|
recType: recType, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return id |
|
|
|
@ -274,7 +333,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
|
|
|
|
|
|
|
|
|
|
if option.IPv4Enable { |
|
|
|
|
msg := new(dnsmessage.Message) |
|
|
|
|
msg.Header.ID = s.addPendingRequest(domain) |
|
|
|
|
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA) |
|
|
|
|
msg.Header.RecursionDesired = true |
|
|
|
|
msg.Questions = []dnsmessage.Question{qA} |
|
|
|
|
if opt := s.getMsgOptions(); opt != nil { |
|
|
|
@ -285,7 +344,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
|
|
|
|
|
|
|
|
|
|
if option.IPv6Enable { |
|
|
|
|
msg := new(dnsmessage.Message) |
|
|
|
|
msg.Header.ID = s.addPendingRequest(domain) |
|
|
|
|
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA) |
|
|
|
|
msg.Header.RecursionDesired = true |
|
|
|
|
msg.Questions = []dnsmessage.Question{qAAAA} |
|
|
|
|
if opt := s.getMsgOptions(); opt != nil { |
|
|
|
@ -313,22 +372,44 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option
|
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) []net.IP { |
|
|
|
|
func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) { |
|
|
|
|
s.RLock() |
|
|
|
|
records, found := s.ips[domain] |
|
|
|
|
record, found := s.ips[domain] |
|
|
|
|
s.RUnlock() |
|
|
|
|
|
|
|
|
|
if found && len(records) > 0 { |
|
|
|
|
var ips []net.Address |
|
|
|
|
now := time.Now() |
|
|
|
|
for _, rec := range records { |
|
|
|
|
if rec.Expire.After(now) { |
|
|
|
|
ips = append(ips, rec.IP) |
|
|
|
|
} |
|
|
|
|
if !found { |
|
|
|
|
return nil, errRecordNotFound |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var ips []net.Address |
|
|
|
|
var lastErr error |
|
|
|
|
if option.IPv4Enable { |
|
|
|
|
a, err := record.A.getIPs() |
|
|
|
|
if err != nil { |
|
|
|
|
lastErr = err |
|
|
|
|
} |
|
|
|
|
return toNetIP(filterIP(ips, option)) |
|
|
|
|
ips = append(ips, a...) |
|
|
|
|
} |
|
|
|
|
return nil |
|
|
|
|
|
|
|
|
|
if option.IPv6Enable { |
|
|
|
|
aaaa, err := record.AAAA.getIPs() |
|
|
|
|
if err != nil { |
|
|
|
|
lastErr = err |
|
|
|
|
} |
|
|
|
|
ips = append(ips, aaaa...) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fmt.Println("IPs for ", domain, ": ", ips) |
|
|
|
|
|
|
|
|
|
if len(ips) > 0 { |
|
|
|
|
return toNetIP(ips), nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if lastErr != nil { |
|
|
|
|
return nil, lastErr |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return nil, dns_feature.ErrEmptyResponse |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func Fqdn(domain string) string { |
|
|
|
@ -341,9 +422,9 @@ func Fqdn(domain string) string {
|
|
|
|
|
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) { |
|
|
|
|
fqdn := Fqdn(domain) |
|
|
|
|
|
|
|
|
|
ips := s.findIPsForDomain(fqdn, option) |
|
|
|
|
if len(ips) > 0 { |
|
|
|
|
return ips, nil |
|
|
|
|
ips, err := s.findIPsForDomain(fqdn, option) |
|
|
|
|
if err != errRecordNotFound { |
|
|
|
|
return ips, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
sub := s.pub.Subscribe(fqdn) |
|
|
|
@ -352,9 +433,9 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option I
|
|
|
|
|
s.sendQuery(ctx, fqdn, option) |
|
|
|
|
|
|
|
|
|
for { |
|
|
|
|
ips := s.findIPsForDomain(fqdn, option) |
|
|
|
|
if len(ips) > 0 { |
|
|
|
|
return ips, nil |
|
|
|
|
ips, err := s.findIPsForDomain(fqdn, option) |
|
|
|
|
if err != errRecordNotFound { |
|
|
|
|
return ips, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
select { |
|
|
|
|