From fcfb0a302a2d167211153f1bac1d57efeea8ff9d Mon Sep 17 00:00:00 2001 From: Meow <197331664+Meo597@users.noreply.github.com> Date: Fri, 21 Nov 2025 10:54:01 +0800 Subject: [PATCH] perf(GeoIPMatcher): faster heuristic matching with reduced memory usage (#5289) --- app/dns/nameserver.go | 38 +- app/router/condition.go | 69 +- app/router/condition_geoip.go | 1005 +++++++++++++++++++++++++--- app/router/condition_geoip_test.go | 71 +- app/router/condition_test.go | 2 +- app/router/config.go | 14 +- 6 files changed, 996 insertions(+), 203 deletions(-) diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index cf1b665b..3f025d74 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -29,8 +29,8 @@ type Client struct { server Server skipFallback bool domains []string - expectedIPs []*router.GeoIPMatcher - unexpectedIPs []*router.GeoIPMatcher + expectedIPs router.GeoIPMatcher + unexpectedIPs router.GeoIPMatcher actPrior bool actUnprior bool tag string @@ -154,23 +154,21 @@ func NewClient( } // Establish expected IPs - var expectedMatchers []*router.GeoIPMatcher - for _, geoip := range ns.ExpectedGeoip { - matcher, err := router.GlobalGeoIPContainer.Add(geoip) + var expectedMatcher router.GeoIPMatcher + if len(ns.ExpectedGeoip) > 0 { + expectedMatcher, err = router.BuildOptimizedGeoIPMatcher(ns.ExpectedGeoip...) if err != nil { return errors.New("failed to create expected ip matcher").Base(err).AtWarning() } - expectedMatchers = append(expectedMatchers, matcher) } // Establish unexpected IPs - var unexpectedMatchers []*router.GeoIPMatcher - for _, geoip := range ns.UnexpectedGeoip { - matcher, err := router.GlobalGeoIPContainer.Add(geoip) + var unexpectedMatcher router.GeoIPMatcher + if len(ns.UnexpectedGeoip) > 0 { + unexpectedMatcher, err = router.BuildOptimizedGeoIPMatcher(ns.UnexpectedGeoip...) if err != nil { return errors.New("failed to create unexpected ip matcher").Base(err).AtWarning() } - unexpectedMatchers = append(unexpectedMatchers, matcher) } if len(clientIP) > 0 { @@ -192,8 +190,8 @@ func NewClient( client.server = server client.skipFallback = ns.SkipFallback client.domains = rules - client.expectedIPs = expectedMatchers - client.unexpectedIPs = unexpectedMatchers + client.expectedIPs = expectedMatcher + client.unexpectedIPs = unexpectedMatcher client.actPrior = ns.ActPrior client.actUnprior = ns.ActUnprior client.tag = tag @@ -243,32 +241,32 @@ func (c *Client) QueryIP(ctx context.Context, domain string, option dns.IPOption return nil, 0, dns.ErrEmptyResponse } - if len(c.expectedIPs) > 0 && !c.actPrior { - ips = router.MatchIPs(c.expectedIPs, ips, false) + if c.expectedIPs != nil && !c.actPrior { + ips, _ = c.expectedIPs.FilterIPs(ips) errors.LogDebug(context.Background(), "domain ", domain, " expectedIPs ", ips, " matched at server ", c.Name()) if len(ips) == 0 { return nil, 0, dns.ErrEmptyResponse } } - if len(c.unexpectedIPs) > 0 && !c.actUnprior { - ips = router.MatchIPs(c.unexpectedIPs, ips, true) + if c.unexpectedIPs != nil && !c.actUnprior { + _, ips = c.unexpectedIPs.FilterIPs(ips) errors.LogDebug(context.Background(), "domain ", domain, " unexpectedIPs ", ips, " matched at server ", c.Name()) if len(ips) == 0 { return nil, 0, dns.ErrEmptyResponse } } - if len(c.expectedIPs) > 0 && c.actPrior { - ipsNew := router.MatchIPs(c.expectedIPs, ips, false) + if c.expectedIPs != nil && c.actPrior { + ipsNew, _ := c.expectedIPs.FilterIPs(ips) if len(ipsNew) > 0 { ips = ipsNew errors.LogDebug(context.Background(), "domain ", domain, " priorIPs ", ips, " matched at server ", c.Name()) } } - if len(c.unexpectedIPs) > 0 && c.actUnprior { - ipsNew := router.MatchIPs(c.unexpectedIPs, ips, true) + if c.unexpectedIPs != nil && c.actUnprior { + _, ipsNew := c.unexpectedIPs.FilterIPs(ips) if len(ipsNew) > 0 { ips = ipsNew errors.LogDebug(context.Background(), "domain ", domain, " unpriorIPs ", ips, " matched at server ", c.Name()) diff --git a/app/router/condition.go b/app/router/condition.go index 4127ca47..c8cf4e8d 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -96,61 +96,53 @@ func (m *DomainMatcher) Apply(ctx routing.Context) bool { return m.ApplyDomain(domain) } -type MultiGeoIPMatcher struct { - matchers []*GeoIPMatcher - asType string // local, source, target +type MatcherAsType byte + +const ( + MatcherAsType_Local MatcherAsType = iota + MatcherAsType_Source + MatcherAsType_Target + MatcherAsType_VlessRoute // for port +) + +type IPMatcher struct { + matcher GeoIPMatcher + asType MatcherAsType } -func NewMultiGeoIPMatcher(geoips []*GeoIP, asType string) (*MultiGeoIPMatcher, error) { - var matchers []*GeoIPMatcher - for _, geoip := range geoips { - matcher, err := GlobalGeoIPContainer.Add(geoip) - if err != nil { - return nil, err - } - matchers = append(matchers, matcher) +func NewIPMatcher(geoips []*GeoIP, asType MatcherAsType) (*IPMatcher, error) { + matcher, err := BuildOptimizedGeoIPMatcher(geoips...) + if err != nil { + return nil, err } - - matcher := &MultiGeoIPMatcher{ - matchers: matchers, - asType: asType, - } - - return matcher, nil + return &IPMatcher{matcher: matcher, asType: asType}, nil } // Apply implements Condition. -func (m *MultiGeoIPMatcher) Apply(ctx routing.Context) bool { +func (m *IPMatcher) Apply(ctx routing.Context) bool { var ips []net.IP switch m.asType { - case "local": + case MatcherAsType_Local: ips = ctx.GetLocalIPs() - case "source": + case MatcherAsType_Source: ips = ctx.GetSourceIPs() - case "target": + case MatcherAsType_Target: ips = ctx.GetTargetIPs() default: - panic("unreachable, asType should be local or source or target") + panic("unk asType") } - for _, ip := range ips { - for _, matcher := range m.matchers { - if matcher.Match(ip) { - return true - } - } - } - return false + return m.matcher.AnyMatch(ips) } type PortMatcher struct { port net.MemoryPortList - asType string // local, source, target + asType MatcherAsType } // NewPortMatcher create a new port matcher that can match source or local or destination port -func NewPortMatcher(list *net.PortList, asType string) *PortMatcher { +func NewPortMatcher(list *net.PortList, asType MatcherAsType) *PortMatcher { return &PortMatcher{ port: net.PortListFromProto(list), asType: asType, @@ -160,18 +152,17 @@ func NewPortMatcher(list *net.PortList, asType string) *PortMatcher { // Apply implements Condition. func (v *PortMatcher) Apply(ctx routing.Context) bool { switch v.asType { - case "local": + case MatcherAsType_Local: return v.port.Contains(ctx.GetLocalPort()) - case "source": + case MatcherAsType_Source: return v.port.Contains(ctx.GetSourcePort()) - case "target": + case MatcherAsType_Target: return v.port.Contains(ctx.GetTargetPort()) - case "vlessRoute": + case MatcherAsType_VlessRoute: return v.port.Contains(ctx.GetVlessRoute()) default: - panic("unreachable, asType should be local or source or target") + panic("unk asType") } - } type NetworkMatcher struct { diff --git a/app/router/condition_geoip.go b/app/router/condition_geoip.go index 38f7f0ce..823548cf 100644 --- a/app/router/condition_geoip.go +++ b/app/router/condition_geoip.go @@ -1,144 +1,961 @@ package router import ( + "context" "net/netip" - "strconv" + "sort" + "strings" + "sync" + "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "go4.org/netipx" ) -type GeoIPMatcher struct { - countryCode string - reverseMatch bool - ip4 *netipx.IPSet - ip6 *netipx.IPSet +type GeoIPMatcher interface { + // TODO: (PERF) all net.IP -> netipx.Addr + + // Invalid IP always return false. + Match(ip net.IP) bool + + // Returns true if *any* IP is valid and match. + AnyMatch(ips []net.IP) bool + + // Returns true only if *all* IPs are valid and match. Any invalid IP, or non-matching valid IP, causes false. + Matches(ips []net.IP) bool + + // Filters IPs. Invalid IPs are silently dropped and not included in either result. + FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) + + ToggleReverse() + + SetReverse(reverse bool) } -func (m *GeoIPMatcher) Init(cidrs []*CIDR) error { - var builder4, builder6 netipx.IPSetBuilder +type GeoIPSet struct { + ipv4, ipv6 *netipx.IPSet + max4, max6 uint8 +} - for _, cidr := range cidrs { - ip := net.IP(cidr.GetIp()) - ipPrefixString := ip.String() + "/" + strconv.Itoa(int(cidr.GetPrefix())) - ipPrefix, err := netip.ParsePrefix(ipPrefixString) - if err != nil { - return err +type HeuristicGeoIPMatcher struct { + ipset *GeoIPSet + reverse bool +} + +type ipBucket struct { + rep netip.Addr + ips []net.IP +} + +// Match implements GeoIPMatcher. +func (m *HeuristicGeoIPMatcher) Match(ip net.IP) bool { + ipx, ok := netipx.FromStdIP(ip) + if !ok { + return false + } + return m.matchAddr(ipx) +} + +func (m *HeuristicGeoIPMatcher) matchAddr(ipx netip.Addr) bool { + if ipx.Is4() { + return m.ipset.ipv4.Contains(ipx) != m.reverse + } + if ipx.Is6() { + return m.ipset.ipv6.Contains(ipx) != m.reverse + } + return false +} + +// AnyMatch implements GeoIPMatcher. +func (m *HeuristicGeoIPMatcher) AnyMatch(ips []net.IP) bool { + n := len(ips) + if n == 0 { + return false + } + + if n == 1 { + return m.Match(ips[0]) + } + + heur4 := m.ipset.max4 <= 24 + heur6 := m.ipset.max6 <= 64 + if !heur4 && !heur6 { + for _, ip := range ips { + if ipx, ok := netipx.FromStdIP(ip); ok { + if m.matchAddr(ipx) { + return true + } + } + } + return false + } + + buckets := make(map[[9]byte]struct{}, n) + for _, ip := range ips { + key, ok := prefixKeyFromIP(ip) + if !ok { + continue + } + heur := (key[0] == 4 && heur4) || (key[0] == 6 && heur6) + if heur { + if _, exists := buckets[key]; exists { + continue + } + } + ipx, ok := netipx.FromStdIP(ip) + if !ok { + continue + } + if m.matchAddr(ipx) { + return true + } + if heur { + buckets[key] = struct{}{} + } + } + return false +} + +// Matches implements GeoIPMatcher. +func (m *HeuristicGeoIPMatcher) Matches(ips []net.IP) bool { + n := len(ips) + if n == 0 { + return false + } + + if n == 1 { + return m.Match(ips[0]) + } + + heur4 := m.ipset.max4 <= 24 + heur6 := m.ipset.max6 <= 64 + if !heur4 && !heur6 { + for _, ip := range ips { + ipx, ok := netipx.FromStdIP(ip) + if !ok { + return false + } + if !m.matchAddr(ipx) { + return false + } + } + return true + } + + buckets := make(map[[9]byte]netip.Addr, n) + precise := make([]netip.Addr, 0, n) + + for _, ip := range ips { + key, ok := prefixKeyFromIP(ip) + if !ok { + return false } - switch len(ip) { - case net.IPv4len: - builder4.AddPrefix(ipPrefix) - case net.IPv6len: - builder6.AddPrefix(ipPrefix) + if (key[0] == 4 && heur4) || (key[0] == 6 && heur6) { + if _, exists := buckets[key]; !exists { + ipx, ok := netipx.FromStdIP(ip) + if !ok { + return false + } + buckets[key] = ipx + } + } else { + ipx, ok := netipx.FromStdIP(ip) + if !ok { + return false + } + precise = append(precise, ipx) } } - if ip4, err := builder4.IPSet(); err != nil { - return err - } else { - m.ip4 = ip4 + for _, ipx := range buckets { + if !m.matchAddr(ipx) { + return false + } } - - if ip6, err := builder6.IPSet(); err != nil { - return err - } else { - m.ip6 = ip6 + for _, ipx := range precise { + if !m.matchAddr(ipx) { + return false + } } - - return nil + return true } -func (m *GeoIPMatcher) SetReverseMatch(isReverseMatch bool) { - m.reverseMatch = isReverseMatch +func prefixKeyFromIP(ip net.IP) (key [9]byte, ok bool) { + if ip4 := ip.To4(); ip4 != nil { + key[0] = 4 + key[1] = ip4[0] + key[2] = ip4[1] + key[3] = ip4[2] // /24 + return key, true + } + if ip16 := ip.To16(); ip16 != nil { + key[0] = 6 + key[1] = ip16[0] + key[2] = ip16[1] + key[3] = ip16[2] + key[4] = ip16[3] + key[5] = ip16[4] + key[6] = ip16[5] + key[7] = ip16[6] + key[8] = ip16[7] // /64 + return key, true + } + return key, false // illegal } -func (m *GeoIPMatcher) match4(ip net.IP) bool { - nip, ok := netipx.FromStdIP(ip) +// FilterIPs implements GeoIPMatcher. +func (m *HeuristicGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) { + n := len(ips) + if n == 0 { + return []net.IP{}, []net.IP{} + } + + if n == 1 { + ipx, ok := netipx.FromStdIP(ips[0]) + if !ok { + return []net.IP{}, []net.IP{} + } + if m.matchAddr(ipx) { + return ips, []net.IP{} + } + return []net.IP{}, ips + } + + heur4 := m.ipset.max4 <= 24 + heur6 := m.ipset.max6 <= 64 + if !heur4 && !heur6 { + matched = make([]net.IP, 0, n) + unmatched = make([]net.IP, 0, n) + for _, ip := range ips { + ipx, ok := netipx.FromStdIP(ip) + if !ok { + continue // illegal ip, ignore + } + if m.matchAddr(ipx) { + matched = append(matched, ip) + } else { + unmatched = append(unmatched, ip) + } + } + return + } + + buckets := make(map[[9]byte]*ipBucket, n) + precise := make([]net.IP, 0, n) + + for _, ip := range ips { + key, ok := prefixKeyFromIP(ip) + if !ok { + continue // illegal ip, ignore + } + + if (key[0] == 4 && !heur4) || (key[0] == 6 && !heur6) { + precise = append(precise, ip) + continue + } + + b, exists := buckets[key] + if !exists { + // build bucket + ipx, ok := netipx.FromStdIP(ip) + if !ok { + continue // illegal ip, ignore + } + b = &ipBucket{ + rep: ipx, + ips: make([]net.IP, 0, 4), // for dns answer + } + buckets[key] = b + } + b.ips = append(b.ips, ip) + } + + matched = make([]net.IP, 0, n) + unmatched = make([]net.IP, 0, n) + for _, b := range buckets { + if m.matchAddr(b.rep) { + matched = append(matched, b.ips...) + } else { + unmatched = append(unmatched, b.ips...) + } + } + for _, ip := range precise { + ipx, ok := netipx.FromStdIP(ip) + if !ok { + continue // illegal ip, ignore + } + if m.matchAddr(ipx) { + matched = append(matched, ip) + } else { + unmatched = append(unmatched, ip) + } + } + return +} + +// ToggleReverse implements GeoIPMatcher. +func (m *HeuristicGeoIPMatcher) ToggleReverse() { + m.reverse = !m.reverse +} + +// SetReverse implements GeoIPMatcher. +func (m *HeuristicGeoIPMatcher) SetReverse(reverse bool) { + m.reverse = reverse +} + +type GeneralMultiGeoIPMatcher struct { + matchers []GeoIPMatcher +} + +// Match implements GeoIPMatcher. +func (mm *GeneralMultiGeoIPMatcher) Match(ip net.IP) bool { + for _, m := range mm.matchers { + if m.Match(ip) { + return true + } + } + return false +} + +// AnyMatch implements GeoIPMatcher. +func (mm *GeneralMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool { + for _, m := range mm.matchers { + if m.AnyMatch(ips) { + return true + } + } + return false +} + +// Matches implements GeoIPMatcher. +func (mm *GeneralMultiGeoIPMatcher) Matches(ips []net.IP) bool { + for _, m := range mm.matchers { + if m.Matches(ips) { + return true + } + } + return false +} + +// FilterIPs implements GeoIPMatcher. +func (mm *GeneralMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) { + matched = make([]net.IP, 0, len(ips)) + unmatched = ips + for _, m := range mm.matchers { + if len(unmatched) == 0 { + break + } + var mtch []net.IP + mtch, unmatched = m.FilterIPs(unmatched) + if len(mtch) > 0 { + matched = append(matched, mtch...) + } + } + return +} + +// ToggleReverse implements GeoIPMatcher. +func (mm *GeneralMultiGeoIPMatcher) ToggleReverse() { + for _, m := range mm.matchers { + m.ToggleReverse() + } +} + +// SetReverse implements GeoIPMatcher. +func (mm *GeneralMultiGeoIPMatcher) SetReverse(reverse bool) { + for _, m := range mm.matchers { + m.SetReverse(reverse) + } +} + +type HeuristicMultiGeoIPMatcher struct { + matchers []*HeuristicGeoIPMatcher +} + +// Match implements GeoIPMatcher. +func (mm *HeuristicMultiGeoIPMatcher) Match(ip net.IP) bool { + ipx, ok := netipx.FromStdIP(ip) if !ok { return false } - return m.ip4.Contains(nip) + for _, m := range mm.matchers { + if m.matchAddr(ipx) { + return true + } + } + return false } -func (m *GeoIPMatcher) match6(ip net.IP) bool { - nip, ok := netipx.FromStdIP(ip) - if !ok { +// AnyMatch implements GeoIPMatcher. +func (mm *HeuristicMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool { + n := len(ips) + if n == 0 { return false } - return m.ip6.Contains(nip) -} - -// Match returns true if the given ip is included by the GeoIP. -func (m *GeoIPMatcher) Match(ip net.IP) bool { - isMatched := false - switch len(ip) { - case net.IPv4len: - isMatched = m.match4(ip) - case net.IPv6len: - isMatched = m.match6(ip) + if n == 1 { + return mm.Match(ips[0]) } - if m.reverseMatch { - return !isMatched + + buckets := make(map[[9]byte]struct{}, n) + for _, ip := range ips { + var ipx netip.Addr + state := uint8(0) // 0 = Not initialized, 1 = Initialized, 4 = IPv4 can be skipped, 6 = IPv6 can be skipped + for _, m := range mm.matchers { + heur4 := m.ipset.max4 <= 24 + heur6 := m.ipset.max6 <= 64 + + if state == 0 && (heur4 || heur6) { + key, ok := prefixKeyFromIP(ip) + if !ok { + break + } + if _, exists := buckets[key]; exists { + state = key[0] + } else { + buckets[key] = struct{}{} + state = 1 + } + } + if (heur4 && state == 4) || (heur6 && state == 6) { + continue + } + + if !ipx.IsValid() { + nipx, ok := netipx.FromStdIP(ip) + if !ok { + break + } + ipx = nipx + } + if m.matchAddr(ipx) { + return true + } + } } - return isMatched + return false } -// GeoIPMatcherContainer is a container for GeoIPMatchers. It keeps unique copies of GeoIPMatcher by country code. -type GeoIPMatcherContainer struct { - matchers []*GeoIPMatcher +// Matches implements GeoIPMatcher. +func (mm *HeuristicMultiGeoIPMatcher) Matches(ips []net.IP) bool { + n := len(ips) + if n == 0 { + return false + } + + if n == 1 { + return mm.Match(ips[0]) + } + + var views ipViews + for _, m := range mm.matchers { + if !views.ensureForMatcher(m, ips) { + return false + } + + matched := true + if m.ipset.max4 <= 24 { + for _, ipx := range views.buckets4 { + if !m.matchAddr(ipx) { + matched = false + break + } + } + } else { + for _, ipx := range views.precise4 { + if !m.matchAddr(ipx) { + matched = false + break + } + } + } + if !matched { + continue + } + + if m.ipset.max6 <= 64 { + for _, ipx := range views.buckets6 { + if !m.matchAddr(ipx) { + matched = false + break + } + } + } else { + for _, ipx := range views.precise6 { + if !m.matchAddr(ipx) { + matched = false + break + } + } + } + if matched { + return true + } + } + return false } -// Add adds a new GeoIP set into the container. -// If the country code of GeoIP is not empty, GeoIPMatcherContainer will try to find an existing one, instead of adding a new one. -func (c *GeoIPMatcherContainer) Add(geoip *GeoIP) (*GeoIPMatcher, error) { - if len(geoip.CountryCode) > 0 { - for _, m := range c.matchers { - if m.countryCode == geoip.CountryCode && m.reverseMatch == geoip.ReverseMatch { - return m, nil +type ipViews struct { + buckets4, buckets6 map[[9]byte]netip.Addr + precise4, precise6 []netip.Addr +} + +func (v *ipViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) bool { + needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil + needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil + needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil + needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil + + if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 { + return true + } + + if needHeur4 { + v.buckets4 = make(map[[9]byte]netip.Addr, len(ips)) + } + if needHeur6 { + v.buckets6 = make(map[[9]byte]netip.Addr, len(ips)) + } + if needPrec4 { + v.precise4 = make([]netip.Addr, 0, len(ips)) + } + if needPrec6 { + v.precise6 = make([]netip.Addr, 0, len(ips)) + } + + for _, ip := range ips { + key, ok := prefixKeyFromIP(ip) + if !ok { + return false + } + + switch key[0] { + case 4: + var ipx netip.Addr + if needHeur4 { + if _, exists := v.buckets4[key]; !exists { + ipx, ok = netipx.FromStdIP(ip) + if !ok { + return false + } + v.buckets4[key] = ipx + } + } + if needPrec4 { + if !ipx.IsValid() { + ipx, ok = netipx.FromStdIP(ip) + if !ok { + return false + } + } + v.precise4 = append(v.precise4, ipx) + } + case 6: + var ipx netip.Addr + if needHeur6 { + if _, exists := v.buckets6[key]; !exists { + ipx, ok = netipx.FromStdIP(ip) + if !ok { + return false + } + v.buckets6[key] = ipx + } + } + if needPrec6 { + if !ipx.IsValid() { + ipx, ok = netipx.FromStdIP(ip) + if !ok { + return false + } + } + v.precise6 = append(v.precise6, ipx) + } + default: + return false + } + } + + return true +} + +// FilterIPs implements GeoIPMatcher. +func (mm *HeuristicMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) { + n := len(ips) + if n == 0 { + return []net.IP{}, []net.IP{} + } + + if n == 1 { + ipx, ok := netipx.FromStdIP(ips[0]) + if !ok { + return []net.IP{}, []net.IP{} + } + for _, m := range mm.matchers { + if m.matchAddr(ipx) { + return ips, []net.IP{} + } + } + return []net.IP{}, ips + } + + var views ipBucketViews + + matched = make([]net.IP, 0, n) + for _, m := range mm.matchers { + views.ensureForMatcher(m, ips) + + if m.ipset.max4 <= 24 { + for key, b := range views.buckets4 { + if b == nil { + continue + } + if m.matchAddr(b.rep) { + views.buckets4[key] = nil + matched = append(matched, b.ips...) + } + } + } else { + for ipx, ip := range views.precise4 { + if ip == nil { + continue + } + if m.matchAddr(ipx) { + views.precise4[ipx] = nil + matched = append(matched, ip) + } + } + } + + if m.ipset.max6 <= 64 { + for key, b := range views.buckets6 { + if b == nil { + continue + } + if m.matchAddr(b.rep) { + views.buckets6[key] = nil + matched = append(matched, b.ips...) + } + } + } else { + for ipx, ip := range views.precise6 { + if ip == nil { + continue + } + if m.matchAddr(ipx) { + views.precise6[ipx] = nil + matched = append(matched, ip) + } } } } - m := &GeoIPMatcher{ - countryCode: geoip.CountryCode, - reverseMatch: geoip.ReverseMatch, + unmatched = make([]net.IP, 0, n-len(matched)) + if views.buckets4 != nil { + for _, b := range views.buckets4 { + if b == nil { + continue + } + unmatched = append(unmatched, b.ips...) + } } - if err := m.Init(geoip.Cidr); err != nil { + if views.precise4 != nil { + for _, ip := range views.precise4 { + if ip == nil { + continue + } + unmatched = append(unmatched, ip) + } + } + if views.buckets6 != nil { + for _, b := range views.buckets6 { + if b == nil { + continue + } + unmatched = append(unmatched, b.ips...) + } + } + if views.precise6 != nil { + for _, ip := range views.precise6 { + if ip == nil { + continue + } + unmatched = append(unmatched, ip) + } + } + + return +} + +type ipBucketViews struct { + buckets4, buckets6 map[[9]byte]*ipBucket + precise4, precise6 map[netip.Addr]net.IP +} + +func (v *ipBucketViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) { + needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil + needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil + needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil + needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil + + if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 { + return + } + + if needHeur4 { + v.buckets4 = make(map[[9]byte]*ipBucket, len(ips)) + } + if needHeur6 { + v.buckets6 = make(map[[9]byte]*ipBucket, len(ips)) + } + if needPrec4 { + v.precise4 = make(map[netip.Addr]net.IP, len(ips)) + } + if needPrec6 { + v.precise6 = make(map[netip.Addr]net.IP, len(ips)) + } + + for _, ip := range ips { + key, ok := prefixKeyFromIP(ip) + if !ok { + continue // illegal ip, ignore + } + + switch key[0] { + case 4: + var ipx netip.Addr + if needHeur4 { + b, exists := v.buckets4[key] + if !exists { + // build bucket + ipx, ok = netipx.FromStdIP(ip) + if !ok { + continue // illegal ip, ignore + } + b = &ipBucket{ + rep: ipx, + ips: make([]net.IP, 0, 4), // for dns answer + } + v.buckets4[key] = b + } + b.ips = append(b.ips, ip) + } + if needPrec4 { + if !ipx.IsValid() { + ipx, ok = netipx.FromStdIP(ip) + if !ok { + continue // illegal ip, ignore + } + } + v.precise4[ipx] = ip + } + case 6: + var ipx netip.Addr + if needHeur6 { + b, exists := v.buckets6[key] + if !exists { + // build bucket + ipx, ok = netipx.FromStdIP(ip) + if !ok { + continue // illegal ip, ignore + } + b = &ipBucket{ + rep: ipx, + ips: make([]net.IP, 0, 4), // for dns answer + } + v.buckets6[key] = b + } + b.ips = append(b.ips, ip) + } + if needPrec6 { + if !ipx.IsValid() { + ipx, ok = netipx.FromStdIP(ip) + if !ok { + continue // illegal ip, ignore + } + } + v.precise6[ipx] = ip + } + } + } +} + +// ToggleReverse implements GeoIPMatcher. +func (mm *HeuristicMultiGeoIPMatcher) ToggleReverse() { + for _, m := range mm.matchers { + m.ToggleReverse() + } +} + +// SetReverse implements GeoIPMatcher. +func (mm *HeuristicMultiGeoIPMatcher) SetReverse(reverse bool) { + for _, m := range mm.matchers { + m.SetReverse(reverse) + } +} + +type GeoIPSetFactory struct { + sync.Mutex + shared map[string]*GeoIPSet // TODO: cleanup +} + +var ipsetFactory = GeoIPSetFactory{shared: make(map[string]*GeoIPSet)} + +func (f *GeoIPSetFactory) GetOrCreate(key string, cidrGroups [][]*CIDR) (*GeoIPSet, error) { + f.Lock() + defer f.Unlock() + + if ipset := f.shared[key]; ipset != nil { + return ipset, nil + } + + ipset, err := f.Create(cidrGroups...) + if err == nil { + f.shared[key] = ipset + } + return ipset, err +} + +func (f *GeoIPSetFactory) Create(cidrGroups ...[]*CIDR) (*GeoIPSet, error) { + var ipv4Builder, ipv6Builder netipx.IPSetBuilder + + for _, cidrGroup := range cidrGroups { + for _, cidrEntry := range cidrGroup { + ipBytes := cidrEntry.GetIp() + prefixLen := int(cidrEntry.GetPrefix()) + + addr, ok := netip.AddrFromSlice(ipBytes) + if !ok { + errors.LogError(context.Background(), "ignore invalid IP byte slice: ", ipBytes) + continue + } + + prefix := netip.PrefixFrom(addr, prefixLen) + if !prefix.IsValid() { + errors.LogError(context.Background(), "ignore created invalid prefix from addr ", addr, " and length ", prefixLen) + continue + } + + if addr.Is4() { + ipv4Builder.AddPrefix(prefix) + } else if addr.Is6() { + ipv6Builder.AddPrefix(prefix) + } + } + } + + ipv4, err := ipv4Builder.IPSet() + if err != nil { + return nil, errors.New("failed to build IPv4 set").Base(err) + } + ipv6, err := ipv6Builder.IPSet() + if err != nil { + return nil, errors.New("failed to build IPv6 set").Base(err) + } + + var max4, max6 int + + for _, p := range ipv4.Prefixes() { + if b := p.Bits(); b > max4 { + max4 = b + } + } + for _, p := range ipv6.Prefixes() { + if b := p.Bits(); b > max6 { + max6 = b + } + } + + if max4 == 0 { + max4 = 0xff + } + if max6 == 0 { + max6 = 0xff + } + + return &GeoIPSet{ipv4: ipv4, ipv6: ipv6, max4: uint8(max4), max6: uint8(max6)}, nil +} + +func BuildOptimizedGeoIPMatcher(geoips ...*GeoIP) (GeoIPMatcher, error) { + n := len(geoips) + if n == 0 { + return nil, errors.New("no geoip configs provided") + } + + var subs []*HeuristicGeoIPMatcher + pos := make([]*GeoIP, 0, n) + neg := make([]*GeoIP, 0, n/2) + + for _, geoip := range geoips { + if geoip == nil { + return nil, errors.New("geoip entry is nil") + } + if geoip.CountryCode == "" { + ipset, err := ipsetFactory.Create(geoip.Cidr) + if err != nil { + return nil, err + } + subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: geoip.ReverseMatch}) + continue + } + if !geoip.ReverseMatch { + pos = append(pos, geoip) + } else { + neg = append(neg, geoip) + } + } + + buildIPSet := func(mergeables []*GeoIP) (*GeoIPSet, error) { + n := len(mergeables) + if n == 0 { + return nil, nil + } + + sort.Slice(mergeables, func(i, j int) bool { + gi, gj := mergeables[i], mergeables[j] + return gi.CountryCode < gj.CountryCode + }) + + var sb strings.Builder + sb.Grow(n * 3) // xx, + cidrGroups := make([][]*CIDR, 0, n) + var last *GeoIP + for i, geoip := range mergeables { + if i == 0 || (geoip.CountryCode != last.CountryCode) { + last = geoip + sb.WriteString(geoip.CountryCode) + sb.WriteString(",") + cidrGroups = append(cidrGroups, geoip.Cidr) + } + } + + return ipsetFactory.GetOrCreate(sb.String(), cidrGroups) + } + + ipset, err := buildIPSet(pos) + if err != nil { return nil, err } - if len(geoip.CountryCode) > 0 { - c.matchers = append(c.matchers, m) + if ipset != nil { + subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: false}) } - return m, nil -} -var GlobalGeoIPContainer GeoIPMatcherContainer + ipset, err = buildIPSet(neg) + if err != nil { + return nil, err + } + if ipset != nil { + subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: true}) + } -func MatchIPs(matchers []*GeoIPMatcher, ips []net.IP, reverse bool) []net.IP { - if len(matchers) == 0 { - panic("GeoIP matchers should not be empty to avoid ambiguity") + switch len(subs) { + case 0: + return nil, errors.New("no valid geoip matcher") + case 1: + return subs[0], nil + default: + return &HeuristicMultiGeoIPMatcher{matchers: subs}, nil } - newIPs := make([]net.IP, 0, len(ips)) - var isFound bool - for _, ip := range ips { - isFound = false - for _, matcher := range matchers { - if matcher.Match(ip) { - isFound = true - break - } - } - if isFound && !reverse { - newIPs = append(newIPs, ip) - continue - } - if !isFound && reverse { - newIPs = append(newIPs, ip) - continue - } - } - return newIPs } diff --git a/app/router/condition_geoip_test.go b/app/router/condition_geoip_test.go index 07f40b83..b712db9e 100644 --- a/app/router/condition_geoip_test.go +++ b/app/router/condition_geoip_test.go @@ -35,33 +35,6 @@ func getAssetPath(file string) (string, error) { return path, nil } -func TestGeoIPMatcherContainer(t *testing.T) { - container := &router.GeoIPMatcherContainer{} - - m1, err := container.Add(&router.GeoIP{ - CountryCode: "CN", - }) - common.Must(err) - - m2, err := container.Add(&router.GeoIP{ - CountryCode: "US", - }) - common.Must(err) - - m3, err := container.Add(&router.GeoIP{ - CountryCode: "CN", - }) - common.Must(err) - - if m1 != m3 { - t.Error("expect same matcher for same geoip, but not") - } - - if m1 == m2 { - t.Error("expect different matcher for different geoip, but actually same") - } -} - func TestGeoIPMatcher(t *testing.T) { cidrList := []*router.CIDR{ {Ip: []byte{0, 0, 0, 0}, Prefix: 8}, @@ -80,8 +53,10 @@ func TestGeoIPMatcher(t *testing.T) { {Ip: []byte{91, 108, 4, 0}, Prefix: 16}, } - matcher := &router.GeoIPMatcher{} - common.Must(matcher.Init(cidrList)) + matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ + Cidr: cidrList, + }) + common.Must(err) testCases := []struct { Input string @@ -140,8 +115,10 @@ func TestGeoIPMatcherRegression(t *testing.T) { {Ip: []byte{98, 108, 20, 0}, Prefix: 23}, } - matcher := &router.GeoIPMatcher{} - common.Must(matcher.Init(cidrList)) + matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ + Cidr: cidrList, + }) + common.Must(err) testCases := []struct { Input string @@ -171,9 +148,11 @@ func TestGeoIPReverseMatcher(t *testing.T) { {Ip: []byte{8, 8, 8, 8}, Prefix: 32}, {Ip: []byte{91, 108, 4, 0}, Prefix: 16}, } - matcher := &router.GeoIPMatcher{} - matcher.SetReverseMatch(true) // Reverse match - common.Must(matcher.Init(cidrList)) + matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ + Cidr: cidrList, + }) + common.Must(err) + matcher.SetReverse(true) // Reverse match testCases := []struct { Input string @@ -206,8 +185,10 @@ func TestGeoIPMatcher4CN(t *testing.T) { ips, err := loadGeoIP("CN") common.Must(err) - matcher := &router.GeoIPMatcher{} - common.Must(matcher.Init(ips)) + matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ + Cidr: ips, + }) + common.Must(err) if matcher.Match([]byte{8, 8, 8, 8}) { t.Error("expect CN geoip doesn't contain 8.8.8.8, but actually does") @@ -218,8 +199,10 @@ func TestGeoIPMatcher6US(t *testing.T) { ips, err := loadGeoIP("US") common.Must(err) - matcher := &router.GeoIPMatcher{} - common.Must(matcher.Init(ips)) + matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ + Cidr: ips, + }) + common.Must(err) if !matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) { t.Error("expect US geoip contain 2001:4860:4860::8888, but actually not") @@ -254,8 +237,10 @@ func BenchmarkGeoIPMatcher4CN(b *testing.B) { ips, err := loadGeoIP("CN") common.Must(err) - matcher := &router.GeoIPMatcher{} - common.Must(matcher.Init(ips)) + matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ + Cidr: ips, + }) + common.Must(err) b.ResetTimer() @@ -268,8 +253,10 @@ func BenchmarkGeoIPMatcher6US(b *testing.B) { ips, err := loadGeoIP("US") common.Must(err) - matcher := &router.GeoIPMatcher{} - common.Must(matcher.Init(ips)) + matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ + Cidr: ips, + }) + common.Must(err) b.ResetTimer() diff --git a/app/router/condition_test.go b/app/router/condition_test.go index 7e90351a..1272aef6 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -447,7 +447,7 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) { }) } - matcher, err := NewMultiGeoIPMatcher(geoips, "target") + matcher, err := NewIPMatcher(geoips, MatcherAsType_Target) common.Must(err) ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.ParseAddress("8.8.8.8"), 80)}) diff --git a/app/router/config.go b/app/router/config.go index b7338e35..e9f0e02c 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -46,7 +46,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { } if rr.VlessRouteList != nil { - conds.Add(NewPortMatcher(rr.VlessRouteList, "vlessRoute")) + conds.Add(NewPortMatcher(rr.VlessRouteList, MatcherAsType_VlessRoute)) } if len(rr.InboundTag) > 0 { @@ -54,15 +54,15 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { } if rr.PortList != nil { - conds.Add(NewPortMatcher(rr.PortList, "target")) + conds.Add(NewPortMatcher(rr.PortList, MatcherAsType_Target)) } if rr.SourcePortList != nil { - conds.Add(NewPortMatcher(rr.SourcePortList, "source")) + conds.Add(NewPortMatcher(rr.SourcePortList, MatcherAsType_Source)) } if rr.LocalPortList != nil { - conds.Add(NewPortMatcher(rr.LocalPortList, "local")) + conds.Add(NewPortMatcher(rr.LocalPortList, MatcherAsType_Local)) } if len(rr.Networks) > 0 { @@ -70,7 +70,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { } if len(rr.Geoip) > 0 { - cond, err := NewMultiGeoIPMatcher(rr.Geoip, "target") + cond, err := NewIPMatcher(rr.Geoip, MatcherAsType_Target) if err != nil { return nil, err } @@ -78,7 +78,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { } if len(rr.SourceGeoip) > 0 { - cond, err := NewMultiGeoIPMatcher(rr.SourceGeoip, "source") + cond, err := NewIPMatcher(rr.SourceGeoip, MatcherAsType_Source) if err != nil { return nil, err } @@ -86,7 +86,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { } if len(rr.LocalGeoip) > 0 { - cond, err := NewMultiGeoIPMatcher(rr.LocalGeoip, "local") + cond, err := NewIPMatcher(rr.LocalGeoip, MatcherAsType_Local) if err != nil { return nil, err }