From 9957c64b4aa31f5af85dc2e5e03fec232e88ca82 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Thu, 21 Feb 2019 13:43:48 +0100 Subject: [PATCH] correctly propagate dns errors all the way through. the internal dns system can correctly handle the cases where: 1) domain has no A or AAAA records 2) domain doesn't exist fixes #1565 --- app/dns/server.go | 6 + app/dns/server_test.go | 23 ++++ app/dns/udpns.go | 195 ++++++++++++++++++++++---------- features/dns/client.go | 22 ++++ features/dns/localdns/client.go | 9 ++ proxy/dns/dns.go | 9 +- proxy/dns/dns_test.go | 66 ++++++++--- 7 files changed, 252 insertions(+), 78 deletions(-) diff --git a/app/dns/server.go b/app/dns/server.go index 351ed7ef..4ed1ca61 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -226,6 +226,9 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err if len(ips) > 0 { return ips, nil } + if err == dns.ErrEmptyResponse { + return nil, err + } if err != nil { newError("failed to lookup ip for domain ", domain, " at server ", ns.Name()).Base(err).WriteToLog() lastErr = err @@ -238,6 +241,9 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err if len(ips) > 0 { return ips, nil } + if err == dns.ErrEmptyResponse { + return nil, err + } if err != nil { newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog() lastErr = err diff --git a/app/dns/server_test.go b/app/dns/server_test.go index 90e2aa2d..a7e1e7b9 100644 --- a/app/dns/server_test.go +++ b/app/dns/server_test.go @@ -60,6 +60,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888") common.Must(err) ans.Answer = append(ans.Answer, rr) + } else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA { + ans.MsgHdr.Rcode = dns.RcodeNameError } } w.WriteMsg(ans) @@ -186,6 +188,27 @@ func TestUDPServer(t *testing.T) { } } + { + _, err := client.LookupIP("notexist.google.com") + if err == nil { + t.Fatal("nil error") + } + if r := feature_dns.RCodeFromError(err); r != uint16(dns.RcodeNameError) { + t.Fatal("expected NameError, but got ", r) + } + } + + { + clientv6 := client.(feature_dns.IPv6Lookup) + ips, err := clientv6.LookupIPv6("ipv4only.google.com") + if err != feature_dns.ErrEmptyResponse { + t.Fatal("error: ", err) + } + if len(ips) != 0 { + t.Fatal("ips: ", ips) + } + } + dnsServer.Shutdown() { diff --git a/app/dns/udpns.go b/app/dns/udpns.go index e9c70231..0e2967e5 100644 --- a/app/dns/udpns.go +++ b/app/dns/udpns.go @@ -5,36 +5,60 @@ package dns import ( "context" "encoding/binary" + fmt "fmt" "sync" "sync/atomic" "time" "golang.org/x/net/dns/dnsmessage" "v2ray.com/core/common" + "v2ray.com/core/common/errors" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol/dns" udp_proto "v2ray.com/core/common/protocol/udp" "v2ray.com/core/common/session" "v2ray.com/core/common/signal/pubsub" "v2ray.com/core/common/task" + dns_feature "v2ray.com/core/features/dns" "v2ray.com/core/features/routing" "v2ray.com/core/transport/internet/udp" ) +type record struct { + A *IPRecord + AAAA *IPRecord +} + type IPRecord struct { - IP net.Address + IP []net.Address Expire time.Time + RCode dnsmessage.RCode +} + +func (r *IPRecord) getIPs() ([]net.Address, error) { + if r == nil || r.Expire.Before(time.Now()) { + return nil, errRecordNotFound + } + if r.RCode != dnsmessage.RCodeSuccess { + return nil, dns_feature.RCodeError(r.RCode) + } + return r.IP, nil } type pendingRequest struct { - domain string - expire time.Time + domain string + expire time.Time + recType dnsmessage.Type } +var ( + errRecordNotFound = errors.New("record not found") +) + type ClassicNameServer struct { sync.RWMutex address net.Destination - ips map[string][]IPRecord + ips map[string]record requests map[uint16]pendingRequest pub *pubsub.Service udpServer *udp.Dispatcher @@ -46,7 +70,7 @@ type ClassicNameServer struct { func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer { s := &ClassicNameServer{ address: address, - ips: make(map[string][]IPRecord), + ips: make(map[string]record), requests: make(map[uint16]pendingRequest), clientIP: clientIP, pub: pubsub.NewService(), @@ -72,22 +96,23 @@ func (s *ClassicNameServer) Cleanup() error { return newError("nothing to do. stopping...") } - for domain, ips := range s.ips { - newIPs := make([]IPRecord, 0, len(ips)) - for _, ip := range ips { - if ip.Expire.After(now) { - newIPs = append(newIPs, ip) - } + for domain, record := range s.ips { + if record.A != nil && record.A.Expire.Before(now) { + record.A = nil } - if len(newIPs) == 0 { + if record.AAAA != nil && record.AAAA.Expire.Before(now) { + record.AAAA = nil + } + + if record.A == nil && record.AAAA == nil { delete(s.ips, domain) - } else if len(newIPs) < len(ips) { - s.ips[domain] = newIPs + } else { + s.ips[domain] = record } } if len(s.ips) == 0 { - s.ips = make(map[string][]IPRecord) + s.ips = make(map[string]record) } for id, req := range s.requests { @@ -130,9 +155,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot } domain := req.domain - ips := make([]IPRecord, 0, 16) + recType := req.recType now := time.Now() + ipRecord := &IPRecord{ + RCode: header.RCode, + Expire: now.Add(time.Second * 600), + } + for { header, err := parser.AnswerHeader() if err != nil { @@ -145,6 +175,15 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot if ttl == 0 { ttl = 600 } + expire := now.Add(time.Duration(ttl) * time.Second) + if ipRecord.Expire.After(expire) { + ipRecord.Expire = expire + } + + if header.Type != recType { + continue + } + switch header.Type { case dnsmessage.TypeA: ans, err := parser.AResource() @@ -152,20 +191,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog() break } - ips = append(ips, IPRecord{ - IP: net.IPAddress(ans.A[:]), - Expire: now.Add(time.Duration(ttl) * time.Second), - }) + ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:])) case dnsmessage.TypeAAAA: ans, err := parser.AAAAResource() if err != nil { newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog() break } - ips = append(ips, IPRecord{ - IP: net.IPAddress(ans.AAAA[:]), - Expire: now.Add(time.Duration(ttl) * time.Second), - }) + ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:])) default: if err := parser.SkipAnswer(); err != nil { newError("failed to skip answer").Base(err).WriteToLog() @@ -173,24 +206,49 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot } } - if len(domain) > 0 && len(ips) > 0 { - s.updateIP(domain, ips) + var rec record + switch recType { + case dnsmessage.TypeA: + rec.A = ipRecord + case dnsmessage.TypeAAAA: + rec.AAAA = ipRecord + } + + if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) { + s.updateIP(domain, rec) + } +} + +func isNewer(baseRec *IPRecord, newRec *IPRecord) bool { + if newRec == nil { + return false + } + if baseRec == nil { + return true } + return baseRec.Expire.Before(newRec.Expire) } -func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) { +func (s *ClassicNameServer) updateIP(domain string, newRec record) { s.Lock() newError("updating IP records for domain:", domain).AtDebug().WriteToLog() - now := time.Now() - eips := s.ips[domain] - for _, ip := range eips { - if ip.Expire.After(now) { - ips = append(ips, ip) - } + rec := s.ips[domain] + + updated := false + if isNewer(rec.A, newRec.A) { + rec.A = newRec.A + updated = true + } + if isNewer(rec.AAAA, newRec.AAAA) { + rec.AAAA = newRec.AAAA + updated = true + } + + if updated { + s.ips[domain] = rec + s.pub.Publish(domain, nil) } - s.ips[domain] = ips - s.pub.Publish(domain, nil) s.Unlock() common.Must(s.cleanup.Start()) @@ -244,14 +302,15 @@ func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource { return opt } -func (s *ClassicNameServer) addPendingRequest(domain string) uint16 { +func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 { id := uint16(atomic.AddUint32(&s.reqID, 1)) s.Lock() defer s.Unlock() s.requests[id] = pendingRequest{ - domain: domain, - expire: time.Now().Add(time.Second * 8), + domain: domain, + expire: time.Now().Add(time.Second * 8), + recType: recType, } return id @@ -274,7 +333,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess if option.IPv4Enable { msg := new(dnsmessage.Message) - msg.Header.ID = s.addPendingRequest(domain) + msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA) msg.Header.RecursionDesired = true msg.Questions = []dnsmessage.Question{qA} if opt := s.getMsgOptions(); opt != nil { @@ -285,7 +344,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess if option.IPv6Enable { msg := new(dnsmessage.Message) - msg.Header.ID = s.addPendingRequest(domain) + msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA) msg.Header.RecursionDesired = true msg.Questions = []dnsmessage.Question{qAAAA} if opt := s.getMsgOptions(); opt != nil { @@ -313,22 +372,44 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option } } -func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) []net.IP { +func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) { s.RLock() - records, found := s.ips[domain] + record, found := s.ips[domain] s.RUnlock() - if found && len(records) > 0 { - var ips []net.Address - now := time.Now() - for _, rec := range records { - if rec.Expire.After(now) { - ips = append(ips, rec.IP) - } + if !found { + return nil, errRecordNotFound + } + + var ips []net.Address + var lastErr error + if option.IPv4Enable { + a, err := record.A.getIPs() + if err != nil { + lastErr = err } - return toNetIP(filterIP(ips, option)) + ips = append(ips, a...) } - return nil + + if option.IPv6Enable { + aaaa, err := record.AAAA.getIPs() + if err != nil { + lastErr = err + } + ips = append(ips, aaaa...) + } + + fmt.Println("IPs for ", domain, ": ", ips) + + if len(ips) > 0 { + return toNetIP(ips), nil + } + + if lastErr != nil { + return nil, lastErr + } + + return nil, dns_feature.ErrEmptyResponse } func Fqdn(domain string) string { @@ -341,9 +422,9 @@ func Fqdn(domain string) string { func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) { fqdn := Fqdn(domain) - ips := s.findIPsForDomain(fqdn, option) - if len(ips) > 0 { - return ips, nil + ips, err := s.findIPsForDomain(fqdn, option) + if err != errRecordNotFound { + return ips, err } sub := s.pub.Subscribe(fqdn) @@ -352,9 +433,9 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option I s.sendQuery(ctx, fqdn, option) for { - ips := s.findIPsForDomain(fqdn, option) - if len(ips) > 0 { - return ips, nil + ips, err := s.findIPsForDomain(fqdn, option) + if err != errRecordNotFound { + return ips, err } select { diff --git a/features/dns/client.go b/features/dns/client.go index fcb55e53..76f63591 100644 --- a/features/dns/client.go +++ b/features/dns/client.go @@ -1,7 +1,9 @@ package dns import ( + "v2ray.com/core/common/errors" "v2ray.com/core/common/net" + "v2ray.com/core/common/serial" "v2ray.com/core/features" ) @@ -35,3 +37,23 @@ type IPv6Lookup interface { func ClientType() interface{} { return (*Client)(nil) } + +// ErrEmptyResponse indicates that DNS query succeeded but no answer was returned. +var ErrEmptyResponse = errors.New("empty response") + +type RCodeError uint16 + +func (e RCodeError) Error() string { + return serial.Concat("rcode: ", uint16(e)) +} + +func RCodeFromError(err error) uint16 { + if err == nil { + return 0 + } + cause := errors.Cause(err) + if r, ok := cause.(RCodeError); ok { + return uint16(r) + } + return 0 +} diff --git a/features/dns/localdns/client.go b/features/dns/localdns/client.go index 563de379..02571907 100644 --- a/features/dns/localdns/client.go +++ b/features/dns/localdns/client.go @@ -32,6 +32,9 @@ func (*Client) LookupIP(host string) ([]net.IP, error) { parsedIPs = append(parsedIPs, parsed.IP()) } } + if len(parsedIPs) == 0 { + return nil, dns.ErrEmptyResponse + } return parsedIPs, nil } @@ -47,6 +50,9 @@ func (c *Client) LookupIPv4(host string) ([]net.IP, error) { ipv4 = append(ipv4, ip) } } + if len(ipv4) == 0 { + return nil, dns.ErrEmptyResponse + } return ipv4, nil } @@ -62,6 +68,9 @@ func (c *Client) LookupIPv6(host string) ([]net.IP, error) { ipv6 = append(ipv6, ip) } } + if len(ipv6) == 0 { + return nil, dns.ErrEmptyResponse + } return ipv6, nil } diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index b872e557..4117c0c3 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -218,20 +218,17 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, ips, err = h.ipv6Lookup.LookupIPv6(domain) } - if err != nil { + rcode := dns.RCodeFromError(err) + if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse { newError("ip query").Base(err).WriteToLog() return } - if len(ips) == 0 { - return - } - b := buf.New() rawBytes := b.Extend(buf.Size) builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ ID: id, - RCode: dnsmessage.RCodeSuccess, + RCode: dnsmessage.RCode(rcode), RecursionAvailable: true, RecursionDesired: true, Response: true, diff --git a/proxy/dns/dns_test.go b/proxy/dns/dns_test.go index 9e94b3c2..da917596 100644 --- a/proxy/dns/dns_test.go +++ b/proxy/dns/dns_test.go @@ -63,6 +63,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888") common.Must(err) ans.Answer = append(ans.Answer, rr) + } else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA { + ans.MsgHdr.Rcode = dns.RcodeNameError } } w.WriteMsg(ans) @@ -128,26 +130,60 @@ func TestUDPDNSTunnel(t *testing.T) { common.Must(v.Start()) defer v.Close() - m1 := new(dns.Msg) - m1.Id = dns.Id() - m1.RecursionDesired = true - m1.Question = make([]dns.Question, 1) - m1.Question[0] = dns.Question{"google.com.", dns.TypeA, dns.ClassINET} + { + m1 := new(dns.Msg) + m1.Id = dns.Id() + m1.RecursionDesired = true + m1.Question = make([]dns.Question, 1) + m1.Question[0] = dns.Question{"google.com.", dns.TypeA, dns.ClassINET} - c := new(dns.Client) - in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort))) - common.Must(err) + c := new(dns.Client) + in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort))) + common.Must(err) - if len(in.Answer) != 1 { - t.Fatal("len(answer): ", len(in.Answer)) + if len(in.Answer) != 1 { + t.Fatal("len(answer): ", len(in.Answer)) + } + + rr, ok := in.Answer[0].(*dns.A) + if !ok { + t.Fatal("not A record") + } + if r := cmp.Diff(rr.A[:], net.IP{8, 8, 8, 8}); r != "" { + t.Error(r) + } } - rr, ok := in.Answer[0].(*dns.A) - if !ok { - t.Fatal("not A record") + { + m1 := new(dns.Msg) + m1.Id = dns.Id() + m1.RecursionDesired = true + m1.Question = make([]dns.Question, 1) + m1.Question[0] = dns.Question{"ipv4only.google.com.", dns.TypeAAAA, dns.ClassINET} + + c := new(dns.Client) + in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort))) + common.Must(err) + + if len(in.Answer) != 0 { + t.Fatal("len(answer): ", len(in.Answer)) + } } - if r := cmp.Diff(rr.A[:], net.IP{8, 8, 8, 8}); r != "" { - t.Error(r) + + { + m1 := new(dns.Msg) + m1.Id = dns.Id() + m1.RecursionDesired = true + m1.Question = make([]dns.Question, 1) + m1.Question[0] = dns.Question{"notexist.google.com.", dns.TypeAAAA, dns.ClassINET} + + c := new(dns.Client) + in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort))) + common.Must(err) + + if in.Rcode != dns.RcodeNameError { + t.Error("expected NameError, but got ", in.Rcode) + } } }