diff --git a/app/dns/dns.go b/app/dns/dns.go index 54191926..de499a40 100644 --- a/app/dns/dns.go +++ b/app/dns/dns.go @@ -68,6 +68,11 @@ func (this *DnsCache) cleanup() { } func (this *DnsCache) Add(context app.Context, domain string, ip net.IP) { + callerTag := context.CallerTag() + if !this.config.IsTrustedSource(callerTag) { + return + } + this.RLock() entry, found := this.cache[domain] this.RUnlock() diff --git a/app/dns/dns_test.go b/app/dns/dns_test.go index f466a70d..698afa10 100644 --- a/app/dns/dns_test.go +++ b/app/dns/dns_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/v2ray/v2ray-core/app/dns" + dnstesting "github.com/v2ray/v2ray-core/app/dns/testing" apptesting "github.com/v2ray/v2ray-core/app/testing" netassert "github.com/v2ray/v2ray-core/common/net/testing/assert" v2testing "github.com/v2ray/v2ray-core/testing" @@ -14,11 +15,19 @@ func TestDnsAdd(t *testing.T) { v2testing.Current(t) domain := "v2ray.com" - cache := dns.NewCache(nil) + cache := dns.NewCache(&dnstesting.CacheConfig{ + TrustedTags: map[string]bool{ + "testtag": true, + }, + }) ip := cache.Get(&apptesting.Context{}, domain) netassert.IP(ip).IsNil() - cache.Add(&apptesting.Context{}, domain, []byte{1, 2, 3, 4}) + cache.Add(&apptesting.Context{CallerTagValue: "notvalidtag"}, domain, []byte{1, 2, 3, 4}) + ip = cache.Get(&apptesting.Context{}, domain) + netassert.IP(ip).IsNil() + + cache.Add(&apptesting.Context{CallerTagValue: "testtag"}, domain, []byte{1, 2, 3, 4}) ip = cache.Get(&apptesting.Context{}, domain) netassert.IP(ip).Equals(net.IP([]byte{1, 2, 3, 4})) } diff --git a/app/dns/testing/config.go b/app/dns/testing/config.go new file mode 100644 index 00000000..e1c51e98 --- /dev/null +++ b/app/dns/testing/config.go @@ -0,0 +1,10 @@ +package testing + +type CacheConfig struct { + TrustedTags map[string]bool +} + +func (this *CacheConfig) IsTrustedSource(tag string) bool { + _, found := this.TrustedTags[tag] + return found +}