From 6ef77246abff1e08df5a6eba305a88120e45feb3 Mon Sep 17 00:00:00 2001 From: vcptr <51714622+vcptr@users.noreply.github.com> Date: Sat, 29 Jun 2019 23:43:30 +0800 Subject: [PATCH] add DOH dns client --- app/dns/dnscommon.go | 235 +++++++++++++++++++++++ app/dns/dnscommon_test.go | 166 ++++++++++++++++ app/dns/dohdns.go | 315 +++++++++++++++++++++++++++++++ app/dns/server.go | 54 +++++- app/dns/udpns.go | 269 ++++---------------------- app/proxyman/outbound/handler.go | 5 +- common/mux/client.go | 3 +- common/session/context.go | 14 ++ infra/conf/v2ray.go | 28 +-- infra/conf/v2ray_test.go | 34 +++- 10 files changed, 875 insertions(+), 248 deletions(-) create mode 100644 app/dns/dnscommon.go create mode 100644 app/dns/dnscommon_test.go create mode 100644 app/dns/dohdns.go diff --git a/app/dns/dnscommon.go b/app/dns/dnscommon.go new file mode 100644 index 00000000..77a0ecda --- /dev/null +++ b/app/dns/dnscommon.go @@ -0,0 +1,235 @@ +// +build !confonly + +package dns + +import ( + "encoding/binary" + "time" + + "golang.org/x/net/dns/dnsmessage" + "v2ray.com/core/common" + "v2ray.com/core/common/errors" + "v2ray.com/core/common/net" + dns_feature "v2ray.com/core/features/dns" +) + +func Fqdn(domain string) string { + if len(domain) > 0 && domain[len(domain)-1] == '.' { + return domain + } + return domain + "." +} + +type record struct { + A *IPRecord + AAAA *IPRecord +} + +type IPRecord struct { + ReqID uint16 + IP []net.Address + Expire time.Time + RCode dnsmessage.RCode +} + +func (r *IPRecord) getIPs() ([]net.Address, error) { + if r == nil || r.Expire.Before(time.Now()) { + return nil, errRecordNotFound + } + if r.RCode != dnsmessage.RCodeSuccess { + return nil, dns_feature.RCodeError(r.RCode) + } + return r.IP, nil +} + +func isNewer(baseRec *IPRecord, newRec *IPRecord) bool { + if newRec == nil { + return false + } + if baseRec == nil { + return true + } + return baseRec.Expire.Before(newRec.Expire) +} + +var ( + errRecordNotFound = errors.New("record not found") +) + +type dnsRequest struct { + reqType dnsmessage.Type + domain string + start time.Time + expire time.Time + msg *dnsmessage.Message +} + +func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource { + if len(clientIP) == 0 { + return nil + } + + var netmask int + var family uint16 + + if len(clientIP) == 4 { + family = 1 + netmask = 24 // 24 for IPV4, 96 for IPv6 + } else { + family = 2 + netmask = 96 + } + + b := make([]byte, 4) + binary.BigEndian.PutUint16(b[0:], family) + b[2] = byte(netmask) + b[3] = 0 + switch family { + case 1: + ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8)) + needLength := (netmask + 8 - 1) / 8 // division rounding up + b = append(b, ip[:needLength]...) + case 2: + ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8)) + needLength := (netmask + 8 - 1) / 8 // division rounding up + b = append(b, ip[:needLength]...) + } + + const EDNS0SUBNET = 0x08 + + opt := new(dnsmessage.Resource) + common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true)) + + opt.Body = &dnsmessage.OPTResource{ + Options: []dnsmessage.Option{ + { + Code: EDNS0SUBNET, + Data: b, + }, + }, + } + + return opt +} + +func buildReqMsgs(domain string, option IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest { + qA := dnsmessage.Question{ + Name: dnsmessage.MustNewName(domain), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + } + + qAAAA := dnsmessage.Question{ + Name: dnsmessage.MustNewName(domain), + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + } + + var reqs []*dnsRequest + now := time.Now() + + if option.IPv4Enable { + msg := new(dnsmessage.Message) + msg.Header.ID = reqIDGen() + msg.Header.RecursionDesired = true + msg.Questions = []dnsmessage.Question{qA} + if reqOpts != nil { + msg.Additionals = append(msg.Additionals, *reqOpts) + } + reqs = append(reqs, &dnsRequest{ + reqType: dnsmessage.TypeA, + domain: domain, + start: now, + msg: msg, + }) + } + + if option.IPv6Enable { + msg := new(dnsmessage.Message) + msg.Header.ID = reqIDGen() + msg.Header.RecursionDesired = true + msg.Questions = []dnsmessage.Question{qAAAA} + if reqOpts != nil { + msg.Additionals = append(msg.Additionals, *reqOpts) + } + reqs = append(reqs, &dnsRequest{ + reqType: dnsmessage.TypeAAAA, + domain: domain, + start: now, + msg: msg, + }) + } + + return reqs +} + +// parseResponse parse DNS answers from the returned payload +func parseResponse(payload []byte) (*IPRecord, error) { + var parser dnsmessage.Parser + h, err := parser.Start(payload) + if err != nil { + return nil, newError("failed to parse DNS response").Base(err).AtWarning() + } + if err := parser.SkipAllQuestions(); err != nil { + return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning() + } + + now := time.Now() + var ipRecExpire time.Time + if h.RCode != dnsmessage.RCodeSuccess { + // A default TTL, maybe a negtive cache + ipRecExpire = now.Add(time.Second * 120) + } + + ipRecord := &IPRecord{ + ReqID: h.ID, + RCode: h.RCode, + Expire: ipRecExpire, + } + +L: + for { + ah, err := parser.AnswerHeader() + if err != nil { + if err != dnsmessage.ErrSectionDone { + newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog() + } + break + } + + switch ah.Type { + case dnsmessage.TypeA: + ans, err := parser.AResource() + if err != nil { + newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog() + break L + } + ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:])) + case dnsmessage.TypeAAAA: + ans, err := parser.AAAAResource() + if err != nil { + newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog() + break L + } + ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:])) + default: + if err := parser.SkipAnswer(); err != nil { + newError("failed to skip answer").Base(err).WriteToLog() + break L + } + continue + } + + if ipRecord.Expire.IsZero() { + ttl := ah.TTL + if ttl < 600 { + // at least 10 mins TTL + ipRecord.Expire = now.Add(time.Minute * 10) + } else { + ipRecord.Expire = now.Add(time.Duration(ttl) * time.Second) + } + } + } + + return ipRecord, nil +} diff --git a/app/dns/dnscommon_test.go b/app/dns/dnscommon_test.go new file mode 100644 index 00000000..62a35012 --- /dev/null +++ b/app/dns/dnscommon_test.go @@ -0,0 +1,166 @@ +// +build !confonly + +package dns + +import ( + "math/rand" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/miekg/dns" + "golang.org/x/net/dns/dnsmessage" + "v2ray.com/core/common" + "v2ray.com/core/common/net" + v2net "v2ray.com/core/common/net" +) + +func Test_parseResponse(t *testing.T) { + type args struct { + payload []byte + } + + var p [][]byte + + ans := new(dns.Msg) + ans.Id = 0 + p = append(p, common.Must2(ans.Pack()).([]byte)) + + p = append(p, []byte{}) + + ans = new(dns.Msg) + ans.Id = 1 + ans.Answer = append(ans.Answer, + common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR), + common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR), + common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")).(dns.RR), + common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")).(dns.RR), + ) + p = append(p, common.Must2(ans.Pack()).([]byte)) + + ans = new(dns.Msg) + ans.Id = 2 + ans.Answer = append(ans.Answer, + common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR), + common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR), + common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR), + common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")).(dns.RR), + common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")).(dns.RR), + common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")).(dns.RR), + ) + p = append(p, common.Must2(ans.Pack()).([]byte)) + + tests := []struct { + name string + want *IPRecord + wantErr bool + }{ + {"empty", + &IPRecord{0, []v2net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess}, + false, + }, + {"error", + nil, + true, + }, + {"a record", + &IPRecord{1, []v2net.Address{v2net.ParseAddress("8.8.8.8"), v2net.ParseAddress("8.8.4.4")}, + time.Time{}, dnsmessage.RCodeSuccess}, + false, + }, + {"aaaa record", + &IPRecord{2, []v2net.Address{v2net.ParseAddress("2001::123:8888"), v2net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess}, + false, + }, + } + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseResponse(p[i]) + if (err != nil) != tt.wantErr { + t.Errorf("handleResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if got != nil { + // reset the time + got.Expire = time.Time{} + } + if cmp.Diff(got, tt.want) != "" { + t.Errorf(cmp.Diff(got, tt.want)) + // t.Errorf("handleResponse() = %#v, want %#v", got, tt.want) + } + }) + } +} + +func Test_buildReqMsgs(t *testing.T) { + + stubID := func() uint16 { + return uint16(rand.Uint32()) + } + type args struct { + domain string + option IPOption + reqOpts *dnsmessage.Resource + } + tests := []struct { + name string + args args + want int + }{ + {"dual stack", args{"test.com", IPOption{true, true}, nil}, 2}, + {"ipv4 only", args{"test.com", IPOption{true, false}, nil}, 1}, + {"ipv6 only", args{"test.com", IPOption{false, true}, nil}, 1}, + {"none/error", args{"test.com", IPOption{false, false}, nil}, 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := buildReqMsgs(tt.args.domain, tt.args.option, stubID, tt.args.reqOpts); !(len(got) == tt.want) { + t.Errorf("buildReqMsgs() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_genEDNS0Options(t *testing.T) { + type args struct { + clientIP net.IP + } + tests := []struct { + name string + args args + want *dnsmessage.Resource + }{ + // TODO: Add test cases. + {"ipv4", args{net.ParseIP("4.3.2.1")}, nil}, + {"ipv6", args{net.ParseIP("2001::4321")}, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := genEDNS0Options(tt.args.clientIP); got == nil { + t.Errorf("genEDNS0Options() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFqdn(t *testing.T) { + type args struct { + domain string + } + tests := []struct { + name string + args args + want string + }{ + {"with fqdn", args{"www.v2ray.com."}, "www.v2ray.com."}, + {"without fqdn", args{"www.v2ray.com"}, "www.v2ray.com."}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Fqdn(tt.args.domain); got != tt.want { + t.Errorf("Fqdn() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/app/dns/dohdns.go b/app/dns/dohdns.go new file mode 100644 index 00000000..c75c8abc --- /dev/null +++ b/app/dns/dohdns.go @@ -0,0 +1,315 @@ +// +build !confonly + +package dns + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "net/http" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/dns/dnsmessage" + "v2ray.com/core/common" + "v2ray.com/core/common/dice" + "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol/dns" + "v2ray.com/core/common/session" + "v2ray.com/core/common/signal/pubsub" + "v2ray.com/core/common/task" + "v2ray.com/core/features/routing" +) + +// DoHNameServer implimented DNS over HTTPS (RFC8484) Wire Format, +// which is compatiable with traditional dns over udp(RFC1035), +// thus most of the DOH implimentation is copied from udpns.go +type DoHNameServer struct { + sync.RWMutex + dispatcher routing.Dispatcher + dohDests []net.Destination + ips map[string]record + pub *pubsub.Service + cleanup *task.Periodic + reqID uint32 + clientIP net.IP + httpClient *http.Client + dohURL string + name string +} + +func NewDoHNameServer(dests []net.Destination, dohHost string, dispatcher routing.Dispatcher, clientIP net.IP) *DoHNameServer { + + s := NewDoHLocalNameServer(dohHost, clientIP) + s.name = "DOH:" + dohHost + s.dispatcher = dispatcher + s.dohDests = dests + + // Dispatched connection will be closed (interupted) after each request + // This makes DOH inefficient without a keeped-alive connection + // See: core/app/proxyman/outbound/handler.go:113 + // Using mux (https request wrapped in a stream layer) improves the situation. + // Recommand to use NewDoHLocalNameServer (DOHL:) if v2ray instance is running on + // a normal network eg. the server side of v2ray + tr := &http.Transport{ + MaxIdleConns: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + DialContext: s.DialContext, + } + + dispatchedClient := &http.Client{ + Transport: tr, + Timeout: 16 * time.Second, + } + + s.httpClient = dispatchedClient + return s +} + +func NewDoHLocalNameServer(dohHost string, clientIP net.IP) *DoHNameServer { + s := &DoHNameServer{ + httpClient: http.DefaultClient, + ips: make(map[string]record), + clientIP: clientIP, + pub: pubsub.NewService(), + name: "DOHL:" + dohHost, + dohURL: fmt.Sprintf("https://%s/dns-query", dohHost), + } + s.cleanup = &task.Periodic{ + Interval: time.Minute, + Execute: s.Cleanup, + } + return s +} + +func (s *DoHNameServer) Name() string { + return s.name +} + +func (s *DoHNameServer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + + dest := s.dohDests[dice.Roll(len(s.dohDests))] + + link, err := s.dispatcher.Dispatch(ctx, dest) + if err != nil { + return nil, err + } + return net.NewConnection( + net.ConnectionInputMulti(link.Writer), + net.ConnectionOutputMulti(link.Reader), + ), nil +} + +func (s *DoHNameServer) Cleanup() error { + now := time.Now() + s.Lock() + defer s.Unlock() + + if len(s.ips) == 0 { + return newError("nothing to do. stopping...") + } + + for domain, record := range s.ips { + if record.A != nil && record.A.Expire.Before(now) { + record.A = nil + } + if record.AAAA != nil && record.AAAA.Expire.Before(now) { + record.AAAA = nil + } + + if record.A == nil && record.AAAA == nil { + newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() + delete(s.ips, domain) + } else { + s.ips[domain] = record + } + } + + if len(s.ips) == 0 { + s.ips = make(map[string]record) + } + + return nil +} + +func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { + elapsed := time.Since(req.start) + newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() + + s.Lock() + rec := s.ips[req.domain] + updated := false + + switch req.reqType { + case dnsmessage.TypeA: + if isNewer(rec.A, ipRec) { + rec.A = ipRec + updated = true + } + case dnsmessage.TypeAAAA: + if isNewer(rec.AAAA, ipRec) { + rec.AAAA = ipRec + updated = true + } + } + + if updated { + s.ips[req.domain] = rec + s.pub.Publish(req.domain, nil) + } + + s.Unlock() + common.Must(s.cleanup.Start()) +} + +func (s *DoHNameServer) newReqID() uint16 { + return uint16(atomic.AddUint32(&s.reqID, 1)) +} + +func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, option IPOption) { + newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx)) + + reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP)) + + var deadline time.Time + if d, ok := ctx.Deadline(); ok { + deadline = d + } else { + deadline = time.Now().Add(time.Second * 8) + } + + for _, req := range reqs { + + go func(r *dnsRequest) { + + // generate new context for each req, using same context + // may cause reqs all aborted if any one encounter an error + dnsCtx := context.Background() + + // reserve internal dns server requested Inbound + if inbound := session.InboundFromContext(ctx); inbound != nil { + dnsCtx = session.ContextWithInbound(dnsCtx, inbound) + } + + dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ + Protocol: "https", + }) + + // forced to use mux for DOH + dnsCtx = session.ContextWithMuxPrefered(dnsCtx, true) + + dnsCtx, cancel := context.WithDeadline(dnsCtx, deadline) + defer cancel() + + b, _ := dns.PackMessage(r.msg) + resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes()) + if err != nil { + newError("failed to retrive response").Base(err).AtError().WriteToLog() + return + } + rec, err := parseResponse(resp) + if err != nil { + newError("failed to handle DOH response").Base(err).AtError().WriteToLog() + return + } + s.updateIP(r, rec) + }(req) + } +} + +func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, error) { + + body := bytes.NewBuffer(b) + req, err := http.NewRequest("POST", s.dohURL, body) + if err != nil { + return nil, err + } + + req.Header.Add("Accept", "application/dns-message") + req.Header.Add("Content-Type", "application/dns-message") + + resp, err := s.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + err = fmt.Errorf("DOH HTTPS server returned with non-OK code %d", resp.StatusCode) + return nil, err + } + + return ioutil.ReadAll(resp.Body) +} + +func (s *DoHNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) { + s.RLock() + record, found := s.ips[domain] + s.RUnlock() + + if !found { + return nil, errRecordNotFound + } + + var ips []net.Address + var lastErr error + if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess { + aaaa, err := record.AAAA.getIPs() + if err != nil { + lastErr = err + } + ips = append(ips, aaaa...) + } + + if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess { + a, err := record.A.getIPs() + if err != nil { + lastErr = err + } + ips = append(ips, a...) + } + + if len(ips) > 0 { + return toNetIP(ips), nil + } + + if lastErr != nil { + return nil, lastErr + } + + return nil, errRecordNotFound +} + +// QueryIP is called from dns.Server->queryIPTimeout +func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) { + + fqdn := Fqdn(domain) + + ips, err := s.findIPsForDomain(fqdn, option) + if err != errRecordNotFound { + newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() + return ips, err + } + + sub := s.pub.Subscribe(fqdn) + defer sub.Close() + + s.sendQuery(ctx, fqdn, option) + + for { + ips, err := s.findIPsForDomain(fqdn, option) + if err != errRecordNotFound { + return ips, err + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-sub.Wait(): + } + } +} diff --git a/app/dns/server.go b/app/dns/server.go index c1c85e93..6d70cffa 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -6,6 +6,8 @@ package dns import ( "context" + "fmt" + "strings" "sync" "time" @@ -87,6 +89,49 @@ func New(ctx context.Context, config *Config) (*Server, error) { address := endpoint.Address.AsAddress() if address.Family().IsDomain() && address.Domain() == "localhost" { server.clients = append(server.clients, NewLocalNameServer()) + newError("DNS: localhost inited").AtInfo().WriteToLog() + } else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOHL_") { + dohHost := address.Domain()[5:] + server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, server.clientIP)) + newError("DNS: DOH - Local inited for https://", dohHost).AtInfo().WriteToLog() + } else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOH_") { + // DOH_ prefix makes net.Address think it's a domain + // need to process the real address here. + dohHost := address.Domain()[4:] + dohAddr := net.ParseAddress(dohHost) + dohIP := dohHost + var dests []net.Destination + + if dohAddr.Family().IsDomain() { + // resolve DOH server in advance + ips, err := net.LookupIP(dohAddr.Domain()) + if err != nil || len(ips) == 0 { + return 0 + } + for _, ip := range ips { + dohIP := ip.String() + if len(ip) == net.IPv6len { + dohIP = fmt.Sprintf("[%s]", dohIP) + } + dohdest, _ := net.ParseDestination(fmt.Sprintf("tcp:%s:443", dohIP)) + dests = append(dests, dohdest) + } + } else { + // rfc8484, DOH service only use port 443 + dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:443", dohIP)) + if err != nil { + return 0 + } + dests = []net.Destination{dest} + } + + // need the core dispatcher, register DOHClient at callback + idx := len(server.clients) + server.clients = append(server.clients, nil) + common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) { + server.clients[idx] = NewDoHNameServer(dests, dohHost, d, server.clientIP) + newError("DNS: DOH - Remote client inited for https://", dohHost).AtInfo().WriteToLog() + })) } else { dest := endpoint.AsDestination() if dest.Network == net.Network_Unknown { @@ -100,6 +145,7 @@ func New(ctx context.Context, config *Config) (*Server, error) { server.clients[idx] = NewClassicNameServer(dest, d, server.clientIP) })) } + newError("DNS: UDP client inited for ", dest.NetAddr()).AtInfo().WriteToLog() } return len(server.clients) - 1 } @@ -272,10 +318,16 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err return nil, newError("empty domain name") } + // normalize the FQDN form query if domain[len(domain)-1] == '.' { domain = domain[:len(domain)-1] } + // skip domain without any dot + if strings.Index(domain, ".") == -1 { + return nil, newError("invalid domain name") + } + ips := s.lookupStatic(domain, option, 0) if ips != nil && ips[0].Family().IsIP() { newError("returning ", len(ips), " IPs for domain ", domain).WriteToLog() @@ -331,7 +383,7 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err } } - return nil, newError("returning nil for domain ", domain).Base(lastErr) + return nil, dns.ErrEmptyResponse.Base(lastErr) } func init() { diff --git a/app/dns/udpns.go b/app/dns/udpns.go index 853c95d2..b5321592 100644 --- a/app/dns/udpns.go +++ b/app/dns/udpns.go @@ -4,14 +4,13 @@ package dns import ( "context" - "encoding/binary" + "strings" "sync" "sync/atomic" "time" "golang.org/x/net/dns/dnsmessage" "v2ray.com/core/common" - "v2ray.com/core/common/errors" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol/dns" udp_proto "v2ray.com/core/common/protocol/udp" @@ -23,42 +22,12 @@ import ( "v2ray.com/core/transport/internet/udp" ) -type record struct { - A *IPRecord - AAAA *IPRecord -} - -type IPRecord struct { - IP []net.Address - Expire time.Time - RCode dnsmessage.RCode -} - -func (r *IPRecord) getIPs() ([]net.Address, error) { - if r == nil || r.Expire.Before(time.Now()) { - return nil, errRecordNotFound - } - if r.RCode != dnsmessage.RCodeSuccess { - return nil, dns_feature.RCodeError(r.RCode) - } - return r.IP, nil -} - -type pendingRequest struct { - domain string - expire time.Time - recType dnsmessage.Type -} - -var ( - errRecordNotFound = errors.New("record not found") -) - type ClassicNameServer struct { sync.RWMutex + name string address net.Destination ips map[string]record - requests map[uint16]pendingRequest + requests map[uint16]dnsRequest pub *pubsub.Service udpServer *udp.Dispatcher cleanup *task.Periodic @@ -70,9 +39,10 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher s := &ClassicNameServer{ address: address, ips: make(map[string]record), - requests: make(map[uint16]pendingRequest), + requests: make(map[uint16]dnsRequest), clientIP: clientIP, pub: pubsub.NewService(), + name: strings.ToUpper(address.String()), } s.cleanup = &task.Periodic{ Interval: time.Minute, @@ -83,7 +53,7 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher } func (s *ClassicNameServer) Name() string { - return s.address.String() + return s.name } func (s *ClassicNameServer) Cleanup() error { @@ -92,7 +62,7 @@ func (s *ClassicNameServer) Cleanup() error { defer s.Unlock() if len(s.ips) == 0 && len(s.requests) == 0 { - return newError("nothing to do. stopping...") + return newError(s.name, " nothing to do. stopping...") } for domain, record := range s.ips { @@ -121,123 +91,52 @@ func (s *ClassicNameServer) Cleanup() error { } if len(s.requests) == 0 { - s.requests = make(map[uint16]pendingRequest) + s.requests = make(map[uint16]dnsRequest) } return nil } func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) { - payload := packet.Payload - var parser dnsmessage.Parser - header, err := parser.Start(payload.Bytes()) + ipRec, err := parseResponse(packet.Payload.Bytes()) if err != nil { - newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog() - return - } - if err := parser.SkipAllQuestions(); err != nil { - newError("failed to skip questions in DNS response").Base(err).AtWarning().WriteToLog() + newError(s.name, " fail to parse responsed DNS udp").AtError().WriteToLog() return } - id := header.ID s.Lock() - req, f := s.requests[id] - if f { + id := ipRec.ReqID + req, ok := s.requests[id] + if ok { + // remove the pending request delete(s.requests, id) } s.Unlock() - - if !f { + if !ok { + newError(s.name, " cannot find the pending request").AtError().WriteToLog() return } - domain := req.domain - recType := req.recType - - now := time.Now() - ipRecord := &IPRecord{ - RCode: header.RCode, - Expire: now.Add(time.Second * 600), - } - -L: - for { - header, err := parser.AnswerHeader() - if err != nil { - if err != dnsmessage.ErrSectionDone { - newError("failed to parse answer section for domain: ", domain).Base(err).WriteToLog() - } - break - } - ttl := header.TTL - if ttl == 0 { - ttl = 600 - } - expire := now.Add(time.Duration(ttl) * time.Second) - if ipRecord.Expire.After(expire) { - ipRecord.Expire = expire - } - - if header.Type != recType { - if err := parser.SkipAnswer(); err != nil { - newError("failed to skip answer").Base(err).WriteToLog() - break L - } - continue - } - - switch header.Type { - case dnsmessage.TypeA: - ans, err := parser.AResource() - if err != nil { - newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog() - break L - } - ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:])) - case dnsmessage.TypeAAAA: - ans, err := parser.AAAAResource() - if err != nil { - newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog() - break L - } - ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:])) - default: - if err := parser.SkipAnswer(); err != nil { - newError("failed to skip answer").Base(err).WriteToLog() - break L - } - } - } - var rec record - switch recType { + switch req.reqType { case dnsmessage.TypeA: - rec.A = ipRecord + rec.A = ipRec case dnsmessage.TypeAAAA: - rec.AAAA = ipRecord + rec.AAAA = ipRec } - if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) { - s.updateIP(domain, rec) + elapsed := time.Since(req.start) + newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() + if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) { + s.updateIP(req.domain, rec) } } -func isNewer(baseRec *IPRecord, newRec *IPRecord) bool { - if newRec == nil { - return false - } - if baseRec == nil { - return true - } - return baseRec.Expire.Before(newRec.Expire) -} - func (s *ClassicNameServer) updateIP(domain string, newRec record) { s.Lock() - newError("updating IP records for domain:", domain).AtDebug().WriteToLog() + newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog() rec := s.ips[domain] updated := false @@ -259,116 +158,27 @@ func (s *ClassicNameServer) updateIP(domain string, newRec record) { common.Must(s.cleanup.Start()) } -func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource { - if len(s.clientIP) == 0 { - return nil - } - - var netmask int - var family uint16 - - if len(s.clientIP) == 4 { - family = 1 - netmask = 24 // 24 for IPV4, 96 for IPv6 - } else { - family = 2 - netmask = 96 - } - - b := make([]byte, 4) - binary.BigEndian.PutUint16(b[0:], family) - b[2] = byte(netmask) - b[3] = 0 - switch family { - case 1: - ip := s.clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8)) - needLength := (netmask + 8 - 1) / 8 // division rounding up - b = append(b, ip[:needLength]...) - case 2: - ip := s.clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8)) - needLength := (netmask + 8 - 1) / 8 // division rounding up - b = append(b, ip[:needLength]...) - } - - const EDNS0SUBNET = 0x08 - - opt := new(dnsmessage.Resource) - common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true)) - - opt.Body = &dnsmessage.OPTResource{ - Options: []dnsmessage.Option{ - { - Code: EDNS0SUBNET, - Data: b, - }, - }, - } - - return opt +func (s *ClassicNameServer) newReqID() uint16 { + return uint16(atomic.AddUint32(&s.reqID, 1)) } -func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 { - id := uint16(atomic.AddUint32(&s.reqID, 1)) +func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) { s.Lock() defer s.Unlock() - s.requests[id] = pendingRequest{ - domain: domain, - expire: time.Now().Add(time.Second * 8), - recType: recType, - } - - return id -} - -func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmessage.Message { - qA := dnsmessage.Question{ - Name: dnsmessage.MustNewName(domain), - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - } - - qAAAA := dnsmessage.Question{ - Name: dnsmessage.MustNewName(domain), - Type: dnsmessage.TypeAAAA, - Class: dnsmessage.ClassINET, - } - - var msgs []*dnsmessage.Message - - if option.IPv4Enable { - msg := new(dnsmessage.Message) - msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA) - msg.Header.RecursionDesired = true - msg.Questions = []dnsmessage.Question{qA} - if opt := s.getMsgOptions(); opt != nil { - msg.Additionals = append(msg.Additionals, *opt) - } - msgs = append(msgs, msg) - } - - if option.IPv6Enable { - msg := new(dnsmessage.Message) - msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA) - msg.Header.RecursionDesired = true - msg.Questions = []dnsmessage.Question{qAAAA} - if opt := s.getMsgOptions(); opt != nil { - msg.Additionals = append(msg.Additionals, *opt) - } - msgs = append(msgs, msg) - } - - return msgs + id := req.msg.ID + req.expire = time.Now().Add(time.Second * 8) + s.requests[id] = *req } func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) { - newError("querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx)) + newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx)) - msgs := s.buildMsgs(domain, option) - - for _, msg := range msgs { - b, _ := dns.PackMessage(msg) + reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP)) + for _, req := range reqs { + s.addPendingRequest(req) + b, _ := dns.PackMessage(req.msg) udpCtx := context.Background() if inbound := session.InboundFromContext(ctx); inbound != nil { udpCtx = session.ContextWithInbound(udpCtx, inbound) @@ -418,18 +228,13 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([] return nil, dns_feature.ErrEmptyResponse } -func Fqdn(domain string) string { - if len(domain) > 0 && domain[len(domain)-1] == '.' { - return domain - } - return domain + "." -} - func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) { + fqdn := Fqdn(domain) ips, err := s.findIPsForDomain(fqdn, option) if err != errRecordNotFound { + newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() return ips, err } diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index af82f3f0..eead10b7 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -68,12 +68,13 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou return nil, newError("not an outbound handler") } - if h.senderSettings != nil && h.senderSettings.MultiplexSettings != nil && h.senderSettings.MultiplexSettings.Enabled { + if h.senderSettings != nil && h.senderSettings.MultiplexSettings != nil { config := h.senderSettings.MultiplexSettings if config.Concurrency < 1 || config.Concurrency > 1024 { return nil, newError("invalid mux concurrency: ", config.Concurrency).AtWarning() } h.mux = &mux.ClientManager{ + Enabled: h.senderSettings.MultiplexSettings.Enabled, Picker: &mux.IncrementalWorkerPicker{ Factory: &mux.DialingWorkerFactory{ Proxy: proxyHandler, @@ -98,7 +99,7 @@ func (h *Handler) Tag() string { // Dispatch implements proxy.Outbound.Dispatch. func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { - if h.mux != nil { + if h.mux != nil && (h.mux.Enabled || session.MuxPreferedFromContext(ctx)) { if err := h.mux.Dispatch(ctx, link); err != nil { newError("failed to process mux outbound traffic").Base(err).WriteToLog(session.ExportIDToError(ctx)) common.Interrupt(link.Writer) diff --git a/common/mux/client.go b/common/mux/client.go index 537f8a09..ae9fe1b5 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -21,7 +21,8 @@ import ( ) type ClientManager struct { - Picker WorkerPicker + Enabled bool // wheather mux is enabled from user config + Picker WorkerPicker } func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error { diff --git a/common/session/context.go b/common/session/context.go index 2bbb40c9..2e69ae00 100644 --- a/common/session/context.go +++ b/common/session/context.go @@ -9,6 +9,7 @@ const ( inboundSessionKey outboundSessionKey contentSessionKey + MuxPreferedSessionKey ) // ContextWithID returns a new context with the given ID. @@ -56,3 +57,16 @@ func ContentFromContext(ctx context.Context) *Content { } return nil } + +// ContextWithMuxPrefered returns a new context with the given bool +func ContextWithMuxPrefered(ctx context.Context, forced bool) context.Context { + return context.WithValue(ctx, MuxPreferedSessionKey, forced) +} + +// MuxPreferedFromContext returns value in this context, or false if not contained. +func MuxPreferedFromContext(ctx context.Context) bool { + if val, ok := ctx.Value(MuxPreferedSessionKey).(bool); ok { + return val + } + return false +} diff --git a/infra/conf/v2ray.go b/infra/conf/v2ray.go index 01f74b4f..198922d7 100644 --- a/infra/conf/v2ray.go +++ b/infra/conf/v2ray.go @@ -75,15 +75,24 @@ func (c *SniffingConfig) Build() (*proxyman.SniffingConfig, error) { } type MuxConfig struct { - Enabled bool `json:"enabled"` - Concurrency uint16 `json:"concurrency"` + Enabled bool `json:"enabled"` + Concurrency int16 `json:"concurrency"` } -func (c *MuxConfig) GetConcurrency() uint16 { - if c.Concurrency == 0 { - return 8 +func (m *MuxConfig) Build() *proxyman.MultiplexingConfig { + if m.Concurrency < 0 { + return nil + } + + var con uint32 = 8 + if m.Concurrency > 0 { + con = uint32(m.Concurrency) + } + + return &proxyman.MultiplexingConfig{ + Enabled: m.Enabled, + Concurrency: con, } - return c.Concurrency } type InboundDetourAllocationConfig struct { @@ -246,11 +255,8 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) { senderSettings.ProxySettings = ps } - if c.MuxSettings != nil && c.MuxSettings.Enabled { - senderSettings.MultiplexSettings = &proxyman.MultiplexingConfig{ - Enabled: true, - Concurrency: uint32(c.MuxSettings.GetConcurrency()), - } + if c.MuxSettings != nil { + senderSettings.MultiplexSettings = c.MuxSettings.Build() } settings := []byte("{}") diff --git a/infra/conf/v2ray_test.go b/infra/conf/v2ray_test.go index 3b82be71..d324d3ea 100644 --- a/infra/conf/v2ray_test.go +++ b/infra/conf/v2ray_test.go @@ -2,15 +2,16 @@ package conf_test import ( "encoding/json" + "reflect" "testing" "github.com/golang/protobuf/proto" - "v2ray.com/core" "v2ray.com/core/app/dispatcher" "v2ray.com/core/app/log" "v2ray.com/core/app/proxyman" "v2ray.com/core/app/router" + "v2ray.com/core/common" clog "v2ray.com/core/common/log" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -337,3 +338,34 @@ func TestV2RayConfig(t *testing.T) { }, }) } + +func TestMuxConfig_Build(t *testing.T) { + tests := []struct { + name string + fields string + want *proxyman.MultiplexingConfig + }{ + {"default", `{"enabled": true, "concurrency": 16}`, &proxyman.MultiplexingConfig{ + Enabled: true, + Concurrency: 16, + }}, + {"empty def", `{}`, &proxyman.MultiplexingConfig{ + Enabled: false, + Concurrency: 8, + }}, + {"not enable", `{"enabled": false, "concurrency": 4}`, &proxyman.MultiplexingConfig{ + Enabled: false, + Concurrency: 4, + }}, + {"forbidden", `{"enabled": false, "concurrency": -1}`, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &MuxConfig{} + common.Must(json.Unmarshal([]byte(tt.fields), m)) + if got := m.Build(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("MuxConfig.Build() = %v, want %v", got, tt.want) + } + }) + } +}