From 2fb77d6911eb31ae8e8ab40943b86564b3c64989 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Tue, 26 Jun 2018 15:16:45 +0200 Subject: [PATCH] consume context in local nameserver. --- app/dns/nameserver.go | 35 +++++++++++++++++++---------------- app/dns/nameserver_test.go | 21 +++++++++++++++++++++ app/dns/udpns.go | 8 ++++++++ common/net/system.go | 2 ++ 4 files changed, 50 insertions(+), 16 deletions(-) create mode 100644 app/dns/nameserver_test.go diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index 43b8c39e..3c2aee64 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -2,31 +2,34 @@ package dns import ( "context" - "time" "v2ray.com/core/common/net" ) -var ( - multiQuestionDNS = map[net.Address]bool{ - net.IPAddress([]byte{8, 8, 8, 8}): true, - net.IPAddress([]byte{8, 8, 4, 4}): true, - net.IPAddress([]byte{9, 9, 9, 9}): true, - } -) - -type ARecord struct { - IPs []net.IP - Expire time.Time -} - type NameServer interface { QueryIP(ctx context.Context, domain string) ([]net.IP, error) } type LocalNameServer struct { + resolver net.Resolver } -func (*LocalNameServer) QueryIP(ctx context.Context, domain string) ([]net.IP, error) { - return net.LookupIP(domain) +func (s *LocalNameServer) QueryIP(ctx context.Context, domain string) ([]net.IP, error) { + ipAddr, err := s.resolver.LookupIPAddr(ctx, domain) + if err != nil { + return nil, err + } + var ips []net.IP + for _, addr := range ipAddr { + ips = append(ips, addr.IP) + } + return ips, nil +} + +func NewLocalNameServer() *LocalNameServer { + return &LocalNameServer{ + resolver: net.Resolver{ + PreferGo: true, + }, + } } diff --git a/app/dns/nameserver_test.go b/app/dns/nameserver_test.go new file mode 100644 index 00000000..1e416782 --- /dev/null +++ b/app/dns/nameserver_test.go @@ -0,0 +1,21 @@ +package dns_test + +import ( + "context" + "testing" + "time" + + . "v2ray.com/core/app/dns" + . "v2ray.com/ext/assert" +) + +func TestLocalNameServer(t *testing.T) { + assert := With(t) + + s := NewLocalNameServer() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + ips, err := s.QueryIP(ctx, "google.com") + cancel() + assert(err, IsNil) + assert(len(ips), GreaterThan, 0) +} diff --git a/app/dns/udpns.go b/app/dns/udpns.go index 9b025784..e5b5a679 100644 --- a/app/dns/udpns.go +++ b/app/dns/udpns.go @@ -16,6 +16,14 @@ import ( "v2ray.com/core/transport/internet/udp" ) +var ( + multiQuestionDNS = map[net.Address]bool{ + net.IPAddress([]byte{8, 8, 8, 8}): true, + net.IPAddress([]byte{8, 8, 4, 4}): true, + net.IPAddress([]byte{9, 9, 9, 9}): true, + } +) + type IPRecord struct { IP net.IP Expire time.Time diff --git a/common/net/system.go b/common/net/system.go index d4811cda..17febb1c 100644 --- a/common/net/system.go +++ b/common/net/system.go @@ -51,3 +51,5 @@ type TCPListener = net.TCPListener type UnixListener = net.UnixListener var ResolveUnixAddr = net.ResolveUnixAddr + +type Resolver = net.Resolver