mirror of https://github.com/v2ray/v2ray-core
migrate to the new geoip matcher
parent
5400153827
commit
41956e92a5
|
@ -120,22 +120,6 @@ func (m *DomainMatcher) Apply(ctx context.Context) bool {
|
|||
return m.ApplyDomain(dest.Address.Domain())
|
||||
}
|
||||
|
||||
type CIDRMatcher struct {
|
||||
cidr *net.IPNet
|
||||
onSource bool
|
||||
}
|
||||
|
||||
func NewCIDRMatcher(ip []byte, mask uint32, onSource bool) (*CIDRMatcher, error) {
|
||||
cidr := &net.IPNet{
|
||||
IP: net.IP(ip),
|
||||
Mask: net.CIDRMask(int(mask), len(ip)*8),
|
||||
}
|
||||
return &CIDRMatcher{
|
||||
cidr: cidr,
|
||||
onSource: onSource,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func sourceFromContext(ctx context.Context) net.Destination {
|
||||
inbound := session.InboundFromContext(ctx)
|
||||
if inbound == nil {
|
||||
|
@ -152,80 +136,6 @@ func targetFromContent(ctx context.Context) net.Destination {
|
|||
return outbound.Target
|
||||
}
|
||||
|
||||
func (v *CIDRMatcher) Apply(ctx context.Context) bool {
|
||||
ips := make([]net.IP, 0, 4)
|
||||
if resolver, ok := ResolvedIPsFromContext(ctx); ok {
|
||||
resolvedIPs := resolver.Resolve()
|
||||
for _, rip := range resolvedIPs {
|
||||
if !rip.Family().IsIPv6() {
|
||||
continue
|
||||
}
|
||||
ips = append(ips, rip.IP())
|
||||
}
|
||||
}
|
||||
|
||||
var dest net.Destination
|
||||
if v.onSource {
|
||||
dest = sourceFromContext(ctx)
|
||||
} else {
|
||||
dest = targetFromContent(ctx)
|
||||
}
|
||||
|
||||
if dest.IsValid() && dest.Address.Family().IsIPv6() {
|
||||
ips = append(ips, dest.Address.IP())
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if v.cidr.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type IPv4Matcher struct {
|
||||
ipv4net *net.IPNetTable
|
||||
onSource bool
|
||||
}
|
||||
|
||||
func NewIPv4Matcher(ipnet *net.IPNetTable, onSource bool) *IPv4Matcher {
|
||||
return &IPv4Matcher{
|
||||
ipv4net: ipnet,
|
||||
onSource: onSource,
|
||||
}
|
||||
}
|
||||
|
||||
func (v *IPv4Matcher) Apply(ctx context.Context) bool {
|
||||
ips := make([]net.IP, 0, 4)
|
||||
if resolver, ok := ResolvedIPsFromContext(ctx); ok {
|
||||
resolvedIPs := resolver.Resolve()
|
||||
for _, rip := range resolvedIPs {
|
||||
if !rip.Family().IsIPv4() {
|
||||
continue
|
||||
}
|
||||
ips = append(ips, rip.IP())
|
||||
}
|
||||
}
|
||||
|
||||
var dest net.Destination
|
||||
if v.onSource {
|
||||
dest = sourceFromContext(ctx)
|
||||
} else {
|
||||
dest = targetFromContent(ctx)
|
||||
}
|
||||
|
||||
if dest.IsValid() && dest.Address.Family().IsIPv4() {
|
||||
ips = append(ips, dest.Address.IP())
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if v.ipv4net.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type MultiGeoIPMatcher struct {
|
||||
matchers []*GeoIPMatcher
|
||||
onSource bool
|
||||
|
|
|
@ -2,8 +2,6 @@ package router
|
|||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"v2ray.com/core/common/net"
|
||||
)
|
||||
|
||||
// CIDRList is an alias of []*CIDR to provide sort.Interface.
|
||||
|
@ -54,40 +52,6 @@ func (r *Rule) Apply(ctx context.Context) bool {
|
|||
return r.Condition.Apply(ctx)
|
||||
}
|
||||
|
||||
func cidrToCondition(cidr []*CIDR, source bool) (Condition, error) {
|
||||
ipv4Net := net.NewIPNetTable()
|
||||
ipv6Cond := NewAnyCondition()
|
||||
hasIpv6 := false
|
||||
|
||||
for _, ip := range cidr {
|
||||
switch len(ip.Ip) {
|
||||
case net.IPv4len:
|
||||
ipv4Net.AddIP(ip.Ip, byte(ip.Prefix))
|
||||
case net.IPv6len:
|
||||
hasIpv6 = true
|
||||
matcher, err := NewCIDRMatcher(ip.Ip, ip.Prefix, source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ipv6Cond.Add(matcher)
|
||||
default:
|
||||
return nil, newError("invalid IP length").AtWarning()
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case !ipv4Net.IsEmpty() && hasIpv6:
|
||||
cond := NewAnyCondition()
|
||||
cond.Add(NewIPv4Matcher(ipv4Net, source))
|
||||
cond.Add(ipv6Cond)
|
||||
return cond, nil
|
||||
case !ipv4Net.IsEmpty():
|
||||
return NewIPv4Matcher(ipv4Net, source), nil
|
||||
default:
|
||||
return ipv6Cond, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (rr *RoutingRule) BuildCondition() (Condition, error) {
|
||||
conds := NewConditionChan()
|
||||
|
||||
|
@ -122,7 +86,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
|
|||
}
|
||||
conds.Add(cond)
|
||||
} else if len(rr.Cidr) > 0 {
|
||||
cond, err := cidrToCondition(rr.Cidr, false)
|
||||
cond, err := NewMultiGeoIPMatcher([]*GeoIP{{Cidr: rr.Cidr}}, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -136,7 +100,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
|
|||
}
|
||||
conds.Add(cond)
|
||||
} else if len(rr.SourceCidr) > 0 {
|
||||
cond, err := cidrToCondition(rr.SourceCidr, true)
|
||||
cond, err := NewMultiGeoIPMatcher([]*GeoIP{{Cidr: rr.SourceCidr}}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,83 +0,0 @@
|
|||
package net
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
"net"
|
||||
)
|
||||
|
||||
type IPNetTable struct {
|
||||
cache map[uint32]byte
|
||||
}
|
||||
|
||||
func NewIPNetTable() *IPNetTable {
|
||||
return &IPNetTable{
|
||||
cache: make(map[uint32]byte, 1024),
|
||||
}
|
||||
}
|
||||
|
||||
func ipToUint32(ip IP) uint32 {
|
||||
value := uint32(0)
|
||||
for _, b := range []byte(ip) {
|
||||
value <<= 8
|
||||
value += uint32(b)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func ipMaskToByte(mask net.IPMask) byte {
|
||||
value := byte(0)
|
||||
for _, b := range []byte(mask) {
|
||||
value += byte(bits.OnesCount8(b))
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (n *IPNetTable) Add(ipNet *net.IPNet) {
|
||||
ipv4 := ipNet.IP.To4()
|
||||
if ipv4 == nil {
|
||||
// For now, we don't support IPv6
|
||||
return
|
||||
}
|
||||
mask := ipMaskToByte(ipNet.Mask)
|
||||
n.AddIP(ipv4, mask)
|
||||
}
|
||||
|
||||
func (n *IPNetTable) AddIP(ip []byte, mask byte) {
|
||||
k := ipToUint32(ip)
|
||||
k = (k >> (32 - mask)) << (32 - mask) // normalize ip
|
||||
existing, found := n.cache[k]
|
||||
if !found || existing > mask {
|
||||
n.cache[k] = mask
|
||||
}
|
||||
}
|
||||
|
||||
func (n *IPNetTable) Contains(ip net.IP) bool {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return false
|
||||
}
|
||||
originalValue := ipToUint32(ipv4)
|
||||
|
||||
if entry, found := n.cache[originalValue]; found {
|
||||
if entry == 32 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
mask := uint32(0)
|
||||
for maskbit := byte(1); maskbit <= 32; maskbit++ {
|
||||
mask += 1 << uint32(32-maskbit)
|
||||
|
||||
maskedValue := originalValue & mask
|
||||
if entry, found := n.cache[maskedValue]; found {
|
||||
if entry == maskbit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *IPNetTable) IsEmpty() bool {
|
||||
return len(n.cache) == 0
|
||||
}
|
|
@ -1,133 +0,0 @@
|
|||
package net_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
"v2ray.com/core/app/router"
|
||||
"v2ray.com/core/common/platform"
|
||||
|
||||
"v2ray.com/ext/sysio"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
. "v2ray.com/core/common/net"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func parseCIDR(str string) *net.IPNet {
|
||||
_, ipNet, err := net.ParseCIDR(str)
|
||||
common.Must(err)
|
||||
return ipNet
|
||||
}
|
||||
|
||||
func TestIPNet(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
ipNet := NewIPNetTable()
|
||||
ipNet.Add(parseCIDR(("0.0.0.0/8")))
|
||||
ipNet.Add(parseCIDR(("10.0.0.0/8")))
|
||||
ipNet.Add(parseCIDR(("100.64.0.0/10")))
|
||||
ipNet.Add(parseCIDR(("127.0.0.0/8")))
|
||||
ipNet.Add(parseCIDR(("169.254.0.0/16")))
|
||||
ipNet.Add(parseCIDR(("172.16.0.0/12")))
|
||||
ipNet.Add(parseCIDR(("192.0.0.0/24")))
|
||||
ipNet.Add(parseCIDR(("192.0.2.0/24")))
|
||||
ipNet.Add(parseCIDR(("192.168.0.0/16")))
|
||||
ipNet.Add(parseCIDR(("198.18.0.0/15")))
|
||||
ipNet.Add(parseCIDR(("198.51.100.0/24")))
|
||||
ipNet.Add(parseCIDR(("203.0.113.0/24")))
|
||||
ipNet.Add(parseCIDR(("8.8.8.8/32")))
|
||||
ipNet.AddIP(net.ParseIP("91.108.4.0"), 16)
|
||||
assert(ipNet.Contains(ParseIP("192.168.1.1")), IsTrue)
|
||||
assert(ipNet.Contains(ParseIP("192.0.0.0")), IsTrue)
|
||||
assert(ipNet.Contains(ParseIP("192.0.1.0")), IsFalse)
|
||||
assert(ipNet.Contains(ParseIP("0.1.0.0")), IsTrue)
|
||||
assert(ipNet.Contains(ParseIP("1.0.0.1")), IsFalse)
|
||||
assert(ipNet.Contains(ParseIP("8.8.8.7")), IsFalse)
|
||||
assert(ipNet.Contains(ParseIP("8.8.8.8")), IsTrue)
|
||||
assert(ipNet.Contains(ParseIP("2001:cdba::3257:9652")), IsFalse)
|
||||
assert(ipNet.Contains(ParseIP("91.108.255.254")), IsTrue)
|
||||
}
|
||||
|
||||
func TestGeoIPCN(t *testing.T) {
|
||||
assert := With(t)
|
||||
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat")))
|
||||
|
||||
ips, err := loadGeoIP("CN")
|
||||
common.Must(err)
|
||||
|
||||
ipNet := NewIPNetTable()
|
||||
for _, ip := range ips {
|
||||
ipNet.AddIP(ip.Ip, byte(ip.Prefix))
|
||||
}
|
||||
|
||||
assert(ipNet.Contains([]byte{8, 8, 8, 8}), IsFalse)
|
||||
}
|
||||
|
||||
func loadGeoIP(country string) ([]*router.CIDR, error) {
|
||||
geoipBytes, err := sysio.ReadAsset("geoip.dat")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var geoipList router.GeoIPList
|
||||
if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, geoip := range geoipList.Entry {
|
||||
if geoip.CountryCode == country {
|
||||
return geoip.Cidr, nil
|
||||
}
|
||||
}
|
||||
|
||||
panic("country not found: " + country)
|
||||
}
|
||||
|
||||
func BenchmarkIPNetQuery(b *testing.B) {
|
||||
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat")))
|
||||
|
||||
ips, err := loadGeoIP("CN")
|
||||
common.Must(err)
|
||||
|
||||
ipNet := NewIPNetTable()
|
||||
for _, ip := range ips {
|
||||
ipNet.AddIP(ip.Ip, byte(ip.Prefix))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
ipNet.Contains([]byte{8, 8, 8, 8})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCIDRQuery(b *testing.B) {
|
||||
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat")))
|
||||
|
||||
ips, err := loadGeoIP("CN")
|
||||
common.Must(err)
|
||||
|
||||
ipNet := make([]*net.IPNet, 0, 1024)
|
||||
for _, ip := range ips {
|
||||
if len(ip.Ip) != 4 {
|
||||
continue
|
||||
}
|
||||
ipNet = append(ipNet, &net.IPNet{
|
||||
IP: net.IP(ip.Ip),
|
||||
Mask: net.CIDRMask(int(ip.Prefix), 32),
|
||||
})
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, n := range ipNet {
|
||||
if n.Contains([]byte{8, 8, 8, 8}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue