mirror of https://github.com/v2ray/v2ray-core
migrate to the new geoip matcher
parent
5400153827
commit
41956e92a5
|
@ -120,22 +120,6 @@ func (m *DomainMatcher) Apply(ctx context.Context) bool {
|
||||||
return m.ApplyDomain(dest.Address.Domain())
|
return m.ApplyDomain(dest.Address.Domain())
|
||||||
}
|
}
|
||||||
|
|
||||||
type CIDRMatcher struct {
|
|
||||||
cidr *net.IPNet
|
|
||||||
onSource bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCIDRMatcher(ip []byte, mask uint32, onSource bool) (*CIDRMatcher, error) {
|
|
||||||
cidr := &net.IPNet{
|
|
||||||
IP: net.IP(ip),
|
|
||||||
Mask: net.CIDRMask(int(mask), len(ip)*8),
|
|
||||||
}
|
|
||||||
return &CIDRMatcher{
|
|
||||||
cidr: cidr,
|
|
||||||
onSource: onSource,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func sourceFromContext(ctx context.Context) net.Destination {
|
func sourceFromContext(ctx context.Context) net.Destination {
|
||||||
inbound := session.InboundFromContext(ctx)
|
inbound := session.InboundFromContext(ctx)
|
||||||
if inbound == nil {
|
if inbound == nil {
|
||||||
|
@ -152,80 +136,6 @@ func targetFromContent(ctx context.Context) net.Destination {
|
||||||
return outbound.Target
|
return outbound.Target
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *CIDRMatcher) Apply(ctx context.Context) bool {
|
|
||||||
ips := make([]net.IP, 0, 4)
|
|
||||||
if resolver, ok := ResolvedIPsFromContext(ctx); ok {
|
|
||||||
resolvedIPs := resolver.Resolve()
|
|
||||||
for _, rip := range resolvedIPs {
|
|
||||||
if !rip.Family().IsIPv6() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ips = append(ips, rip.IP())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var dest net.Destination
|
|
||||||
if v.onSource {
|
|
||||||
dest = sourceFromContext(ctx)
|
|
||||||
} else {
|
|
||||||
dest = targetFromContent(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dest.IsValid() && dest.Address.Family().IsIPv6() {
|
|
||||||
ips = append(ips, dest.Address.IP())
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ip := range ips {
|
|
||||||
if v.cidr.Contains(ip) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
type IPv4Matcher struct {
|
|
||||||
ipv4net *net.IPNetTable
|
|
||||||
onSource bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewIPv4Matcher(ipnet *net.IPNetTable, onSource bool) *IPv4Matcher {
|
|
||||||
return &IPv4Matcher{
|
|
||||||
ipv4net: ipnet,
|
|
||||||
onSource: onSource,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *IPv4Matcher) Apply(ctx context.Context) bool {
|
|
||||||
ips := make([]net.IP, 0, 4)
|
|
||||||
if resolver, ok := ResolvedIPsFromContext(ctx); ok {
|
|
||||||
resolvedIPs := resolver.Resolve()
|
|
||||||
for _, rip := range resolvedIPs {
|
|
||||||
if !rip.Family().IsIPv4() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ips = append(ips, rip.IP())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var dest net.Destination
|
|
||||||
if v.onSource {
|
|
||||||
dest = sourceFromContext(ctx)
|
|
||||||
} else {
|
|
||||||
dest = targetFromContent(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dest.IsValid() && dest.Address.Family().IsIPv4() {
|
|
||||||
ips = append(ips, dest.Address.IP())
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ip := range ips {
|
|
||||||
if v.ipv4net.Contains(ip) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
type MultiGeoIPMatcher struct {
|
type MultiGeoIPMatcher struct {
|
||||||
matchers []*GeoIPMatcher
|
matchers []*GeoIPMatcher
|
||||||
onSource bool
|
onSource bool
|
||||||
|
|
|
@ -2,8 +2,6 @@ package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"v2ray.com/core/common/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// CIDRList is an alias of []*CIDR to provide sort.Interface.
|
// CIDRList is an alias of []*CIDR to provide sort.Interface.
|
||||||
|
@ -54,40 +52,6 @@ func (r *Rule) Apply(ctx context.Context) bool {
|
||||||
return r.Condition.Apply(ctx)
|
return r.Condition.Apply(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func cidrToCondition(cidr []*CIDR, source bool) (Condition, error) {
|
|
||||||
ipv4Net := net.NewIPNetTable()
|
|
||||||
ipv6Cond := NewAnyCondition()
|
|
||||||
hasIpv6 := false
|
|
||||||
|
|
||||||
for _, ip := range cidr {
|
|
||||||
switch len(ip.Ip) {
|
|
||||||
case net.IPv4len:
|
|
||||||
ipv4Net.AddIP(ip.Ip, byte(ip.Prefix))
|
|
||||||
case net.IPv6len:
|
|
||||||
hasIpv6 = true
|
|
||||||
matcher, err := NewCIDRMatcher(ip.Ip, ip.Prefix, source)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ipv6Cond.Add(matcher)
|
|
||||||
default:
|
|
||||||
return nil, newError("invalid IP length").AtWarning()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case !ipv4Net.IsEmpty() && hasIpv6:
|
|
||||||
cond := NewAnyCondition()
|
|
||||||
cond.Add(NewIPv4Matcher(ipv4Net, source))
|
|
||||||
cond.Add(ipv6Cond)
|
|
||||||
return cond, nil
|
|
||||||
case !ipv4Net.IsEmpty():
|
|
||||||
return NewIPv4Matcher(ipv4Net, source), nil
|
|
||||||
default:
|
|
||||||
return ipv6Cond, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rr *RoutingRule) BuildCondition() (Condition, error) {
|
func (rr *RoutingRule) BuildCondition() (Condition, error) {
|
||||||
conds := NewConditionChan()
|
conds := NewConditionChan()
|
||||||
|
|
||||||
|
@ -122,7 +86,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
|
||||||
}
|
}
|
||||||
conds.Add(cond)
|
conds.Add(cond)
|
||||||
} else if len(rr.Cidr) > 0 {
|
} else if len(rr.Cidr) > 0 {
|
||||||
cond, err := cidrToCondition(rr.Cidr, false)
|
cond, err := NewMultiGeoIPMatcher([]*GeoIP{{Cidr: rr.Cidr}}, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -136,7 +100,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
|
||||||
}
|
}
|
||||||
conds.Add(cond)
|
conds.Add(cond)
|
||||||
} else if len(rr.SourceCidr) > 0 {
|
} else if len(rr.SourceCidr) > 0 {
|
||||||
cond, err := cidrToCondition(rr.SourceCidr, true)
|
cond, err := NewMultiGeoIPMatcher([]*GeoIP{{Cidr: rr.SourceCidr}}, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,83 +0,0 @@
|
||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/bits"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
type IPNetTable struct {
|
|
||||||
cache map[uint32]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewIPNetTable() *IPNetTable {
|
|
||||||
return &IPNetTable{
|
|
||||||
cache: make(map[uint32]byte, 1024),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ipToUint32(ip IP) uint32 {
|
|
||||||
value := uint32(0)
|
|
||||||
for _, b := range []byte(ip) {
|
|
||||||
value <<= 8
|
|
||||||
value += uint32(b)
|
|
||||||
}
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
func ipMaskToByte(mask net.IPMask) byte {
|
|
||||||
value := byte(0)
|
|
||||||
for _, b := range []byte(mask) {
|
|
||||||
value += byte(bits.OnesCount8(b))
|
|
||||||
}
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *IPNetTable) Add(ipNet *net.IPNet) {
|
|
||||||
ipv4 := ipNet.IP.To4()
|
|
||||||
if ipv4 == nil {
|
|
||||||
// For now, we don't support IPv6
|
|
||||||
return
|
|
||||||
}
|
|
||||||
mask := ipMaskToByte(ipNet.Mask)
|
|
||||||
n.AddIP(ipv4, mask)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *IPNetTable) AddIP(ip []byte, mask byte) {
|
|
||||||
k := ipToUint32(ip)
|
|
||||||
k = (k >> (32 - mask)) << (32 - mask) // normalize ip
|
|
||||||
existing, found := n.cache[k]
|
|
||||||
if !found || existing > mask {
|
|
||||||
n.cache[k] = mask
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *IPNetTable) Contains(ip net.IP) bool {
|
|
||||||
ipv4 := ip.To4()
|
|
||||||
if ipv4 == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
originalValue := ipToUint32(ipv4)
|
|
||||||
|
|
||||||
if entry, found := n.cache[originalValue]; found {
|
|
||||||
if entry == 32 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mask := uint32(0)
|
|
||||||
for maskbit := byte(1); maskbit <= 32; maskbit++ {
|
|
||||||
mask += 1 << uint32(32-maskbit)
|
|
||||||
|
|
||||||
maskedValue := originalValue & mask
|
|
||||||
if entry, found := n.cache[maskedValue]; found {
|
|
||||||
if entry == maskbit {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *IPNetTable) IsEmpty() bool {
|
|
||||||
return len(n.cache) == 0
|
|
||||||
}
|
|
|
@ -1,133 +0,0 @@
|
||||||
package net_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
proto "github.com/golang/protobuf/proto"
|
|
||||||
"v2ray.com/core/app/router"
|
|
||||||
"v2ray.com/core/common/platform"
|
|
||||||
|
|
||||||
"v2ray.com/ext/sysio"
|
|
||||||
|
|
||||||
"v2ray.com/core/common"
|
|
||||||
. "v2ray.com/core/common/net"
|
|
||||||
. "v2ray.com/ext/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func parseCIDR(str string) *net.IPNet {
|
|
||||||
_, ipNet, err := net.ParseCIDR(str)
|
|
||||||
common.Must(err)
|
|
||||||
return ipNet
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIPNet(t *testing.T) {
|
|
||||||
assert := With(t)
|
|
||||||
|
|
||||||
ipNet := NewIPNetTable()
|
|
||||||
ipNet.Add(parseCIDR(("0.0.0.0/8")))
|
|
||||||
ipNet.Add(parseCIDR(("10.0.0.0/8")))
|
|
||||||
ipNet.Add(parseCIDR(("100.64.0.0/10")))
|
|
||||||
ipNet.Add(parseCIDR(("127.0.0.0/8")))
|
|
||||||
ipNet.Add(parseCIDR(("169.254.0.0/16")))
|
|
||||||
ipNet.Add(parseCIDR(("172.16.0.0/12")))
|
|
||||||
ipNet.Add(parseCIDR(("192.0.0.0/24")))
|
|
||||||
ipNet.Add(parseCIDR(("192.0.2.0/24")))
|
|
||||||
ipNet.Add(parseCIDR(("192.168.0.0/16")))
|
|
||||||
ipNet.Add(parseCIDR(("198.18.0.0/15")))
|
|
||||||
ipNet.Add(parseCIDR(("198.51.100.0/24")))
|
|
||||||
ipNet.Add(parseCIDR(("203.0.113.0/24")))
|
|
||||||
ipNet.Add(parseCIDR(("8.8.8.8/32")))
|
|
||||||
ipNet.AddIP(net.ParseIP("91.108.4.0"), 16)
|
|
||||||
assert(ipNet.Contains(ParseIP("192.168.1.1")), IsTrue)
|
|
||||||
assert(ipNet.Contains(ParseIP("192.0.0.0")), IsTrue)
|
|
||||||
assert(ipNet.Contains(ParseIP("192.0.1.0")), IsFalse)
|
|
||||||
assert(ipNet.Contains(ParseIP("0.1.0.0")), IsTrue)
|
|
||||||
assert(ipNet.Contains(ParseIP("1.0.0.1")), IsFalse)
|
|
||||||
assert(ipNet.Contains(ParseIP("8.8.8.7")), IsFalse)
|
|
||||||
assert(ipNet.Contains(ParseIP("8.8.8.8")), IsTrue)
|
|
||||||
assert(ipNet.Contains(ParseIP("2001:cdba::3257:9652")), IsFalse)
|
|
||||||
assert(ipNet.Contains(ParseIP("91.108.255.254")), IsTrue)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGeoIPCN(t *testing.T) {
|
|
||||||
assert := With(t)
|
|
||||||
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat")))
|
|
||||||
|
|
||||||
ips, err := loadGeoIP("CN")
|
|
||||||
common.Must(err)
|
|
||||||
|
|
||||||
ipNet := NewIPNetTable()
|
|
||||||
for _, ip := range ips {
|
|
||||||
ipNet.AddIP(ip.Ip, byte(ip.Prefix))
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(ipNet.Contains([]byte{8, 8, 8, 8}), IsFalse)
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadGeoIP(country string) ([]*router.CIDR, error) {
|
|
||||||
geoipBytes, err := sysio.ReadAsset("geoip.dat")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var geoipList router.GeoIPList
|
|
||||||
if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, geoip := range geoipList.Entry {
|
|
||||||
if geoip.CountryCode == country {
|
|
||||||
return geoip.Cidr, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
panic("country not found: " + country)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkIPNetQuery(b *testing.B) {
|
|
||||||
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat")))
|
|
||||||
|
|
||||||
ips, err := loadGeoIP("CN")
|
|
||||||
common.Must(err)
|
|
||||||
|
|
||||||
ipNet := NewIPNetTable()
|
|
||||||
for _, ip := range ips {
|
|
||||||
ipNet.AddIP(ip.Ip, byte(ip.Prefix))
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
ipNet.Contains([]byte{8, 8, 8, 8})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkCIDRQuery(b *testing.B) {
|
|
||||||
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat")))
|
|
||||||
|
|
||||||
ips, err := loadGeoIP("CN")
|
|
||||||
common.Must(err)
|
|
||||||
|
|
||||||
ipNet := make([]*net.IPNet, 0, 1024)
|
|
||||||
for _, ip := range ips {
|
|
||||||
if len(ip.Ip) != 4 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ipNet = append(ipNet, &net.IPNet{
|
|
||||||
IP: net.IP(ip.Ip),
|
|
||||||
Mask: net.CIDRMask(int(ip.Prefix), 32),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
for _, n := range ipNet {
|
|
||||||
if n.Contains([]byte{8, 8, 8, 8}) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue