mirror of https://github.com/XTLS/Xray-core
962 lines
19 KiB
Go
962 lines
19 KiB
Go
package router
|
|
|
|
import (
|
|
"context"
|
|
"net/netip"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/xtls/xray-core/common/errors"
|
|
"github.com/xtls/xray-core/common/net"
|
|
|
|
"go4.org/netipx"
|
|
)
|
|
|
|
type GeoIPMatcher interface {
|
|
// TODO: (PERF) all net.IP -> netipx.Addr
|
|
|
|
// Invalid IP always return false.
|
|
Match(ip net.IP) bool
|
|
|
|
// Returns true if *any* IP is valid and match.
|
|
AnyMatch(ips []net.IP) bool
|
|
|
|
// Returns true only if *all* IPs are valid and match. Any invalid IP, or non-matching valid IP, causes false.
|
|
Matches(ips []net.IP) bool
|
|
|
|
// Filters IPs. Invalid IPs are silently dropped and not included in either result.
|
|
FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP)
|
|
|
|
ToggleReverse()
|
|
|
|
SetReverse(reverse bool)
|
|
}
|
|
|
|
type GeoIPSet struct {
|
|
ipv4, ipv6 *netipx.IPSet
|
|
max4, max6 uint8
|
|
}
|
|
|
|
type HeuristicGeoIPMatcher struct {
|
|
ipset *GeoIPSet
|
|
reverse bool
|
|
}
|
|
|
|
type ipBucket struct {
|
|
rep netip.Addr
|
|
ips []net.IP
|
|
}
|
|
|
|
// Match implements GeoIPMatcher.
|
|
func (m *HeuristicGeoIPMatcher) Match(ip net.IP) bool {
|
|
ipx, ok := netipx.FromStdIP(ip)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return m.matchAddr(ipx)
|
|
}
|
|
|
|
func (m *HeuristicGeoIPMatcher) matchAddr(ipx netip.Addr) bool {
|
|
if ipx.Is4() {
|
|
return m.ipset.ipv4.Contains(ipx) != m.reverse
|
|
}
|
|
if ipx.Is6() {
|
|
return m.ipset.ipv6.Contains(ipx) != m.reverse
|
|
}
|
|
return false
|
|
}
|
|
|
|
// AnyMatch implements GeoIPMatcher.
|
|
func (m *HeuristicGeoIPMatcher) AnyMatch(ips []net.IP) bool {
|
|
n := len(ips)
|
|
if n == 0 {
|
|
return false
|
|
}
|
|
|
|
if n == 1 {
|
|
return m.Match(ips[0])
|
|
}
|
|
|
|
heur4 := m.ipset.max4 <= 24
|
|
heur6 := m.ipset.max6 <= 64
|
|
if !heur4 && !heur6 {
|
|
for _, ip := range ips {
|
|
if ipx, ok := netipx.FromStdIP(ip); ok {
|
|
if m.matchAddr(ipx) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
buckets := make(map[[9]byte]struct{}, n)
|
|
for _, ip := range ips {
|
|
key, ok := prefixKeyFromIP(ip)
|
|
if !ok {
|
|
continue
|
|
}
|
|
heur := (key[0] == 4 && heur4) || (key[0] == 6 && heur6)
|
|
if heur {
|
|
if _, exists := buckets[key]; exists {
|
|
continue
|
|
}
|
|
}
|
|
ipx, ok := netipx.FromStdIP(ip)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if m.matchAddr(ipx) {
|
|
return true
|
|
}
|
|
if heur {
|
|
buckets[key] = struct{}{}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Matches implements GeoIPMatcher.
|
|
func (m *HeuristicGeoIPMatcher) Matches(ips []net.IP) bool {
|
|
n := len(ips)
|
|
if n == 0 {
|
|
return false
|
|
}
|
|
|
|
if n == 1 {
|
|
return m.Match(ips[0])
|
|
}
|
|
|
|
heur4 := m.ipset.max4 <= 24
|
|
heur6 := m.ipset.max6 <= 64
|
|
if !heur4 && !heur6 {
|
|
for _, ip := range ips {
|
|
ipx, ok := netipx.FromStdIP(ip)
|
|
if !ok {
|
|
return false
|
|
}
|
|
if !m.matchAddr(ipx) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
buckets := make(map[[9]byte]netip.Addr, n)
|
|
precise := make([]netip.Addr, 0, n)
|
|
|
|
for _, ip := range ips {
|
|
key, ok := prefixKeyFromIP(ip)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
if (key[0] == 4 && heur4) || (key[0] == 6 && heur6) {
|
|
if _, exists := buckets[key]; !exists {
|
|
ipx, ok := netipx.FromStdIP(ip)
|
|
if !ok {
|
|
return false
|
|
}
|
|
buckets[key] = ipx
|
|
}
|
|
} else {
|
|
ipx, ok := netipx.FromStdIP(ip)
|
|
if !ok {
|
|
return false
|
|
}
|
|
precise = append(precise, ipx)
|
|
}
|
|
}
|
|
|
|
for _, ipx := range buckets {
|
|
if !m.matchAddr(ipx) {
|
|
return false
|
|
}
|
|
}
|
|
for _, ipx := range precise {
|
|
if !m.matchAddr(ipx) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func prefixKeyFromIP(ip net.IP) (key [9]byte, ok bool) {
|
|
if ip4 := ip.To4(); ip4 != nil {
|
|
key[0] = 4
|
|
key[1] = ip4[0]
|
|
key[2] = ip4[1]
|
|
key[3] = ip4[2] // /24
|
|
return key, true
|
|
}
|
|
if ip16 := ip.To16(); ip16 != nil {
|
|
key[0] = 6
|
|
key[1] = ip16[0]
|
|
key[2] = ip16[1]
|
|
key[3] = ip16[2]
|
|
key[4] = ip16[3]
|
|
key[5] = ip16[4]
|
|
key[6] = ip16[5]
|
|
key[7] = ip16[6]
|
|
key[8] = ip16[7] // /64
|
|
return key, true
|
|
}
|
|
return key, false // illegal
|
|
}
|
|
|
|
// FilterIPs implements GeoIPMatcher.
|
|
func (m *HeuristicGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
|
|
n := len(ips)
|
|
if n == 0 {
|
|
return []net.IP{}, []net.IP{}
|
|
}
|
|
|
|
if n == 1 {
|
|
ipx, ok := netipx.FromStdIP(ips[0])
|
|
if !ok {
|
|
return []net.IP{}, []net.IP{}
|
|
}
|
|
if m.matchAddr(ipx) {
|
|
return ips, []net.IP{}
|
|
}
|
|
return []net.IP{}, ips
|
|
}
|
|
|
|
heur4 := m.ipset.max4 <= 24
|
|
heur6 := m.ipset.max6 <= 64
|
|
if !heur4 && !heur6 {
|
|
matched = make([]net.IP, 0, n)
|
|
unmatched = make([]net.IP, 0, n)
|
|
for _, ip := range ips {
|
|
ipx, ok := netipx.FromStdIP(ip)
|
|
if !ok {
|
|
continue // illegal ip, ignore
|
|
}
|
|
if m.matchAddr(ipx) {
|
|
matched = append(matched, ip)
|
|
} else {
|
|
unmatched = append(unmatched, ip)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
buckets := make(map[[9]byte]*ipBucket, n)
|
|
precise := make([]net.IP, 0, n)
|
|
|
|
for _, ip := range ips {
|
|
key, ok := prefixKeyFromIP(ip)
|
|
if !ok {
|
|
continue // illegal ip, ignore
|
|
}
|
|
|
|
if (key[0] == 4 && !heur4) || (key[0] == 6 && !heur6) {
|
|
precise = append(precise, ip)
|
|
continue
|
|
}
|
|
|
|
b, exists := buckets[key]
|
|
if !exists {
|
|
// build bucket
|
|
ipx, ok := netipx.FromStdIP(ip)
|
|
if !ok {
|
|
continue // illegal ip, ignore
|
|
}
|
|
b = &ipBucket{
|
|
rep: ipx,
|
|
ips: make([]net.IP, 0, 4), // for dns answer
|
|
}
|
|
buckets[key] = b
|
|
}
|
|
b.ips = append(b.ips, ip)
|
|
}
|
|
|
|
matched = make([]net.IP, 0, n)
|
|
unmatched = make([]net.IP, 0, n)
|
|
for _, b := range buckets {
|
|
if m.matchAddr(b.rep) {
|
|
matched = append(matched, b.ips...)
|
|
} else {
|
|
unmatched = append(unmatched, b.ips...)
|
|
}
|
|
}
|
|
for _, ip := range precise {
|
|
ipx, ok := netipx.FromStdIP(ip)
|
|
if !ok {
|
|
continue // illegal ip, ignore
|
|
}
|
|
if m.matchAddr(ipx) {
|
|
matched = append(matched, ip)
|
|
} else {
|
|
unmatched = append(unmatched, ip)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// ToggleReverse implements GeoIPMatcher.
|
|
func (m *HeuristicGeoIPMatcher) ToggleReverse() {
|
|
m.reverse = !m.reverse
|
|
}
|
|
|
|
// SetReverse implements GeoIPMatcher.
|
|
func (m *HeuristicGeoIPMatcher) SetReverse(reverse bool) {
|
|
m.reverse = reverse
|
|
}
|
|
|
|
type GeneralMultiGeoIPMatcher struct {
|
|
matchers []GeoIPMatcher
|
|
}
|
|
|
|
// Match implements GeoIPMatcher.
|
|
func (mm *GeneralMultiGeoIPMatcher) Match(ip net.IP) bool {
|
|
for _, m := range mm.matchers {
|
|
if m.Match(ip) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// AnyMatch implements GeoIPMatcher.
|
|
func (mm *GeneralMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool {
|
|
for _, m := range mm.matchers {
|
|
if m.AnyMatch(ips) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Matches implements GeoIPMatcher.
|
|
func (mm *GeneralMultiGeoIPMatcher) Matches(ips []net.IP) bool {
|
|
for _, m := range mm.matchers {
|
|
if m.Matches(ips) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// FilterIPs implements GeoIPMatcher.
|
|
func (mm *GeneralMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
|
|
matched = make([]net.IP, 0, len(ips))
|
|
unmatched = ips
|
|
for _, m := range mm.matchers {
|
|
if len(unmatched) == 0 {
|
|
break
|
|
}
|
|
var mtch []net.IP
|
|
mtch, unmatched = m.FilterIPs(unmatched)
|
|
if len(mtch) > 0 {
|
|
matched = append(matched, mtch...)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// ToggleReverse implements GeoIPMatcher.
|
|
func (mm *GeneralMultiGeoIPMatcher) ToggleReverse() {
|
|
for _, m := range mm.matchers {
|
|
m.ToggleReverse()
|
|
}
|
|
}
|
|
|
|
// SetReverse implements GeoIPMatcher.
|
|
func (mm *GeneralMultiGeoIPMatcher) SetReverse(reverse bool) {
|
|
for _, m := range mm.matchers {
|
|
m.SetReverse(reverse)
|
|
}
|
|
}
|
|
|
|
type HeuristicMultiGeoIPMatcher struct {
|
|
matchers []*HeuristicGeoIPMatcher
|
|
}
|
|
|
|
// Match implements GeoIPMatcher.
|
|
func (mm *HeuristicMultiGeoIPMatcher) Match(ip net.IP) bool {
|
|
ipx, ok := netipx.FromStdIP(ip)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
for _, m := range mm.matchers {
|
|
if m.matchAddr(ipx) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// AnyMatch implements GeoIPMatcher.
|
|
func (mm *HeuristicMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool {
|
|
n := len(ips)
|
|
if n == 0 {
|
|
return false
|
|
}
|
|
|
|
if n == 1 {
|
|
return mm.Match(ips[0])
|
|
}
|
|
|
|
buckets := make(map[[9]byte]struct{}, n)
|
|
for _, ip := range ips {
|
|
var ipx netip.Addr
|
|
state := uint8(0) // 0 = Not initialized, 1 = Initialized, 4 = IPv4 can be skipped, 6 = IPv6 can be skipped
|
|
for _, m := range mm.matchers {
|
|
heur4 := m.ipset.max4 <= 24
|
|
heur6 := m.ipset.max6 <= 64
|
|
|
|
if state == 0 && (heur4 || heur6) {
|
|
key, ok := prefixKeyFromIP(ip)
|
|
if !ok {
|
|
break
|
|
}
|
|
if _, exists := buckets[key]; exists {
|
|
state = key[0]
|
|
} else {
|
|
buckets[key] = struct{}{}
|
|
state = 1
|
|
}
|
|
}
|
|
if (heur4 && state == 4) || (heur6 && state == 6) {
|
|
continue
|
|
}
|
|
|
|
if !ipx.IsValid() {
|
|
nipx, ok := netipx.FromStdIP(ip)
|
|
if !ok {
|
|
break
|
|
}
|
|
ipx = nipx
|
|
}
|
|
if m.matchAddr(ipx) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Matches implements GeoIPMatcher.
|
|
func (mm *HeuristicMultiGeoIPMatcher) Matches(ips []net.IP) bool {
|
|
n := len(ips)
|
|
if n == 0 {
|
|
return false
|
|
}
|
|
|
|
if n == 1 {
|
|
return mm.Match(ips[0])
|
|
}
|
|
|
|
var views ipViews
|
|
for _, m := range mm.matchers {
|
|
if !views.ensureForMatcher(m, ips) {
|
|
return false
|
|
}
|
|
|
|
matched := true
|
|
if m.ipset.max4 <= 24 {
|
|
for _, ipx := range views.buckets4 {
|
|
if !m.matchAddr(ipx) {
|
|
matched = false
|
|
break
|
|
}
|
|
}
|
|
} else {
|
|
for _, ipx := range views.precise4 {
|
|
if !m.matchAddr(ipx) {
|
|
matched = false
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if !matched {
|
|
continue
|
|
}
|
|
|
|
if m.ipset.max6 <= 64 {
|
|
for _, ipx := range views.buckets6 {
|
|
if !m.matchAddr(ipx) {
|
|
matched = false
|
|
break
|
|
}
|
|
}
|
|
} else {
|
|
for _, ipx := range views.precise6 {
|
|
if !m.matchAddr(ipx) {
|
|
matched = false
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if matched {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
type ipViews struct {
|
|
buckets4, buckets6 map[[9]byte]netip.Addr
|
|
precise4, precise6 []netip.Addr
|
|
}
|
|
|
|
func (v *ipViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) bool {
|
|
needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
|
|
needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
|
|
needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
|
|
needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
|
|
|
|
if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
|
|
return true
|
|
}
|
|
|
|
if needHeur4 {
|
|
v.buckets4 = make(map[[9]byte]netip.Addr, len(ips))
|
|
}
|
|
if needHeur6 {
|
|
v.buckets6 = make(map[[9]byte]netip.Addr, len(ips))
|
|
}
|
|
if needPrec4 {
|
|
v.precise4 = make([]netip.Addr, 0, len(ips))
|
|
}
|
|
if needPrec6 {
|
|
v.precise6 = make([]netip.Addr, 0, len(ips))
|
|
}
|
|
|
|
for _, ip := range ips {
|
|
key, ok := prefixKeyFromIP(ip)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
switch key[0] {
|
|
case 4:
|
|
var ipx netip.Addr
|
|
if needHeur4 {
|
|
if _, exists := v.buckets4[key]; !exists {
|
|
ipx, ok = netipx.FromStdIP(ip)
|
|
if !ok {
|
|
return false
|
|
}
|
|
v.buckets4[key] = ipx
|
|
}
|
|
}
|
|
if needPrec4 {
|
|
if !ipx.IsValid() {
|
|
ipx, ok = netipx.FromStdIP(ip)
|
|
if !ok {
|
|
return false
|
|
}
|
|
}
|
|
v.precise4 = append(v.precise4, ipx)
|
|
}
|
|
case 6:
|
|
var ipx netip.Addr
|
|
if needHeur6 {
|
|
if _, exists := v.buckets6[key]; !exists {
|
|
ipx, ok = netipx.FromStdIP(ip)
|
|
if !ok {
|
|
return false
|
|
}
|
|
v.buckets6[key] = ipx
|
|
}
|
|
}
|
|
if needPrec6 {
|
|
if !ipx.IsValid() {
|
|
ipx, ok = netipx.FromStdIP(ip)
|
|
if !ok {
|
|
return false
|
|
}
|
|
}
|
|
v.precise6 = append(v.precise6, ipx)
|
|
}
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// FilterIPs implements GeoIPMatcher.
|
|
func (mm *HeuristicMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
|
|
n := len(ips)
|
|
if n == 0 {
|
|
return []net.IP{}, []net.IP{}
|
|
}
|
|
|
|
if n == 1 {
|
|
ipx, ok := netipx.FromStdIP(ips[0])
|
|
if !ok {
|
|
return []net.IP{}, []net.IP{}
|
|
}
|
|
for _, m := range mm.matchers {
|
|
if m.matchAddr(ipx) {
|
|
return ips, []net.IP{}
|
|
}
|
|
}
|
|
return []net.IP{}, ips
|
|
}
|
|
|
|
var views ipBucketViews
|
|
|
|
matched = make([]net.IP, 0, n)
|
|
for _, m := range mm.matchers {
|
|
views.ensureForMatcher(m, ips)
|
|
|
|
if m.ipset.max4 <= 24 {
|
|
for key, b := range views.buckets4 {
|
|
if b == nil {
|
|
continue
|
|
}
|
|
if m.matchAddr(b.rep) {
|
|
views.buckets4[key] = nil
|
|
matched = append(matched, b.ips...)
|
|
}
|
|
}
|
|
} else {
|
|
for ipx, ip := range views.precise4 {
|
|
if ip == nil {
|
|
continue
|
|
}
|
|
if m.matchAddr(ipx) {
|
|
views.precise4[ipx] = nil
|
|
matched = append(matched, ip)
|
|
}
|
|
}
|
|
}
|
|
|
|
if m.ipset.max6 <= 64 {
|
|
for key, b := range views.buckets6 {
|
|
if b == nil {
|
|
continue
|
|
}
|
|
if m.matchAddr(b.rep) {
|
|
views.buckets6[key] = nil
|
|
matched = append(matched, b.ips...)
|
|
}
|
|
}
|
|
} else {
|
|
for ipx, ip := range views.precise6 {
|
|
if ip == nil {
|
|
continue
|
|
}
|
|
if m.matchAddr(ipx) {
|
|
views.precise6[ipx] = nil
|
|
matched = append(matched, ip)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
unmatched = make([]net.IP, 0, n-len(matched))
|
|
if views.buckets4 != nil {
|
|
for _, b := range views.buckets4 {
|
|
if b == nil {
|
|
continue
|
|
}
|
|
unmatched = append(unmatched, b.ips...)
|
|
}
|
|
}
|
|
if views.precise4 != nil {
|
|
for _, ip := range views.precise4 {
|
|
if ip == nil {
|
|
continue
|
|
}
|
|
unmatched = append(unmatched, ip)
|
|
}
|
|
}
|
|
if views.buckets6 != nil {
|
|
for _, b := range views.buckets6 {
|
|
if b == nil {
|
|
continue
|
|
}
|
|
unmatched = append(unmatched, b.ips...)
|
|
}
|
|
}
|
|
if views.precise6 != nil {
|
|
for _, ip := range views.precise6 {
|
|
if ip == nil {
|
|
continue
|
|
}
|
|
unmatched = append(unmatched, ip)
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
type ipBucketViews struct {
|
|
buckets4, buckets6 map[[9]byte]*ipBucket
|
|
precise4, precise6 map[netip.Addr]net.IP
|
|
}
|
|
|
|
func (v *ipBucketViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) {
|
|
needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
|
|
needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
|
|
needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
|
|
needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
|
|
|
|
if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
|
|
return
|
|
}
|
|
|
|
if needHeur4 {
|
|
v.buckets4 = make(map[[9]byte]*ipBucket, len(ips))
|
|
}
|
|
if needHeur6 {
|
|
v.buckets6 = make(map[[9]byte]*ipBucket, len(ips))
|
|
}
|
|
if needPrec4 {
|
|
v.precise4 = make(map[netip.Addr]net.IP, len(ips))
|
|
}
|
|
if needPrec6 {
|
|
v.precise6 = make(map[netip.Addr]net.IP, len(ips))
|
|
}
|
|
|
|
for _, ip := range ips {
|
|
key, ok := prefixKeyFromIP(ip)
|
|
if !ok {
|
|
continue // illegal ip, ignore
|
|
}
|
|
|
|
switch key[0] {
|
|
case 4:
|
|
var ipx netip.Addr
|
|
if needHeur4 {
|
|
b, exists := v.buckets4[key]
|
|
if !exists {
|
|
// build bucket
|
|
ipx, ok = netipx.FromStdIP(ip)
|
|
if !ok {
|
|
continue // illegal ip, ignore
|
|
}
|
|
b = &ipBucket{
|
|
rep: ipx,
|
|
ips: make([]net.IP, 0, 4), // for dns answer
|
|
}
|
|
v.buckets4[key] = b
|
|
}
|
|
b.ips = append(b.ips, ip)
|
|
}
|
|
if needPrec4 {
|
|
if !ipx.IsValid() {
|
|
ipx, ok = netipx.FromStdIP(ip)
|
|
if !ok {
|
|
continue // illegal ip, ignore
|
|
}
|
|
}
|
|
v.precise4[ipx] = ip
|
|
}
|
|
case 6:
|
|
var ipx netip.Addr
|
|
if needHeur6 {
|
|
b, exists := v.buckets6[key]
|
|
if !exists {
|
|
// build bucket
|
|
ipx, ok = netipx.FromStdIP(ip)
|
|
if !ok {
|
|
continue // illegal ip, ignore
|
|
}
|
|
b = &ipBucket{
|
|
rep: ipx,
|
|
ips: make([]net.IP, 0, 4), // for dns answer
|
|
}
|
|
v.buckets6[key] = b
|
|
}
|
|
b.ips = append(b.ips, ip)
|
|
}
|
|
if needPrec6 {
|
|
if !ipx.IsValid() {
|
|
ipx, ok = netipx.FromStdIP(ip)
|
|
if !ok {
|
|
continue // illegal ip, ignore
|
|
}
|
|
}
|
|
v.precise6[ipx] = ip
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ToggleReverse implements GeoIPMatcher.
|
|
func (mm *HeuristicMultiGeoIPMatcher) ToggleReverse() {
|
|
for _, m := range mm.matchers {
|
|
m.ToggleReverse()
|
|
}
|
|
}
|
|
|
|
// SetReverse implements GeoIPMatcher.
|
|
func (mm *HeuristicMultiGeoIPMatcher) SetReverse(reverse bool) {
|
|
for _, m := range mm.matchers {
|
|
m.SetReverse(reverse)
|
|
}
|
|
}
|
|
|
|
type GeoIPSetFactory struct {
|
|
sync.Mutex
|
|
shared map[string]*GeoIPSet // TODO: cleanup
|
|
}
|
|
|
|
var ipsetFactory = GeoIPSetFactory{shared: make(map[string]*GeoIPSet)}
|
|
|
|
func (f *GeoIPSetFactory) GetOrCreate(key string, cidrGroups [][]*CIDR) (*GeoIPSet, error) {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
|
|
if ipset := f.shared[key]; ipset != nil {
|
|
return ipset, nil
|
|
}
|
|
|
|
ipset, err := f.Create(cidrGroups...)
|
|
if err == nil {
|
|
f.shared[key] = ipset
|
|
}
|
|
return ipset, err
|
|
}
|
|
|
|
func (f *GeoIPSetFactory) Create(cidrGroups ...[]*CIDR) (*GeoIPSet, error) {
|
|
var ipv4Builder, ipv6Builder netipx.IPSetBuilder
|
|
|
|
for _, cidrGroup := range cidrGroups {
|
|
for _, cidrEntry := range cidrGroup {
|
|
ipBytes := cidrEntry.GetIp()
|
|
prefixLen := int(cidrEntry.GetPrefix())
|
|
|
|
addr, ok := netip.AddrFromSlice(ipBytes)
|
|
if !ok {
|
|
errors.LogError(context.Background(), "ignore invalid IP byte slice: ", ipBytes)
|
|
continue
|
|
}
|
|
|
|
prefix := netip.PrefixFrom(addr, prefixLen)
|
|
if !prefix.IsValid() {
|
|
errors.LogError(context.Background(), "ignore created invalid prefix from addr ", addr, " and length ", prefixLen)
|
|
continue
|
|
}
|
|
|
|
if addr.Is4() {
|
|
ipv4Builder.AddPrefix(prefix)
|
|
} else if addr.Is6() {
|
|
ipv6Builder.AddPrefix(prefix)
|
|
}
|
|
}
|
|
}
|
|
|
|
ipv4, err := ipv4Builder.IPSet()
|
|
if err != nil {
|
|
return nil, errors.New("failed to build IPv4 set").Base(err)
|
|
}
|
|
ipv6, err := ipv6Builder.IPSet()
|
|
if err != nil {
|
|
return nil, errors.New("failed to build IPv6 set").Base(err)
|
|
}
|
|
|
|
var max4, max6 int
|
|
|
|
for _, p := range ipv4.Prefixes() {
|
|
if b := p.Bits(); b > max4 {
|
|
max4 = b
|
|
}
|
|
}
|
|
for _, p := range ipv6.Prefixes() {
|
|
if b := p.Bits(); b > max6 {
|
|
max6 = b
|
|
}
|
|
}
|
|
|
|
if max4 == 0 {
|
|
max4 = 0xff
|
|
}
|
|
if max6 == 0 {
|
|
max6 = 0xff
|
|
}
|
|
|
|
return &GeoIPSet{ipv4: ipv4, ipv6: ipv6, max4: uint8(max4), max6: uint8(max6)}, nil
|
|
}
|
|
|
|
func BuildOptimizedGeoIPMatcher(geoips ...*GeoIP) (GeoIPMatcher, error) {
|
|
n := len(geoips)
|
|
if n == 0 {
|
|
return nil, errors.New("no geoip configs provided")
|
|
}
|
|
|
|
var subs []*HeuristicGeoIPMatcher
|
|
pos := make([]*GeoIP, 0, n)
|
|
neg := make([]*GeoIP, 0, n/2)
|
|
|
|
for _, geoip := range geoips {
|
|
if geoip == nil {
|
|
return nil, errors.New("geoip entry is nil")
|
|
}
|
|
if geoip.CountryCode == "" {
|
|
ipset, err := ipsetFactory.Create(geoip.Cidr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: geoip.ReverseMatch})
|
|
continue
|
|
}
|
|
if !geoip.ReverseMatch {
|
|
pos = append(pos, geoip)
|
|
} else {
|
|
neg = append(neg, geoip)
|
|
}
|
|
}
|
|
|
|
buildIPSet := func(mergeables []*GeoIP) (*GeoIPSet, error) {
|
|
n := len(mergeables)
|
|
if n == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
sort.Slice(mergeables, func(i, j int) bool {
|
|
gi, gj := mergeables[i], mergeables[j]
|
|
return gi.CountryCode < gj.CountryCode
|
|
})
|
|
|
|
var sb strings.Builder
|
|
sb.Grow(n * 3) // xx,
|
|
cidrGroups := make([][]*CIDR, 0, n)
|
|
var last *GeoIP
|
|
for i, geoip := range mergeables {
|
|
if i == 0 || (geoip.CountryCode != last.CountryCode) {
|
|
last = geoip
|
|
sb.WriteString(geoip.CountryCode)
|
|
sb.WriteString(",")
|
|
cidrGroups = append(cidrGroups, geoip.Cidr)
|
|
}
|
|
}
|
|
|
|
return ipsetFactory.GetOrCreate(sb.String(), cidrGroups)
|
|
}
|
|
|
|
ipset, err := buildIPSet(pos)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if ipset != nil {
|
|
subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: false})
|
|
}
|
|
|
|
ipset, err = buildIPSet(neg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if ipset != nil {
|
|
subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: true})
|
|
}
|
|
|
|
switch len(subs) {
|
|
case 0:
|
|
return nil, errors.New("no valid geoip matcher")
|
|
case 1:
|
|
return subs[0], nil
|
|
default:
|
|
return &HeuristicMultiGeoIPMatcher{matchers: subs}, nil
|
|
}
|
|
}
|