diff --git a/app/dns/hosts.go b/app/dns/hosts.go index 7c398319..aff9521a 100644 --- a/app/dns/hosts.go +++ b/app/dns/hosts.go @@ -9,7 +9,7 @@ import ( // StaticHosts represents static domain-ip mapping in DNS server. type StaticHosts struct { - ips [][]net.IP + ips [][]net.Address matchers *strmatcher.MatcherGroup } @@ -36,7 +36,7 @@ func toStrMatcher(t DomainMatchingType, domain string) (strmatcher.Matcher, erro func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDomain) (*StaticHosts, error) { g := new(strmatcher.MatcherGroup) sh := &StaticHosts{ - ips: make([][]net.IP, len(hosts)+len(legacy)+16), + ips: make([][]net.Address, len(hosts)+len(legacy)+16), matchers: g, } @@ -50,10 +50,10 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma address := ip.AsAddress() if address.Family().IsDomain() { - return nil, newError("ignoring domain address in static hosts: ", address.Domain()).AtWarning() + return nil, newError("invalid domain address in static hosts: ", address.Domain()).AtWarning() } - sh.ips[id] = []net.IP{address.IP()} + sh.ips[id] = []net.Address{address} } } @@ -63,9 +63,13 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma return nil, newError("failed to create domain matcher").Base(err) } id := g.Add(matcher) - ips := make([]net.IP, len(mapping.Ip)) - for idx, ip := range mapping.Ip { - ips[idx] = net.IP(ip) + ips := make([]net.Address, 0, len(mapping.Ip)) + for _, ip := range mapping.Ip { + addr := net.IPAddress(ip) + if addr == nil { + return nil, newError("invalid IP address in static hosts: ", ip).AtWarning() + } + ips = append(ips, addr) } sh.ips[id] = ips } @@ -73,12 +77,11 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma return sh, nil } -func filterIP(ips []net.IP, option IPOption) []net.IP { +func filterIP(ips []net.Address, option IPOption) []net.IP { filtered := make([]net.IP, 0, len(ips)) for _, ip := range ips { - parsed := net.IPAddress(ip) - if (parsed.Family().IsIPv4() && option.IPv4Enable) || (parsed.Family().IsIPv6() && option.IPv6Enable) { - filtered = append(filtered, parsed.IP()) + if (ip.Family().IsIPv4() && option.IPv4Enable) || (ip.Family().IsIPv6() && option.IPv6Enable) { + filtered = append(filtered, ip.IP()) } } if len(filtered) == 0 { diff --git a/app/dns/udpns.go b/app/dns/udpns.go index 70cd8c57..3a246886 100644 --- a/app/dns/udpns.go +++ b/app/dns/udpns.go @@ -20,7 +20,7 @@ import ( ) type IPRecord struct { - IP net.IP + IP net.Address Expire time.Time } @@ -149,7 +149,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, payload *buf.Buf break } ips = append(ips, IPRecord{ - IP: net.IP(ans.A[:]), + IP: net.IPAddress(ans.A[:]), Expire: now.Add(time.Duration(ttl) * time.Second), }) case dnsmessage.TypeAAAA: @@ -159,7 +159,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, payload *buf.Buf break } ips = append(ips, IPRecord{ - IP: net.IP(ans.AAAA[:]), + IP: net.IPAddress(ans.AAAA[:]), Expire: now.Add(time.Duration(ttl) * time.Second), }) default: @@ -323,7 +323,7 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) []n s.RUnlock() if found && len(records) > 0 { - var ips []net.IP + var ips []net.Address now := time.Now() for _, rec := range records { if rec.Expire.After(now) {