clean up dns package

pull/714/merge
Darien Raymond 2017-11-15 00:36:14 +01:00
parent a430e2065a
commit 0dbfb66126
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
16 changed files with 225 additions and 79 deletions

View File

@ -2,12 +2,14 @@ package server
import (
"context"
"fmt"
"sync"
"time"
"github.com/miekg/dns"
"v2ray.com/core/app/dispatcher"
"v2ray.com/core/app/log"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/dice"
"v2ray.com/core/common/net"
@ -15,7 +17,6 @@ import (
)
const (
DefaultTTL = uint32(3600)
CleanupInterval = time.Second * 120
CleanupThreshold = 512
)
@ -55,7 +56,6 @@ func NewUDPNameServer(address net.Destination, dispatcher dispatcher.Interface)
return s
}
// Private: Visible for testing.
func (v *UDPNameServer) Cleanup() {
expiredRequests := make([]uint16, 0, 16)
now := time.Now()
@ -70,10 +70,8 @@ func (v *UDPNameServer) Cleanup() {
delete(v.requests, id)
}
v.Unlock()
expiredRequests = nil
}
// Private: Visible for testing.
func (v *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
var id uint16
v.Lock()
@ -98,7 +96,6 @@ func (v *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
return id
}
// Private: Visible for testing.
func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
msg := new(dns.Msg)
err := msg.Unpack(payload.Bytes())
@ -110,8 +107,8 @@ func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
IPs: make([]net.IP, 0, 16),
}
id := msg.Id
ttl := DefaultTTL
log.Trace(newError("handling response for id ", id, " content: ", msg.String()).AtDebug())
ttl := uint32(3600) // an hour
log.Trace(newError("handling response for id ", id, " content: ", msg).AtDebug())
v.Lock()
request, found := v.requests[id]
@ -126,6 +123,7 @@ func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
switch rr := rr.(type) {
case *dns.A:
record.IPs = append(record.IPs, rr.A)
fmt.Println("Adding ans:", rr.A)
if rr.Hdr.Ttl < ttl {
ttl = rr.Hdr.Ttl
}
@ -152,13 +150,18 @@ func (v *UDPNameServer) BuildQueryA(domain string, id uint16) *buf.Buffer {
Name: dns.Fqdn(domain),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
{
Name: dns.Fqdn(domain),
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}}
buffer := buf.New()
buffer.AppendSupplier(func(b []byte) (int, error) {
common.Must(buffer.Reset(func(b []byte) (int, error) {
writtenBuffer, err := msg.PackBuffer(b)
return len(writtenBuffer), err
})
}))
return buffer
}
@ -167,7 +170,7 @@ func (v *UDPNameServer) QueryA(domain string) <-chan *ARecord {
response := make(chan *ARecord, 1)
id := v.AssignUnusedID(response)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*8)
ctx, cancel := context.WithCancel(context.Background())
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
go func() {
@ -176,11 +179,10 @@ func (v *UDPNameServer) QueryA(domain string) <-chan *ARecord {
v.Lock()
_, found := v.requests[id]
v.Unlock()
if found {
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
} else {
if !found {
break
}
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
}
cancel()
}()
@ -205,7 +207,7 @@ func (v *LocalNameServer) QueryA(domain string) <-chan *ARecord {
response <- &ARecord{
IPs: ips,
Expire: time.Now().Add(time.Second * time.Duration(DefaultTTL)),
Expire: time.Now().Add(time.Hour),
}
}()

View File

@ -1,20 +0,0 @@
package server
import (
"time"
"v2ray.com/core/common/net"
)
type IPResult struct {
IP []net.IP
TTL time.Duration
}
type Querier interface {
QueryDomain(domain string) <-chan *IPResult
}
type UDPQuerier struct {
server net.Destination
}

View File

@ -21,22 +21,22 @@ const (
)
type DomainRecord struct {
A *ARecord
}
type Record struct {
IP []net.IP
Expire time.Time
LastAccess time.Time
}
func (r *Record) Expired() bool {
func (r *DomainRecord) Expired() bool {
return r.Expire.Before(time.Now())
}
func (r *DomainRecord) Inactive() bool {
now := time.Now()
return r.Expire.Before(now) || r.LastAccess.Add(time.Hour).Before(now)
return r.Expire.Before(now) || r.LastAccess.Add(time.Minute*5).Before(now)
}
type CacheServer struct {
sync.RWMutex
sync.Mutex
hosts map[string]net.IP
records map[string]*DomainRecord
servers []NameServer
@ -90,15 +90,33 @@ func (*CacheServer) Start() error {
func (*CacheServer) Close() {}
func (s *CacheServer) GetCached(domain string) []net.IP {
s.RLock()
defer s.RUnlock()
s.Lock()
defer s.Unlock()
if record, found := s.records[domain]; found && record.A.Expire.After(time.Now()) {
return record.A.IPs
if record, found := s.records[domain]; found && !record.Expired() {
record.LastAccess = time.Now()
return record.IP
}
return nil
}
func (s *CacheServer) tryCleanup() {
s.Lock()
defer s.Unlock()
if len(s.records) > 256 {
domains := make([]string, 0, 256)
for d, r := range s.records {
if r.Expired() {
domains = append(domains, d)
}
}
for _, d := range domains {
delete(s.records, d)
}
}
}
func (s *CacheServer) Get(domain string) []net.IP {
if ip, found := s.hosts[domain]; found {
return []net.IP{ip}
@ -110,6 +128,8 @@ func (s *CacheServer) Get(domain string) []net.IP {
return ips
}
s.tryCleanup()
for _, server := range s.servers {
response := server.QueryA(domain)
select {
@ -119,7 +139,9 @@ func (s *CacheServer) Get(domain string) []net.IP {
}
s.Lock()
s.records[domain] = &DomainRecord{
A: a,
IP: a.IPs,
Expire: a.Expire,
LastAccess: time.Now(),
}
s.Unlock()
log.Trace(newError("returning ", len(a.IPs), " IPs for domain ", domain).AtDebug())

View File

@ -0,0 +1,102 @@
package server_test
import (
"context"
"testing"
"v2ray.com/core/app"
"v2ray.com/core/app/dispatcher"
_ "v2ray.com/core/app/dispatcher/impl"
. "v2ray.com/core/app/dns"
_ "v2ray.com/core/app/dns/server"
"v2ray.com/core/app/proxyman"
_ "v2ray.com/core/app/proxyman/outbound"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/proxy/freedom"
"v2ray.com/core/testing/servers/udp"
. "v2ray.com/ext/assert"
"github.com/miekg/dns"
)
type staticHandler struct {
}
func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
ans := new(dns.Msg)
ans.Id = r.Id
for _, q := range r.Question {
if q.Name == "google.com." && q.Qtype == dns.TypeA {
rr, _ := dns.NewRR("google.com. IN A 8.8.8.8")
ans.Answer = append(ans.Answer, rr)
} else if q.Name == "facebook.com." && q.Qtype == dns.TypeA {
rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9")
ans.Answer = append(ans.Answer, rr)
}
}
w.WriteMsg(ans)
}
func TestUDPServer(t *testing.T) {
assert := With(t)
port := udp.PickPort()
dnsServer := dns.Server{
Addr: "127.0.0.1:" + port.String(),
Net: "udp",
Handler: &staticHandler{},
UDPSize: 1200,
}
go dnsServer.ListenAndServe()
config := &Config{
NameServers: []*net.Endpoint{
{
Network: net.Network_UDP,
Address: &net.IPOrDomain{
Address: &net.IPOrDomain_Ip{
Ip: []byte{127, 0, 0, 1},
},
},
Port: uint32(port),
},
},
}
ctx := context.Background()
space := app.NewSpace()
ctx = app.ContextWithSpace(ctx, space)
common.Must(app.AddApplicationToSpace(ctx, config))
common.Must(app.AddApplicationToSpace(ctx, &dispatcher.Config{}))
common.Must(app.AddApplicationToSpace(ctx, &proxyman.OutboundConfig{}))
om := proxyman.OutboundHandlerManagerFromSpace(space)
om.AddHandler(ctx, &proxyman.OutboundHandlerConfig{
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
})
common.Must(space.Initialize())
common.Must(space.Start())
server := FromSpace(space)
assert(server, IsNotNil)
ips := server.Get("google.com")
assert(len(ips), Equals, 1)
assert([]byte(ips[0]), Equals, []byte{8, 8, 8, 8})
ips = server.Get("facebook.com")
assert(len(ips), Equals, 1)
assert([]byte(ips[0]), Equals, []byte{9, 9, 9, 9})
dnsServer.Shutdown()
ips = server.Get("google.com")
assert(len(ips), Equals, 1)
assert([]byte(ips[0]), Equals, []byte{8, 8, 8, 8})
}

View File

@ -53,8 +53,7 @@ func (t *ActivityTimer) run() {
}
}
func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.Context, *ActivityTimer) {
ctx, cancel := context.WithCancel(ctx)
func CancelAfterInactivity(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) *ActivityTimer {
timer := &ActivityTimer{
ctx: ctx,
cancel: cancel,
@ -63,5 +62,5 @@ func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.
}
timer.timeout <- timeout
go timer.run()
return ctx, timer
return timer
}

View File

@ -13,7 +13,8 @@ import (
func TestActivityTimer(t *testing.T) {
assert := With(t)
ctx, timer := CancelAfterInactivity(context.Background(), time.Second*5)
ctx, cancel := context.WithCancel(context.Background())
timer := CancelAfterInactivity(ctx, cancel, time.Second*5)
time.Sleep(time.Second * 6)
assert(ctx.Err(), IsNotNil)
runtime.KeepAlive(timer)
@ -22,7 +23,8 @@ func TestActivityTimer(t *testing.T) {
func TestActivityTimerUpdate(t *testing.T) {
assert := With(t)
ctx, timer := CancelAfterInactivity(context.Background(), time.Second*10)
ctx, cancel := context.WithCancel(context.Background())
timer := CancelAfterInactivity(ctx, cancel, time.Second*10)
time.Sleep(time.Second * 3)
assert(ctx.Err(), IsNil)
timer.SetTimeout(time.Second * 1)

View File

@ -64,7 +64,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
if timeout == 0 {
timeout = time.Minute * 5
}
ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
inboundRay, err := dispatcher.Dispatch(ctx, dest)
if err != nil {

View File

@ -107,7 +107,8 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
if timeout == 0 {
timeout = time.Minute * 5
}
ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
requestDone := signal.ExecuteAsync(func() error {
var writer buf.Writer

View File

@ -153,7 +153,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
if timeout == 0 {
timeout = time.Minute * 5
}
ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
ray, err := dispatcher.Dispatch(ctx, dest)
if err != nil {
return err

View File

@ -90,7 +90,8 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
request.Option |= RequestOptionOneTimeAuth
}
ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*5)
if request.Command == protocol.RequestCommandTCP {
bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))

View File

@ -146,7 +146,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
ctx = protocol.ContextWithUser(ctx, request.User)
userSettings := s.user.GetSettings()
ctx, timer := signal.CancelAfterInactivity(ctx, userSettings.PayloadTimeout)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, userSettings.PayloadTimeout)
ray, err := dispatcher.Dispatch(ctx, dest)
if err != nil {
return err

View File

@ -83,7 +83,8 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy.
return newError("failed to establish connection to server").AtWarning().Base(err)
}
ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*5)
var requestFunc func() error
var responseFunc func() error

View File

@ -107,7 +107,8 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
if timeout == 0 {
timeout = time.Minute * 5
}
ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
ray, err := dispatcher.Dispatch(ctx, dest)
if err != nil {

View File

@ -204,7 +204,8 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i
ctx = protocol.ContextWithUser(ctx, request.User)
ctx, timer := signal.CancelAfterInactivity(ctx, userSettings.PayloadTimeout)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, userSettings.PayloadTimeout)
ray, err := dispatcher.Dispatch(ctx, request.Destination())
if err != nil {
return newError("failed to dispatch request to ", request.Destination()).Base(err)

View File

@ -103,7 +103,8 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
session := encoding.NewClientSession(protocol.DefaultIDHash)
ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*5)
requestDone := signal.ExecuteAsync(func() error {
writer := buf.NewBufferedWriter(buf.NewWriter(conn))

View File

@ -3,25 +3,33 @@ package udp
import (
"context"
"sync"
"time"
"v2ray.com/core/app/dispatcher"
"v2ray.com/core/app/log"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/net"
"v2ray.com/core/common/signal"
"v2ray.com/core/transport/ray"
)
type ResponseCallback func(payload *buf.Buffer)
type connEntry struct {
inbound ray.InboundRay
timer signal.ActivityUpdater
cancel context.CancelFunc
}
type Dispatcher struct {
sync.RWMutex
conns map[net.Destination]ray.InboundRay
conns map[net.Destination]*connEntry
dispatcher dispatcher.Interface
}
func NewDispatcher(dispatcher dispatcher.Interface) *Dispatcher {
return &Dispatcher{
conns: make(map[net.Destination]ray.InboundRay),
conns: make(map[net.Destination]*connEntry),
dispatcher: dispatcher,
}
}
@ -30,51 +38,72 @@ func (v *Dispatcher) RemoveRay(dest net.Destination) {
v.Lock()
defer v.Unlock()
if conn, found := v.conns[dest]; found {
conn.InboundInput().Close()
conn.InboundOutput().Close()
conn.inbound.InboundInput().Close()
conn.inbound.InboundOutput().Close()
delete(v.conns, dest)
}
}
func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (ray.InboundRay, bool) {
func (v *Dispatcher) getInboundRay(dest net.Destination, callback ResponseCallback) *connEntry {
v.Lock()
defer v.Unlock()
if entry, found := v.conns[dest]; found {
return entry, true
return entry
}
log.Trace(newError("establishing new connection for ", dest))
ctx, cancel := context.WithCancel(context.Background())
removeRay := func() {
cancel()
v.RemoveRay(dest)
}
timer := signal.CancelAfterInactivity(ctx, removeRay, time.Second*4)
inboundRay, _ := v.dispatcher.Dispatch(ctx, dest)
v.conns[dest] = inboundRay
return inboundRay, false
entry := &connEntry{
inbound: inboundRay,
timer: timer,
cancel: removeRay,
}
v.conns[dest] = entry
go handleInput(ctx, entry, callback)
return entry
}
func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer, callback ResponseCallback) {
// TODO: Add user to destString
log.Trace(newError("dispatch request to: ", destination).AtDebug())
inboundRay, existing := v.getInboundRay(ctx, destination)
outputStream := inboundRay.InboundInput()
conn := v.getInboundRay(destination, callback)
outputStream := conn.inbound.InboundInput()
if outputStream != nil {
if err := outputStream.WriteMultiBuffer(buf.NewMultiBufferValue(payload)); err != nil {
v.RemoveRay(destination)
log.Trace(newError("failed to write first UDP payload").Base(err))
conn.cancel()
return
}
}
if !existing {
go func() {
handleInput(inboundRay.InboundOutput(), callback)
v.RemoveRay(destination)
}()
}
}
func handleInput(input ray.InputStream, callback ResponseCallback) {
func handleInput(ctx context.Context, conn *connEntry, callback ResponseCallback) {
input := conn.inbound.InboundOutput()
timer := conn.timer
for {
select {
case <-ctx.Done():
return
default:
}
mb, err := input.ReadMultiBuffer()
if err != nil {
break
log.Trace(newError("failed to handl UDP input").Base(err))
conn.cancel()
return
}
timer.Update()
for _, b := range mb {
callback(b)
}