From d6df5d7cf9e249ce7f31b9634466813afae89d1e Mon Sep 17 00:00:00 2001
From: vcptr <51714622+vcptr@users.noreply.github.com>
Date: Mon, 9 Dec 2019 09:37:35 +0800
Subject: [PATCH] doh URL controls full path

---
 app/dns/dohdns.go | 43 ++++++++++++++++++-------------------------
 app/dns/server.go | 32 +++++++-------------------------
 2 files changed, 25 insertions(+), 50 deletions(-)

diff --git a/app/dns/dohdns.go b/app/dns/dohdns.go
index c7a3f5c2..61237ba8 100644
--- a/app/dns/dohdns.go
+++ b/app/dns/dohdns.go
@@ -8,6 +8,7 @@ import (
 	"fmt"
 	"io/ioutil"
 	"net/http"
+	"net/url"
 	"sync"
 	"sync/atomic"
 	"time"
@@ -41,25 +42,25 @@ type DoHNameServer struct {
 }
 
 // NewDoHNameServer creates DOH client object for remote resolving
-func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) {
+func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) {
 
-	dohAddr := net.ParseAddress(dohHost)
-	var dests []net.Destination
-
-	if dohPort == 0 {
-		dohPort = 443
+	dohAddr := net.ParseAddress(url.Hostname())
+	dohPort := "443"
+	if url.Port() != "" {
+		dohPort = url.Port()
 	}
 
-	parseIPDest := func(ip net.IP, port uint32) net.Destination {
+	parseIPDest := func(ip net.IP, port string) net.Destination {
 		strIP := ip.String()
 		if len(ip) == net.IPv6len {
 			strIP = fmt.Sprintf("[%s]", strIP)
 		}
-		dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:%d", strIP, port))
+		dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:%s", strIP, port))
 		common.Must(err)
 		return dest
 	}
 
+	var dests []net.Destination
 	if dohAddr.Family().IsDomain() {
 		// resolve DOH server in advance
 		ips, err := net.LookupIP(dohAddr.Domain())
@@ -74,8 +75,8 @@ func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatc
 		dests = append(dests, parseIPDest(ip, dohPort))
 	}
 
-	newError("DNS: created remote DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog()
-	s := baseDOHNameServer(dohHost, dohPort, "DOH", clientIP)
+	newError("DNS: created Remote DOH client for ", url.String(), ", preresolved Dests: ", dests).AtInfo().WriteToLog()
+	s := baseDOHNameServer(url, "DOH", clientIP)
 	s.dispatcher = dispatcher
 	s.dohDests = dests
 
@@ -102,32 +103,24 @@ func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatc
 }
 
 // NewDoHLocalNameServer creates DOH client object for local resolving
-func NewDoHLocalNameServer(dohHost string, dohPort uint32, clientIP net.IP) *DoHNameServer {
-
-	if dohPort == 0 {
-		dohPort = 443
-	}
-
-	s := baseDOHNameServer(dohHost, dohPort, "DOHL", clientIP)
+func NewDoHLocalNameServer(url *url.URL, clientIP net.IP) *DoHNameServer {
+	url.Scheme = "https"
+	s := baseDOHNameServer(url, "DOHL", clientIP)
 	s.httpClient = &http.Client{
 		Timeout: time.Second * 180,
 	}
-	newError("DNS: created local DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog()
+	newError("DNS: created Local DOH client for ", url.String()).AtInfo().WriteToLog()
 	return s
 }
 
-func baseDOHNameServer(dohHost string, dohPort uint32, prefix string, clientIP net.IP) *DoHNameServer {
-
-	if dohPort == 0 {
-		dohPort = 443
-	}
+func baseDOHNameServer(url *url.URL, prefix string, clientIP net.IP) *DoHNameServer {
 
 	s := &DoHNameServer{
 		ips:      make(map[string]record),
 		clientIP: clientIP,
 		pub:      pubsub.NewService(),
-		name:     fmt.Sprintf("%s:%s:%d", prefix, dohHost, dohPort),
-		dohURL:   fmt.Sprintf("https://%s:%d/dns-query", dohHost, dohPort),
+		name:     fmt.Sprintf("%s//%s", prefix, url.Host),
+		dohURL:   url.String(),
 	}
 	s.cleanup = &task.Periodic{
 		Interval: time.Minute,
diff --git a/app/dns/server.go b/app/dns/server.go
index 9ba791ab..c2bf83f9 100644
--- a/app/dns/server.go
+++ b/app/dns/server.go
@@ -8,7 +8,6 @@ import (
 	"context"
 	"log"
 	"net/url"
-	"strconv"
 	"strings"
 	"sync"
 	"time"
@@ -87,40 +86,22 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 	}
 	server.hosts = hosts
 
-	parseDOHURI := func(d string, endpoint *net.Endpoint) (host string, port uint32, err error) {
-		u, err := url.Parse(d)
-		if err != nil {
-			return "", 0, err
-		}
-		host = u.Hostname()
-		port = 443
-		if u.Port() != "" {
-			p, err := strconv.ParseUint(u.Port(), 10, 16)
-			if err != nil {
-				return "", 0, err
-			}
-			port = uint32(p)
-		}
-		if endpoint.Port != 0 {
-			port = endpoint.Port
-		}
-		return
-	}
-
 	addNameServer := func(endpoint *net.Endpoint) int {
 		address := endpoint.Address.AsAddress()
 		if address.Family().IsDomain() && address.Domain() == "localhost" {
 			server.clients = append(server.clients, NewLocalNameServer())
 		} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "https+local://") {
 			// URI schemed string treated as domain
-			dohlHost, dohlPort, err := parseDOHURI(address.Domain(), endpoint)
+			// DOH Local mode
+			u, err := url.Parse(address.Domain())
 			if err != nil {
 				log.Fatalln(newError("DNS config error").Base(err))
 			}
-			server.clients = append(server.clients, NewDoHLocalNameServer(dohlHost, dohlPort, server.clientIP))
+			server.clients = append(server.clients, NewDoHLocalNameServer(u, server.clientIP))
 		} else if address.Family().IsDomain() &&
 			strings.HasPrefix(address.Domain(), "https://") {
-			dohHost, dohPort, err := parseDOHURI(address.Domain(), endpoint)
+			// DOH Remote mode
+			u, err := url.Parse(address.Domain())
 			if err != nil {
 				log.Fatalln(newError("DNS config error").Base(err))
 			}
@@ -129,13 +110,14 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 
 			// need the core dispatcher, register DOHClient at callback
 			common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) {
-				c, err := NewDoHNameServer(dohHost, dohPort, d, server.clientIP)
+				c, err := NewDoHNameServer(u, d, server.clientIP)
 				if err != nil {
 					log.Fatalln(newError("DNS config error").Base(err))
 				}
 				server.clients[idx] = c
 			}))
 		} else {
+			// UDP classic DNS mode
 			dest := endpoint.AsDestination()
 			if dest.Network == net.Network_Unknown {
 				dest.Network = net.Network_UDP