mirror of https://github.com/v2ray/v2ray-core
Merge branch 'raymaster' into flymaster
commit
622591bf03
|
@ -0,0 +1,237 @@
|
|||
// +build !confonly
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/errors"
|
||||
"v2ray.com/core/common/net"
|
||||
dns_feature "v2ray.com/core/features/dns"
|
||||
)
|
||||
|
||||
// Fqdn normalize domain make sure it ends with '.'
|
||||
func Fqdn(domain string) string {
|
||||
if len(domain) > 0 && domain[len(domain)-1] == '.' {
|
||||
return domain
|
||||
}
|
||||
return domain + "."
|
||||
}
|
||||
|
||||
type record struct {
|
||||
A *IPRecord
|
||||
AAAA *IPRecord
|
||||
}
|
||||
|
||||
// IPRecord is a cacheable item for a resolved domain
|
||||
type IPRecord struct {
|
||||
ReqID uint16
|
||||
IP []net.Address
|
||||
Expire time.Time
|
||||
RCode dnsmessage.RCode
|
||||
}
|
||||
|
||||
func (r *IPRecord) getIPs() ([]net.Address, error) {
|
||||
if r == nil || r.Expire.Before(time.Now()) {
|
||||
return nil, errRecordNotFound
|
||||
}
|
||||
if r.RCode != dnsmessage.RCodeSuccess {
|
||||
return nil, dns_feature.RCodeError(r.RCode)
|
||||
}
|
||||
return r.IP, nil
|
||||
}
|
||||
|
||||
func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
|
||||
if newRec == nil {
|
||||
return false
|
||||
}
|
||||
if baseRec == nil {
|
||||
return true
|
||||
}
|
||||
return baseRec.Expire.Before(newRec.Expire)
|
||||
}
|
||||
|
||||
var (
|
||||
errRecordNotFound = errors.New("record not found")
|
||||
)
|
||||
|
||||
type dnsRequest struct {
|
||||
reqType dnsmessage.Type
|
||||
domain string
|
||||
start time.Time
|
||||
expire time.Time
|
||||
msg *dnsmessage.Message
|
||||
}
|
||||
|
||||
func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
|
||||
if len(clientIP) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var netmask int
|
||||
var family uint16
|
||||
|
||||
if len(clientIP) == 4 {
|
||||
family = 1
|
||||
netmask = 24 // 24 for IPV4, 96 for IPv6
|
||||
} else {
|
||||
family = 2
|
||||
netmask = 96
|
||||
}
|
||||
|
||||
b := make([]byte, 4)
|
||||
binary.BigEndian.PutUint16(b[0:], family)
|
||||
b[2] = byte(netmask)
|
||||
b[3] = 0
|
||||
switch family {
|
||||
case 1:
|
||||
ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
|
||||
needLength := (netmask + 8 - 1) / 8 // division rounding up
|
||||
b = append(b, ip[:needLength]...)
|
||||
case 2:
|
||||
ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
|
||||
needLength := (netmask + 8 - 1) / 8 // division rounding up
|
||||
b = append(b, ip[:needLength]...)
|
||||
}
|
||||
|
||||
const EDNS0SUBNET = 0x08
|
||||
|
||||
opt := new(dnsmessage.Resource)
|
||||
common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
|
||||
|
||||
opt.Body = &dnsmessage.OPTResource{
|
||||
Options: []dnsmessage.Option{
|
||||
{
|
||||
Code: EDNS0SUBNET,
|
||||
Data: b,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
func buildReqMsgs(domain string, option IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
|
||||
qA := dnsmessage.Question{
|
||||
Name: dnsmessage.MustNewName(domain),
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
}
|
||||
|
||||
qAAAA := dnsmessage.Question{
|
||||
Name: dnsmessage.MustNewName(domain),
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
}
|
||||
|
||||
var reqs []*dnsRequest
|
||||
now := time.Now()
|
||||
|
||||
if option.IPv4Enable {
|
||||
msg := new(dnsmessage.Message)
|
||||
msg.Header.ID = reqIDGen()
|
||||
msg.Header.RecursionDesired = true
|
||||
msg.Questions = []dnsmessage.Question{qA}
|
||||
if reqOpts != nil {
|
||||
msg.Additionals = append(msg.Additionals, *reqOpts)
|
||||
}
|
||||
reqs = append(reqs, &dnsRequest{
|
||||
reqType: dnsmessage.TypeA,
|
||||
domain: domain,
|
||||
start: now,
|
||||
msg: msg,
|
||||
})
|
||||
}
|
||||
|
||||
if option.IPv6Enable {
|
||||
msg := new(dnsmessage.Message)
|
||||
msg.Header.ID = reqIDGen()
|
||||
msg.Header.RecursionDesired = true
|
||||
msg.Questions = []dnsmessage.Question{qAAAA}
|
||||
if reqOpts != nil {
|
||||
msg.Additionals = append(msg.Additionals, *reqOpts)
|
||||
}
|
||||
reqs = append(reqs, &dnsRequest{
|
||||
reqType: dnsmessage.TypeAAAA,
|
||||
domain: domain,
|
||||
start: now,
|
||||
msg: msg,
|
||||
})
|
||||
}
|
||||
|
||||
return reqs
|
||||
}
|
||||
|
||||
// parseResponse parse DNS answers from the returned payload
|
||||
func parseResponse(payload []byte) (*IPRecord, error) {
|
||||
var parser dnsmessage.Parser
|
||||
h, err := parser.Start(payload)
|
||||
if err != nil {
|
||||
return nil, newError("failed to parse DNS response").Base(err).AtWarning()
|
||||
}
|
||||
if err := parser.SkipAllQuestions(); err != nil {
|
||||
return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning()
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
var ipRecExpire time.Time
|
||||
if h.RCode != dnsmessage.RCodeSuccess {
|
||||
// A default TTL, maybe a negtive cache
|
||||
ipRecExpire = now.Add(time.Second * 120)
|
||||
}
|
||||
|
||||
ipRecord := &IPRecord{
|
||||
ReqID: h.ID,
|
||||
RCode: h.RCode,
|
||||
Expire: ipRecExpire,
|
||||
}
|
||||
|
||||
L:
|
||||
for {
|
||||
ah, err := parser.AnswerHeader()
|
||||
if err != nil {
|
||||
if err != dnsmessage.ErrSectionDone {
|
||||
newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog()
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
switch ah.Type {
|
||||
case dnsmessage.TypeA:
|
||||
ans, err := parser.AResource()
|
||||
if err != nil {
|
||||
newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
|
||||
break L
|
||||
}
|
||||
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
|
||||
case dnsmessage.TypeAAAA:
|
||||
ans, err := parser.AAAAResource()
|
||||
if err != nil {
|
||||
newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
|
||||
break L
|
||||
}
|
||||
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
|
||||
default:
|
||||
if err := parser.SkipAnswer(); err != nil {
|
||||
newError("failed to skip answer").Base(err).WriteToLog()
|
||||
break L
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if ipRecord.Expire.IsZero() {
|
||||
ttl := ah.TTL
|
||||
if ttl < 600 {
|
||||
// at least 10 mins TTL
|
||||
ipRecord.Expire = now.Add(time.Minute * 10)
|
||||
} else {
|
||||
ipRecord.Expire = now.Add(time.Duration(ttl) * time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ipRecord, nil
|
||||
}
|
|
@ -0,0 +1,166 @@
|
|||
// +build !confonly
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/net"
|
||||
v2net "v2ray.com/core/common/net"
|
||||
)
|
||||
|
||||
func Test_parseResponse(t *testing.T) {
|
||||
type args struct {
|
||||
payload []byte
|
||||
}
|
||||
|
||||
var p [][]byte
|
||||
|
||||
ans := new(dns.Msg)
|
||||
ans.Id = 0
|
||||
p = append(p, common.Must2(ans.Pack()).([]byte))
|
||||
|
||||
p = append(p, []byte{})
|
||||
|
||||
ans = new(dns.Msg)
|
||||
ans.Id = 1
|
||||
ans.Answer = append(ans.Answer,
|
||||
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
|
||||
common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR),
|
||||
common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")).(dns.RR),
|
||||
common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")).(dns.RR),
|
||||
)
|
||||
p = append(p, common.Must2(ans.Pack()).([]byte))
|
||||
|
||||
ans = new(dns.Msg)
|
||||
ans.Id = 2
|
||||
ans.Answer = append(ans.Answer,
|
||||
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
|
||||
common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR),
|
||||
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
|
||||
common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")).(dns.RR),
|
||||
common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")).(dns.RR),
|
||||
common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")).(dns.RR),
|
||||
)
|
||||
p = append(p, common.Must2(ans.Pack()).([]byte))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
want *IPRecord
|
||||
wantErr bool
|
||||
}{
|
||||
{"empty",
|
||||
&IPRecord{0, []v2net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess},
|
||||
false,
|
||||
},
|
||||
{"error",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{"a record",
|
||||
&IPRecord{1, []v2net.Address{v2net.ParseAddress("8.8.8.8"), v2net.ParseAddress("8.8.4.4")},
|
||||
time.Time{}, dnsmessage.RCodeSuccess},
|
||||
false,
|
||||
},
|
||||
{"aaaa record",
|
||||
&IPRecord{2, []v2net.Address{v2net.ParseAddress("2001::123:8888"), v2net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess},
|
||||
false,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parseResponse(p[i])
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("handleResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if got != nil {
|
||||
// reset the time
|
||||
got.Expire = time.Time{}
|
||||
}
|
||||
if cmp.Diff(got, tt.want) != "" {
|
||||
t.Errorf(cmp.Diff(got, tt.want))
|
||||
// t.Errorf("handleResponse() = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_buildReqMsgs(t *testing.T) {
|
||||
|
||||
stubID := func() uint16 {
|
||||
return uint16(rand.Uint32())
|
||||
}
|
||||
type args struct {
|
||||
domain string
|
||||
option IPOption
|
||||
reqOpts *dnsmessage.Resource
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want int
|
||||
}{
|
||||
{"dual stack", args{"test.com", IPOption{true, true}, nil}, 2},
|
||||
{"ipv4 only", args{"test.com", IPOption{true, false}, nil}, 1},
|
||||
{"ipv6 only", args{"test.com", IPOption{false, true}, nil}, 1},
|
||||
{"none/error", args{"test.com", IPOption{false, false}, nil}, 0},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := buildReqMsgs(tt.args.domain, tt.args.option, stubID, tt.args.reqOpts); !(len(got) == tt.want) {
|
||||
t.Errorf("buildReqMsgs() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_genEDNS0Options(t *testing.T) {
|
||||
type args struct {
|
||||
clientIP net.IP
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *dnsmessage.Resource
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{"ipv4", args{net.ParseIP("4.3.2.1")}, nil},
|
||||
{"ipv6", args{net.ParseIP("2001::4321")}, nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := genEDNS0Options(tt.args.clientIP); got == nil {
|
||||
t.Errorf("genEDNS0Options() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFqdn(t *testing.T) {
|
||||
type args struct {
|
||||
domain string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{"with fqdn", args{"www.v2ray.com."}, "www.v2ray.com."},
|
||||
{"without fqdn", args{"www.v2ray.com"}, "www.v2ray.com."},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := Fqdn(tt.args.domain); got != tt.want {
|
||||
t.Errorf("Fqdn() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,370 @@
|
|||
// +build !confonly
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/dice"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol/dns"
|
||||
"v2ray.com/core/common/session"
|
||||
"v2ray.com/core/common/signal/pubsub"
|
||||
"v2ray.com/core/common/task"
|
||||
"v2ray.com/core/features/routing"
|
||||
)
|
||||
|
||||
// DoHNameServer implimented DNS over HTTPS (RFC8484) Wire Format,
|
||||
// which is compatiable with traditional dns over udp(RFC1035),
|
||||
// thus most of the DOH implimentation is copied from udpns.go
|
||||
type DoHNameServer struct {
|
||||
sync.RWMutex
|
||||
dispatcher routing.Dispatcher
|
||||
dohDests []net.Destination
|
||||
ips map[string]record
|
||||
pub *pubsub.Service
|
||||
cleanup *task.Periodic
|
||||
reqID uint32
|
||||
clientIP net.IP
|
||||
httpClient *http.Client
|
||||
dohURL string
|
||||
name string
|
||||
}
|
||||
|
||||
// NewDoHNameServer creates DOH client object for remote resolving
|
||||
func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) {
|
||||
|
||||
dohAddr := net.ParseAddress(dohHost)
|
||||
var dests []net.Destination
|
||||
|
||||
if dohPort == 0 {
|
||||
dohPort = 443
|
||||
}
|
||||
|
||||
parseIPDest := func(ip net.IP, port uint32) net.Destination {
|
||||
strIP := ip.String()
|
||||
if len(ip) == net.IPv6len {
|
||||
strIP = fmt.Sprintf("[%s]", strIP)
|
||||
}
|
||||
dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:%d", strIP, port))
|
||||
common.Must(err)
|
||||
return dest
|
||||
}
|
||||
|
||||
if dohAddr.Family().IsDomain() {
|
||||
// resolve DOH server in advance
|
||||
ips, err := net.LookupIP(dohAddr.Domain())
|
||||
if err != nil || len(ips) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
for _, ip := range ips {
|
||||
dests = append(dests, parseIPDest(ip, dohPort))
|
||||
}
|
||||
} else {
|
||||
ip := dohAddr.IP()
|
||||
dests = append(dests, parseIPDest(ip, dohPort))
|
||||
}
|
||||
|
||||
newError("DNS: created remote DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog()
|
||||
s := baseDOHNameServer(dohHost, dohPort, "DOH", clientIP)
|
||||
s.dispatcher = dispatcher
|
||||
s.dohDests = dests
|
||||
|
||||
// Dispatched connection will be closed (interupted) after each request
|
||||
// This makes DOH inefficient without a keeped-alive connection
|
||||
// See: core/app/proxyman/outbound/handler.go:113
|
||||
// Using mux (https request wrapped in a stream layer) improves the situation.
|
||||
// Recommand to use NewDoHLocalNameServer (DOHL:) if v2ray instance is running on
|
||||
// a normal network eg. the server side of v2ray
|
||||
tr := &http.Transport{
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
DialContext: s.DialContext,
|
||||
}
|
||||
|
||||
dispatchedClient := &http.Client{
|
||||
Transport: tr,
|
||||
Timeout: 16 * time.Second,
|
||||
}
|
||||
|
||||
s.httpClient = dispatchedClient
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewDoHLocalNameServer creates DOH client object for local resolving
|
||||
func NewDoHLocalNameServer(dohHost string, dohPort uint32, clientIP net.IP) *DoHNameServer {
|
||||
|
||||
if dohPort == 0 {
|
||||
dohPort = 443
|
||||
}
|
||||
|
||||
s := baseDOHNameServer(dohHost, dohPort, "DOHL", clientIP)
|
||||
s.httpClient = &http.Client{
|
||||
Timeout: time.Second * 180,
|
||||
}
|
||||
newError("DNS: created local DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog()
|
||||
return s
|
||||
}
|
||||
|
||||
func baseDOHNameServer(dohHost string, dohPort uint32, prefix string, clientIP net.IP) *DoHNameServer {
|
||||
|
||||
if dohPort == 0 {
|
||||
dohPort = 443
|
||||
}
|
||||
|
||||
s := &DoHNameServer{
|
||||
ips: make(map[string]record),
|
||||
clientIP: clientIP,
|
||||
pub: pubsub.NewService(),
|
||||
name: fmt.Sprintf("%s:%s:%d", prefix, dohHost, dohPort),
|
||||
dohURL: fmt.Sprintf("https://%s:%d/dns-query", dohHost, dohPort),
|
||||
}
|
||||
s.cleanup = &task.Periodic{
|
||||
Interval: time.Minute,
|
||||
Execute: s.Cleanup,
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Name returns client name
|
||||
func (s *DoHNameServer) Name() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
// DialContext offer dispatched connection through core routing
|
||||
func (s *DoHNameServer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
|
||||
dest := s.dohDests[dice.Roll(len(s.dohDests))]
|
||||
|
||||
link, err := s.dispatcher.Dispatch(ctx, dest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.NewConnection(
|
||||
net.ConnectionInputMulti(link.Writer),
|
||||
net.ConnectionOutputMulti(link.Reader),
|
||||
), nil
|
||||
}
|
||||
|
||||
// Cleanup clears expired items from cache
|
||||
func (s *DoHNameServer) Cleanup() error {
|
||||
now := time.Now()
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if len(s.ips) == 0 {
|
||||
return newError("nothing to do. stopping...")
|
||||
}
|
||||
|
||||
for domain, record := range s.ips {
|
||||
if record.A != nil && record.A.Expire.Before(now) {
|
||||
record.A = nil
|
||||
}
|
||||
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
||||
record.AAAA = nil
|
||||
}
|
||||
|
||||
if record.A == nil && record.AAAA == nil {
|
||||
newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
|
||||
delete(s.ips, domain)
|
||||
} else {
|
||||
s.ips[domain] = record
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.ips) == 0 {
|
||||
s.ips = make(map[string]record)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
|
||||
elapsed := time.Since(req.start)
|
||||
newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
|
||||
|
||||
s.Lock()
|
||||
rec := s.ips[req.domain]
|
||||
updated := false
|
||||
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
if isNewer(rec.A, ipRec) {
|
||||
rec.A = ipRec
|
||||
updated = true
|
||||
}
|
||||
case dnsmessage.TypeAAAA:
|
||||
if isNewer(rec.AAAA, ipRec) {
|
||||
rec.AAAA = ipRec
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
|
||||
if updated {
|
||||
s.ips[req.domain] = rec
|
||||
s.pub.Publish(req.domain, nil)
|
||||
}
|
||||
|
||||
s.Unlock()
|
||||
common.Must(s.cleanup.Start())
|
||||
}
|
||||
|
||||
func (s *DoHNameServer) newReqID() uint16 {
|
||||
return uint16(atomic.AddUint32(&s.reqID, 1))
|
||||
}
|
||||
|
||||
func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
|
||||
newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx))
|
||||
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP))
|
||||
|
||||
var deadline time.Time
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
deadline = d
|
||||
} else {
|
||||
deadline = time.Now().Add(time.Second * 8)
|
||||
}
|
||||
|
||||
for _, req := range reqs {
|
||||
|
||||
go func(r *dnsRequest) {
|
||||
|
||||
// generate new context for each req, using same context
|
||||
// may cause reqs all aborted if any one encounter an error
|
||||
dnsCtx := context.Background()
|
||||
|
||||
// reserve internal dns server requested Inbound
|
||||
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
||||
dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
|
||||
}
|
||||
|
||||
dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
|
||||
Protocol: "https",
|
||||
})
|
||||
|
||||
// forced to use mux for DOH
|
||||
dnsCtx = session.ContextWithMuxPrefered(dnsCtx, true)
|
||||
|
||||
dnsCtx, cancel := context.WithDeadline(dnsCtx, deadline)
|
||||
defer cancel()
|
||||
|
||||
b, _ := dns.PackMessage(r.msg)
|
||||
resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes())
|
||||
if err != nil {
|
||||
newError("failed to retrive response").Base(err).AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
rec, err := parseResponse(resp)
|
||||
if err != nil {
|
||||
newError("failed to handle DOH response").Base(err).AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
s.updateIP(r, rec)
|
||||
}(req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, error) {
|
||||
|
||||
body := bytes.NewBuffer(b)
|
||||
req, err := http.NewRequest("POST", s.dohURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("Accept", "application/dns-message")
|
||||
req.Header.Add("Content-Type", "application/dns-message")
|
||||
|
||||
resp, err := s.httpClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err = fmt.Errorf("DOH HTTPS server returned with non-OK code %d", resp.StatusCode)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ioutil.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func (s *DoHNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) {
|
||||
s.RLock()
|
||||
record, found := s.ips[domain]
|
||||
s.RUnlock()
|
||||
|
||||
if !found {
|
||||
return nil, errRecordNotFound
|
||||
}
|
||||
|
||||
var ips []net.Address
|
||||
var lastErr error
|
||||
if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess {
|
||||
aaaa, err := record.AAAA.getIPs()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
ips = append(ips, aaaa...)
|
||||
}
|
||||
|
||||
if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess {
|
||||
a, err := record.A.getIPs()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
ips = append(ips, a...)
|
||||
}
|
||||
|
||||
if len(ips) > 0 {
|
||||
return toNetIP(ips), nil
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
return nil, errRecordNotFound
|
||||
}
|
||||
|
||||
// QueryIP is called from dns.Server->queryIPTimeout
|
||||
func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
|
||||
|
||||
fqdn := Fqdn(domain)
|
||||
|
||||
ips, err := s.findIPsForDomain(fqdn, option)
|
||||
if err != errRecordNotFound {
|
||||
newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
|
||||
return ips, err
|
||||
}
|
||||
|
||||
sub := s.pub.Subscribe(fqdn)
|
||||
defer sub.Close()
|
||||
|
||||
s.sendQuery(ctx, fqdn, option)
|
||||
|
||||
for {
|
||||
ips, err := s.findIPsForDomain(fqdn, option)
|
||||
if err != errRecordNotFound {
|
||||
return ips, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-sub.Wait():
|
||||
}
|
||||
}
|
||||
}
|
|
@ -49,6 +49,7 @@ func (s *localNameServer) Name() string {
|
|||
}
|
||||
|
||||
func NewLocalNameServer() *localNameServer {
|
||||
newError("DNS: created localhost client").AtInfo().WriteToLog()
|
||||
return &localNameServer{
|
||||
client: localdns.New(),
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@ package dns
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -39,7 +41,7 @@ type MultiGeoIPMatcher struct {
|
|||
matchers []*router.GeoIPMatcher
|
||||
}
|
||||
|
||||
var errExpectedIPNonMatch = errors.New("expected ip not match")
|
||||
var errExpectedIPNonMatch = errors.New("expectIPs not match")
|
||||
|
||||
// Match check ip match
|
||||
func (c *MultiGeoIPMatcher) Match(ip net.IP) bool {
|
||||
|
@ -71,7 +73,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
|
|||
server.tag = generateRandomTag()
|
||||
}
|
||||
if len(config.ClientIp) > 0 {
|
||||
if len(config.ClientIp) != 4 && len(config.ClientIp) != 16 {
|
||||
if len(config.ClientIp) != net.IPv4len && len(config.ClientIp) != net.IPv6len {
|
||||
return nil, newError("unexpected IP length", len(config.ClientIp))
|
||||
}
|
||||
server.clientIP = net.IP(config.ClientIp)
|
||||
|
@ -87,6 +89,23 @@ func New(ctx context.Context, config *Config) (*Server, error) {
|
|||
address := endpoint.Address.AsAddress()
|
||||
if address.Family().IsDomain() && address.Domain() == "localhost" {
|
||||
server.clients = append(server.clients, NewLocalNameServer())
|
||||
} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOHL_") {
|
||||
dohHost := address.Domain()[5:]
|
||||
server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, endpoint.Port, server.clientIP))
|
||||
} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOH_") {
|
||||
// DOH_ prefix makes net.Address think it's a domain
|
||||
dohHost := address.Domain()[4:]
|
||||
idx := len(server.clients)
|
||||
server.clients = append(server.clients, nil)
|
||||
|
||||
// need the core dispatcher, register DOHClient at callback
|
||||
common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) {
|
||||
c, err := NewDoHNameServer(dohHost, endpoint.Port, d, server.clientIP)
|
||||
if err != nil {
|
||||
log.Fatalln(newError("DNS config error").Base(err))
|
||||
}
|
||||
server.clients[idx] = c
|
||||
}))
|
||||
} else {
|
||||
dest := endpoint.AsDestination()
|
||||
if dest.Network == net.Network_Unknown {
|
||||
|
@ -129,16 +148,19 @@ func New(ctx context.Context, config *Config) (*Server, error) {
|
|||
domainIndexMap[midx] = uint32(idx)
|
||||
}
|
||||
|
||||
var matchers []*router.GeoIPMatcher
|
||||
for _, geoip := range ns.Geoip {
|
||||
matcher, err := geoIPMatcherContainer.Add(geoip)
|
||||
if err != nil {
|
||||
return nil, newError("failed to create ip matcher").Base(err).AtWarning()
|
||||
// only add to ipIndexMap if GeoIP is configured
|
||||
if len(ns.Geoip) > 0 {
|
||||
var matchers []*router.GeoIPMatcher
|
||||
for _, geoip := range ns.Geoip {
|
||||
matcher, err := geoIPMatcherContainer.Add(geoip)
|
||||
if err != nil {
|
||||
return nil, newError("failed to create ip matcher").Base(err).AtWarning()
|
||||
}
|
||||
matchers = append(matchers, matcher)
|
||||
}
|
||||
matchers = append(matchers, matcher)
|
||||
matcher := &MultiGeoIPMatcher{matchers: matchers}
|
||||
ipIndexMap[uint32(idx)] = matcher
|
||||
}
|
||||
matcher := &MultiGeoIPMatcher{matchers: matchers}
|
||||
ipIndexMap[uint32(idx)] = matcher
|
||||
}
|
||||
|
||||
server.domainMatcher = domainMatcher
|
||||
|
@ -177,12 +199,11 @@ func (s *Server) IsOwnLink(ctx context.Context) bool {
|
|||
func (s *Server) Match(idx uint32, client Client, domain string, ips []net.IP) ([]net.IP, error) {
|
||||
matcher, exist := s.ipIndexMap[idx]
|
||||
if !exist {
|
||||
newError("domain ", domain, " server not in ipIndexMap: ", client.Name(), " idx:", idx, " just return").AtDebug().WriteToLog()
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
if !matcher.HasMatcher() {
|
||||
newError("domain ", domain, " server has not valid matcher: ", client.Name(), " idx:", idx, " just return").AtDebug().WriteToLog()
|
||||
newError("domain ", domain, " server has no valid matcher: ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
|
@ -190,14 +211,12 @@ func (s *Server) Match(idx uint32, client Client, domain string, ips []net.IP) (
|
|||
for _, ip := range ips {
|
||||
if matcher.Match(ip) {
|
||||
newIps = append(newIps, ip)
|
||||
newError("domain ", domain, " ip ", ip, " is match at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
|
||||
} else {
|
||||
newError("domain ", domain, " ip ", ip, " is not match at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
|
||||
}
|
||||
}
|
||||
if len(newIps) == 0 {
|
||||
return nil, errExpectedIPNonMatch
|
||||
}
|
||||
newError("domain ", domain, " expectIPs ", newIps, " matched at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
|
||||
return newIps, nil
|
||||
}
|
||||
|
||||
|
@ -272,10 +291,16 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
|
|||
return nil, newError("empty domain name")
|
||||
}
|
||||
|
||||
// normalize the FQDN form query
|
||||
if domain[len(domain)-1] == '.' {
|
||||
domain = domain[:len(domain)-1]
|
||||
}
|
||||
|
||||
// skip domain without any dot
|
||||
if strings.Index(domain, ".") == -1 {
|
||||
return nil, newError("invalid domain name").AtWarning()
|
||||
}
|
||||
|
||||
ips := s.lookupStatic(domain, option, 0)
|
||||
if ips != nil && ips[0].Family().IsIP() {
|
||||
newError("returning ", len(ips), " IPs for domain ", domain).WriteToLog()
|
||||
|
@ -294,7 +319,6 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
|
|||
idx := s.domainMatcher.Match(domain)
|
||||
if idx > 0 {
|
||||
matchedClient = s.clients[s.domainIndexMap[idx]]
|
||||
newError("domain matched, direct lookup ip for domain ", domain, " at ", matchedClient.Name()).WriteToLog()
|
||||
ips, err := s.queryIPTimeout(s.domainIndexMap[idx], matchedClient, domain, option)
|
||||
if len(ips) > 0 {
|
||||
return ips, nil
|
||||
|
@ -315,10 +339,8 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
|
|||
continue
|
||||
}
|
||||
|
||||
newError("try to lookup ip for domain ", domain, " at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
|
||||
ips, err := s.queryIPTimeout(uint32(idx), client, domain, option)
|
||||
if len(ips) > 0 {
|
||||
newError("lookup ip for domain ", domain, " success: ", ips, " at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
|
@ -331,7 +353,7 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
|
|||
}
|
||||
}
|
||||
|
||||
return nil, newError("returning nil for domain ", domain).Base(lastErr)
|
||||
return nil, dns.ErrEmptyResponse.Base(lastErr)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
276
app/dns/udpns.go
276
app/dns/udpns.go
|
@ -4,14 +4,13 @@ package dns
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/errors"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol/dns"
|
||||
udp_proto "v2ray.com/core/common/protocol/udp"
|
||||
|
@ -23,42 +22,12 @@ import (
|
|||
"v2ray.com/core/transport/internet/udp"
|
||||
)
|
||||
|
||||
type record struct {
|
||||
A *IPRecord
|
||||
AAAA *IPRecord
|
||||
}
|
||||
|
||||
type IPRecord struct {
|
||||
IP []net.Address
|
||||
Expire time.Time
|
||||
RCode dnsmessage.RCode
|
||||
}
|
||||
|
||||
func (r *IPRecord) getIPs() ([]net.Address, error) {
|
||||
if r == nil || r.Expire.Before(time.Now()) {
|
||||
return nil, errRecordNotFound
|
||||
}
|
||||
if r.RCode != dnsmessage.RCodeSuccess {
|
||||
return nil, dns_feature.RCodeError(r.RCode)
|
||||
}
|
||||
return r.IP, nil
|
||||
}
|
||||
|
||||
type pendingRequest struct {
|
||||
domain string
|
||||
expire time.Time
|
||||
recType dnsmessage.Type
|
||||
}
|
||||
|
||||
var (
|
||||
errRecordNotFound = errors.New("record not found")
|
||||
)
|
||||
|
||||
type ClassicNameServer struct {
|
||||
sync.RWMutex
|
||||
name string
|
||||
address net.Destination
|
||||
ips map[string]record
|
||||
requests map[uint16]pendingRequest
|
||||
requests map[uint16]dnsRequest
|
||||
pub *pubsub.Service
|
||||
udpServer *udp.Dispatcher
|
||||
cleanup *task.Periodic
|
||||
|
@ -67,23 +36,31 @@ type ClassicNameServer struct {
|
|||
}
|
||||
|
||||
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
|
||||
|
||||
// default to 53 if unspecific
|
||||
if address.Port == 0 {
|
||||
address.Port = net.Port(53)
|
||||
}
|
||||
|
||||
s := &ClassicNameServer{
|
||||
address: address,
|
||||
ips: make(map[string]record),
|
||||
requests: make(map[uint16]pendingRequest),
|
||||
requests: make(map[uint16]dnsRequest),
|
||||
clientIP: clientIP,
|
||||
pub: pubsub.NewService(),
|
||||
name: strings.ToUpper(address.String()),
|
||||
}
|
||||
s.cleanup = &task.Periodic{
|
||||
Interval: time.Minute,
|
||||
Execute: s.Cleanup,
|
||||
}
|
||||
s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
|
||||
newError("DNS: created udp client inited for ", address.NetAddr()).AtInfo().WriteToLog()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) Name() string {
|
||||
return s.address.String()
|
||||
return s.name
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) Cleanup() error {
|
||||
|
@ -92,7 +69,7 @@ func (s *ClassicNameServer) Cleanup() error {
|
|||
defer s.Unlock()
|
||||
|
||||
if len(s.ips) == 0 && len(s.requests) == 0 {
|
||||
return newError("nothing to do. stopping...")
|
||||
return newError(s.name, " nothing to do. stopping...")
|
||||
}
|
||||
|
||||
for domain, record := range s.ips {
|
||||
|
@ -121,123 +98,52 @@ func (s *ClassicNameServer) Cleanup() error {
|
|||
}
|
||||
|
||||
if len(s.requests) == 0 {
|
||||
s.requests = make(map[uint16]pendingRequest)
|
||||
s.requests = make(map[uint16]dnsRequest)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
|
||||
payload := packet.Payload
|
||||
|
||||
var parser dnsmessage.Parser
|
||||
header, err := parser.Start(payload.Bytes())
|
||||
ipRec, err := parseResponse(packet.Payload.Bytes())
|
||||
if err != nil {
|
||||
newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog()
|
||||
return
|
||||
}
|
||||
if err := parser.SkipAllQuestions(); err != nil {
|
||||
newError("failed to skip questions in DNS response").Base(err).AtWarning().WriteToLog()
|
||||
newError(s.name, " fail to parse responsed DNS udp").AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
|
||||
id := header.ID
|
||||
s.Lock()
|
||||
req, f := s.requests[id]
|
||||
if f {
|
||||
id := ipRec.ReqID
|
||||
req, ok := s.requests[id]
|
||||
if ok {
|
||||
// remove the pending request
|
||||
delete(s.requests, id)
|
||||
}
|
||||
s.Unlock()
|
||||
|
||||
if !f {
|
||||
if !ok {
|
||||
newError(s.name, " cannot find the pending request").AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
|
||||
domain := req.domain
|
||||
recType := req.recType
|
||||
|
||||
now := time.Now()
|
||||
ipRecord := &IPRecord{
|
||||
RCode: header.RCode,
|
||||
Expire: now.Add(time.Second * 600),
|
||||
}
|
||||
|
||||
L:
|
||||
for {
|
||||
header, err := parser.AnswerHeader()
|
||||
if err != nil {
|
||||
if err != dnsmessage.ErrSectionDone {
|
||||
newError("failed to parse answer section for domain: ", domain).Base(err).WriteToLog()
|
||||
}
|
||||
break
|
||||
}
|
||||
ttl := header.TTL
|
||||
if ttl == 0 {
|
||||
ttl = 600
|
||||
}
|
||||
expire := now.Add(time.Duration(ttl) * time.Second)
|
||||
if ipRecord.Expire.After(expire) {
|
||||
ipRecord.Expire = expire
|
||||
}
|
||||
|
||||
if header.Type != recType {
|
||||
if err := parser.SkipAnswer(); err != nil {
|
||||
newError("failed to skip answer").Base(err).WriteToLog()
|
||||
break L
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch header.Type {
|
||||
case dnsmessage.TypeA:
|
||||
ans, err := parser.AResource()
|
||||
if err != nil {
|
||||
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
|
||||
break L
|
||||
}
|
||||
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
|
||||
case dnsmessage.TypeAAAA:
|
||||
ans, err := parser.AAAAResource()
|
||||
if err != nil {
|
||||
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
|
||||
break L
|
||||
}
|
||||
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
|
||||
default:
|
||||
if err := parser.SkipAnswer(); err != nil {
|
||||
newError("failed to skip answer").Base(err).WriteToLog()
|
||||
break L
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var rec record
|
||||
switch recType {
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
rec.A = ipRecord
|
||||
rec.A = ipRec
|
||||
case dnsmessage.TypeAAAA:
|
||||
rec.AAAA = ipRecord
|
||||
rec.AAAA = ipRec
|
||||
}
|
||||
|
||||
if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
|
||||
s.updateIP(domain, rec)
|
||||
elapsed := time.Since(req.start)
|
||||
newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
|
||||
if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
|
||||
s.updateIP(req.domain, rec)
|
||||
}
|
||||
}
|
||||
|
||||
func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
|
||||
if newRec == nil {
|
||||
return false
|
||||
}
|
||||
if baseRec == nil {
|
||||
return true
|
||||
}
|
||||
return baseRec.Expire.Before(newRec.Expire)
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) updateIP(domain string, newRec record) {
|
||||
s.Lock()
|
||||
|
||||
newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
|
||||
newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog()
|
||||
rec := s.ips[domain]
|
||||
|
||||
updated := false
|
||||
|
@ -259,116 +165,27 @@ func (s *ClassicNameServer) updateIP(domain string, newRec record) {
|
|||
common.Must(s.cleanup.Start())
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
|
||||
if len(s.clientIP) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var netmask int
|
||||
var family uint16
|
||||
|
||||
if len(s.clientIP) == 4 {
|
||||
family = 1
|
||||
netmask = 24 // 24 for IPV4, 96 for IPv6
|
||||
} else {
|
||||
family = 2
|
||||
netmask = 96
|
||||
}
|
||||
|
||||
b := make([]byte, 4)
|
||||
binary.BigEndian.PutUint16(b[0:], family)
|
||||
b[2] = byte(netmask)
|
||||
b[3] = 0
|
||||
switch family {
|
||||
case 1:
|
||||
ip := s.clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
|
||||
needLength := (netmask + 8 - 1) / 8 // division rounding up
|
||||
b = append(b, ip[:needLength]...)
|
||||
case 2:
|
||||
ip := s.clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
|
||||
needLength := (netmask + 8 - 1) / 8 // division rounding up
|
||||
b = append(b, ip[:needLength]...)
|
||||
}
|
||||
|
||||
const EDNS0SUBNET = 0x08
|
||||
|
||||
opt := new(dnsmessage.Resource)
|
||||
common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
|
||||
|
||||
opt.Body = &dnsmessage.OPTResource{
|
||||
Options: []dnsmessage.Option{
|
||||
{
|
||||
Code: EDNS0SUBNET,
|
||||
Data: b,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return opt
|
||||
func (s *ClassicNameServer) newReqID() uint16 {
|
||||
return uint16(atomic.AddUint32(&s.reqID, 1))
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 {
|
||||
id := uint16(atomic.AddUint32(&s.reqID, 1))
|
||||
func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
s.requests[id] = pendingRequest{
|
||||
domain: domain,
|
||||
expire: time.Now().Add(time.Second * 8),
|
||||
recType: recType,
|
||||
}
|
||||
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmessage.Message {
|
||||
qA := dnsmessage.Question{
|
||||
Name: dnsmessage.MustNewName(domain),
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
}
|
||||
|
||||
qAAAA := dnsmessage.Question{
|
||||
Name: dnsmessage.MustNewName(domain),
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
}
|
||||
|
||||
var msgs []*dnsmessage.Message
|
||||
|
||||
if option.IPv4Enable {
|
||||
msg := new(dnsmessage.Message)
|
||||
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA)
|
||||
msg.Header.RecursionDesired = true
|
||||
msg.Questions = []dnsmessage.Question{qA}
|
||||
if opt := s.getMsgOptions(); opt != nil {
|
||||
msg.Additionals = append(msg.Additionals, *opt)
|
||||
}
|
||||
msgs = append(msgs, msg)
|
||||
}
|
||||
|
||||
if option.IPv6Enable {
|
||||
msg := new(dnsmessage.Message)
|
||||
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA)
|
||||
msg.Header.RecursionDesired = true
|
||||
msg.Questions = []dnsmessage.Question{qAAAA}
|
||||
if opt := s.getMsgOptions(); opt != nil {
|
||||
msg.Additionals = append(msg.Additionals, *opt)
|
||||
}
|
||||
msgs = append(msgs, msg)
|
||||
}
|
||||
|
||||
return msgs
|
||||
id := req.msg.ID
|
||||
req.expire = time.Now().Add(time.Second * 8)
|
||||
s.requests[id] = *req
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
|
||||
newError("querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
|
||||
newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
|
||||
|
||||
msgs := s.buildMsgs(domain, option)
|
||||
|
||||
for _, msg := range msgs {
|
||||
b, _ := dns.PackMessage(msg)
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP))
|
||||
|
||||
for _, req := range reqs {
|
||||
s.addPendingRequest(req)
|
||||
b, _ := dns.PackMessage(req.msg)
|
||||
udpCtx := context.Background()
|
||||
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
||||
udpCtx = session.ContextWithInbound(udpCtx, inbound)
|
||||
|
@ -418,18 +235,13 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]
|
|||
return nil, dns_feature.ErrEmptyResponse
|
||||
}
|
||||
|
||||
func Fqdn(domain string) string {
|
||||
if len(domain) > 0 && domain[len(domain)-1] == '.' {
|
||||
return domain
|
||||
}
|
||||
return domain + "."
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
|
||||
|
||||
fqdn := Fqdn(domain)
|
||||
|
||||
ips, err := s.findIPsForDomain(fqdn, option)
|
||||
if err != errRecordNotFound {
|
||||
newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
|
||||
return ips, err
|
||||
}
|
||||
|
||||
|
|
|
@ -68,12 +68,13 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou
|
|||
return nil, newError("not an outbound handler")
|
||||
}
|
||||
|
||||
if h.senderSettings != nil && h.senderSettings.MultiplexSettings != nil && h.senderSettings.MultiplexSettings.Enabled {
|
||||
if h.senderSettings != nil && h.senderSettings.MultiplexSettings != nil {
|
||||
config := h.senderSettings.MultiplexSettings
|
||||
if config.Concurrency < 1 || config.Concurrency > 1024 {
|
||||
return nil, newError("invalid mux concurrency: ", config.Concurrency).AtWarning()
|
||||
}
|
||||
h.mux = &mux.ClientManager{
|
||||
Enabled: h.senderSettings.MultiplexSettings.Enabled,
|
||||
Picker: &mux.IncrementalWorkerPicker{
|
||||
Factory: &mux.DialingWorkerFactory{
|
||||
Proxy: proxyHandler,
|
||||
|
@ -98,7 +99,7 @@ func (h *Handler) Tag() string {
|
|||
|
||||
// Dispatch implements proxy.Outbound.Dispatch.
|
||||
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
|
||||
if h.mux != nil {
|
||||
if h.mux != nil && (h.mux.Enabled || session.MuxPreferedFromContext(ctx)) {
|
||||
if err := h.mux.Dispatch(ctx, link); err != nil {
|
||||
newError("failed to process mux outbound traffic").Base(err).WriteToLog(session.ExportIDToError(ctx))
|
||||
common.Interrupt(link.Writer)
|
||||
|
|
|
@ -21,7 +21,8 @@ import (
|
|||
)
|
||||
|
||||
type ClientManager struct {
|
||||
Picker WorkerPicker
|
||||
Enabled bool // wheather mux is enabled from user config
|
||||
Picker WorkerPicker
|
||||
}
|
||||
|
||||
func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error {
|
||||
|
|
|
@ -9,6 +9,7 @@ const (
|
|||
inboundSessionKey
|
||||
outboundSessionKey
|
||||
contentSessionKey
|
||||
muxPreferedSessionKey
|
||||
)
|
||||
|
||||
// ContextWithID returns a new context with the given ID.
|
||||
|
@ -56,3 +57,16 @@ func ContentFromContext(ctx context.Context) *Content {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContextWithMuxPrefered returns a new context with the given bool
|
||||
func ContextWithMuxPrefered(ctx context.Context, forced bool) context.Context {
|
||||
return context.WithValue(ctx, muxPreferedSessionKey, forced)
|
||||
}
|
||||
|
||||
// MuxPreferedFromContext returns value in this context, or false if not contained.
|
||||
func MuxPreferedFromContext(ctx context.Context) bool {
|
||||
if val, ok := ctx.Value(muxPreferedSessionKey).(bool); ok {
|
||||
return val
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ func (c *NameServerConfig) UnmarshalJSON(data []byte) error {
|
|||
var address Address
|
||||
if err := json.Unmarshal(data, &address); err == nil {
|
||||
c.Address = &address
|
||||
c.Port = 53
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -75,15 +75,25 @@ func (c *SniffingConfig) Build() (*proxyman.SniffingConfig, error) {
|
|||
}
|
||||
|
||||
type MuxConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Concurrency uint16 `json:"concurrency"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Concurrency int16 `json:"concurrency"`
|
||||
}
|
||||
|
||||
func (c *MuxConfig) GetConcurrency() uint16 {
|
||||
if c.Concurrency == 0 {
|
||||
return 8
|
||||
// Build creates MultiplexingConfig, Concurrency < 0 completely disables mux.
|
||||
func (m *MuxConfig) Build() *proxyman.MultiplexingConfig {
|
||||
if m.Concurrency < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var con uint32 = 8
|
||||
if m.Concurrency > 0 {
|
||||
con = uint32(m.Concurrency)
|
||||
}
|
||||
|
||||
return &proxyman.MultiplexingConfig{
|
||||
Enabled: m.Enabled,
|
||||
Concurrency: con,
|
||||
}
|
||||
return c.Concurrency
|
||||
}
|
||||
|
||||
type InboundDetourAllocationConfig struct {
|
||||
|
@ -246,11 +256,8 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
|
|||
senderSettings.ProxySettings = ps
|
||||
}
|
||||
|
||||
if c.MuxSettings != nil && c.MuxSettings.Enabled {
|
||||
senderSettings.MultiplexSettings = &proxyman.MultiplexingConfig{
|
||||
Enabled: true,
|
||||
Concurrency: uint32(c.MuxSettings.GetConcurrency()),
|
||||
}
|
||||
if c.MuxSettings != nil {
|
||||
senderSettings.MultiplexSettings = c.MuxSettings.Build()
|
||||
}
|
||||
|
||||
settings := []byte("{}")
|
||||
|
|
|
@ -2,15 +2,16 @@ package conf_test
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"v2ray.com/core"
|
||||
"v2ray.com/core/app/dispatcher"
|
||||
"v2ray.com/core/app/log"
|
||||
"v2ray.com/core/app/proxyman"
|
||||
"v2ray.com/core/app/router"
|
||||
"v2ray.com/core/common"
|
||||
clog "v2ray.com/core/common/log"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol"
|
||||
|
@ -337,3 +338,34 @@ func TestV2RayConfig(t *testing.T) {
|
|||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestMuxConfig_Build(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fields string
|
||||
want *proxyman.MultiplexingConfig
|
||||
}{
|
||||
{"default", `{"enabled": true, "concurrency": 16}`, &proxyman.MultiplexingConfig{
|
||||
Enabled: true,
|
||||
Concurrency: 16,
|
||||
}},
|
||||
{"empty def", `{}`, &proxyman.MultiplexingConfig{
|
||||
Enabled: false,
|
||||
Concurrency: 8,
|
||||
}},
|
||||
{"not enable", `{"enabled": false, "concurrency": 4}`, &proxyman.MultiplexingConfig{
|
||||
Enabled: false,
|
||||
Concurrency: 4,
|
||||
}},
|
||||
{"forbidden", `{"enabled": false, "concurrency": -1}`, nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &MuxConfig{}
|
||||
common.Must(json.Unmarshal([]byte(tt.fields), m))
|
||||
if got := m.Build(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("MuxConfig.Build() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue