Browse Source

migrate to the new geoip matcher

pull/1350/head
Darien Raymond 6 years ago
parent
commit
41956e92a5
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
  1. 90
      app/router/condition.go
  2. 40
      app/router/config.go
  3. 83
      common/net/ipnet.go
  4. 133
      common/net/ipnet_test.go

90
app/router/condition.go

@ -120,22 +120,6 @@ func (m *DomainMatcher) Apply(ctx context.Context) bool {
return m.ApplyDomain(dest.Address.Domain())
}
type CIDRMatcher struct {
cidr *net.IPNet
onSource bool
}
func NewCIDRMatcher(ip []byte, mask uint32, onSource bool) (*CIDRMatcher, error) {
cidr := &net.IPNet{
IP: net.IP(ip),
Mask: net.CIDRMask(int(mask), len(ip)*8),
}
return &CIDRMatcher{
cidr: cidr,
onSource: onSource,
}, nil
}
func sourceFromContext(ctx context.Context) net.Destination {
inbound := session.InboundFromContext(ctx)
if inbound == nil {
@ -152,80 +136,6 @@ func targetFromContent(ctx context.Context) net.Destination {
return outbound.Target
}
func (v *CIDRMatcher) Apply(ctx context.Context) bool {
ips := make([]net.IP, 0, 4)
if resolver, ok := ResolvedIPsFromContext(ctx); ok {
resolvedIPs := resolver.Resolve()
for _, rip := range resolvedIPs {
if !rip.Family().IsIPv6() {
continue
}
ips = append(ips, rip.IP())
}
}
var dest net.Destination
if v.onSource {
dest = sourceFromContext(ctx)
} else {
dest = targetFromContent(ctx)
}
if dest.IsValid() && dest.Address.Family().IsIPv6() {
ips = append(ips, dest.Address.IP())
}
for _, ip := range ips {
if v.cidr.Contains(ip) {
return true
}
}
return false
}
type IPv4Matcher struct {
ipv4net *net.IPNetTable
onSource bool
}
func NewIPv4Matcher(ipnet *net.IPNetTable, onSource bool) *IPv4Matcher {
return &IPv4Matcher{
ipv4net: ipnet,
onSource: onSource,
}
}
func (v *IPv4Matcher) Apply(ctx context.Context) bool {
ips := make([]net.IP, 0, 4)
if resolver, ok := ResolvedIPsFromContext(ctx); ok {
resolvedIPs := resolver.Resolve()
for _, rip := range resolvedIPs {
if !rip.Family().IsIPv4() {
continue
}
ips = append(ips, rip.IP())
}
}
var dest net.Destination
if v.onSource {
dest = sourceFromContext(ctx)
} else {
dest = targetFromContent(ctx)
}
if dest.IsValid() && dest.Address.Family().IsIPv4() {
ips = append(ips, dest.Address.IP())
}
for _, ip := range ips {
if v.ipv4net.Contains(ip) {
return true
}
}
return false
}
type MultiGeoIPMatcher struct {
matchers []*GeoIPMatcher
onSource bool

40
app/router/config.go

@ -2,8 +2,6 @@ package router
import (
"context"
"v2ray.com/core/common/net"
)
// CIDRList is an alias of []*CIDR to provide sort.Interface.
@ -54,40 +52,6 @@ func (r *Rule) Apply(ctx context.Context) bool {
return r.Condition.Apply(ctx)
}
func cidrToCondition(cidr []*CIDR, source bool) (Condition, error) {
ipv4Net := net.NewIPNetTable()
ipv6Cond := NewAnyCondition()
hasIpv6 := false
for _, ip := range cidr {
switch len(ip.Ip) {
case net.IPv4len:
ipv4Net.AddIP(ip.Ip, byte(ip.Prefix))
case net.IPv6len:
hasIpv6 = true
matcher, err := NewCIDRMatcher(ip.Ip, ip.Prefix, source)
if err != nil {
return nil, err
}
ipv6Cond.Add(matcher)
default:
return nil, newError("invalid IP length").AtWarning()
}
}
switch {
case !ipv4Net.IsEmpty() && hasIpv6:
cond := NewAnyCondition()
cond.Add(NewIPv4Matcher(ipv4Net, source))
cond.Add(ipv6Cond)
return cond, nil
case !ipv4Net.IsEmpty():
return NewIPv4Matcher(ipv4Net, source), nil
default:
return ipv6Cond, nil
}
}
func (rr *RoutingRule) BuildCondition() (Condition, error) {
conds := NewConditionChan()
@ -122,7 +86,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
}
conds.Add(cond)
} else if len(rr.Cidr) > 0 {
cond, err := cidrToCondition(rr.Cidr, false)
cond, err := NewMultiGeoIPMatcher([]*GeoIP{{Cidr: rr.Cidr}}, false)
if err != nil {
return nil, err
}
@ -136,7 +100,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
}
conds.Add(cond)
} else if len(rr.SourceCidr) > 0 {
cond, err := cidrToCondition(rr.SourceCidr, true)
cond, err := NewMultiGeoIPMatcher([]*GeoIP{{Cidr: rr.SourceCidr}}, true)
if err != nil {
return nil, err
}

83
common/net/ipnet.go

@ -1,83 +0,0 @@
package net
import (
"math/bits"
"net"
)
type IPNetTable struct {
cache map[uint32]byte
}
func NewIPNetTable() *IPNetTable {
return &IPNetTable{
cache: make(map[uint32]byte, 1024),
}
}
func ipToUint32(ip IP) uint32 {
value := uint32(0)
for _, b := range []byte(ip) {
value <<= 8
value += uint32(b)
}
return value
}
func ipMaskToByte(mask net.IPMask) byte {
value := byte(0)
for _, b := range []byte(mask) {
value += byte(bits.OnesCount8(b))
}
return value
}
func (n *IPNetTable) Add(ipNet *net.IPNet) {
ipv4 := ipNet.IP.To4()
if ipv4 == nil {
// For now, we don't support IPv6
return
}
mask := ipMaskToByte(ipNet.Mask)
n.AddIP(ipv4, mask)
}
func (n *IPNetTable) AddIP(ip []byte, mask byte) {
k := ipToUint32(ip)
k = (k >> (32 - mask)) << (32 - mask) // normalize ip
existing, found := n.cache[k]
if !found || existing > mask {
n.cache[k] = mask
}
}
func (n *IPNetTable) Contains(ip net.IP) bool {
ipv4 := ip.To4()
if ipv4 == nil {
return false
}
originalValue := ipToUint32(ipv4)
if entry, found := n.cache[originalValue]; found {
if entry == 32 {
return true
}
}
mask := uint32(0)
for maskbit := byte(1); maskbit <= 32; maskbit++ {
mask += 1 << uint32(32-maskbit)
maskedValue := originalValue & mask
if entry, found := n.cache[maskedValue]; found {
if entry == maskbit {
return true
}
}
}
return false
}
func (n *IPNetTable) IsEmpty() bool {
return len(n.cache) == 0
}

133
common/net/ipnet_test.go

@ -1,133 +0,0 @@
package net_test
import (
"net"
"os"
"path/filepath"
"testing"
proto "github.com/golang/protobuf/proto"
"v2ray.com/core/app/router"
"v2ray.com/core/common/platform"
"v2ray.com/ext/sysio"
"v2ray.com/core/common"
. "v2ray.com/core/common/net"
. "v2ray.com/ext/assert"
)
func parseCIDR(str string) *net.IPNet {
_, ipNet, err := net.ParseCIDR(str)
common.Must(err)
return ipNet
}
func TestIPNet(t *testing.T) {
assert := With(t)
ipNet := NewIPNetTable()
ipNet.Add(parseCIDR(("0.0.0.0/8")))
ipNet.Add(parseCIDR(("10.0.0.0/8")))
ipNet.Add(parseCIDR(("100.64.0.0/10")))
ipNet.Add(parseCIDR(("127.0.0.0/8")))
ipNet.Add(parseCIDR(("169.254.0.0/16")))
ipNet.Add(parseCIDR(("172.16.0.0/12")))
ipNet.Add(parseCIDR(("192.0.0.0/24")))
ipNet.Add(parseCIDR(("192.0.2.0/24")))
ipNet.Add(parseCIDR(("192.168.0.0/16")))
ipNet.Add(parseCIDR(("198.18.0.0/15")))
ipNet.Add(parseCIDR(("198.51.100.0/24")))
ipNet.Add(parseCIDR(("203.0.113.0/24")))
ipNet.Add(parseCIDR(("8.8.8.8/32")))
ipNet.AddIP(net.ParseIP("91.108.4.0"), 16)
assert(ipNet.Contains(ParseIP("192.168.1.1")), IsTrue)
assert(ipNet.Contains(ParseIP("192.0.0.0")), IsTrue)
assert(ipNet.Contains(ParseIP("192.0.1.0")), IsFalse)
assert(ipNet.Contains(ParseIP("0.1.0.0")), IsTrue)
assert(ipNet.Contains(ParseIP("1.0.0.1")), IsFalse)
assert(ipNet.Contains(ParseIP("8.8.8.7")), IsFalse)
assert(ipNet.Contains(ParseIP("8.8.8.8")), IsTrue)
assert(ipNet.Contains(ParseIP("2001:cdba::3257:9652")), IsFalse)
assert(ipNet.Contains(ParseIP("91.108.255.254")), IsTrue)
}
func TestGeoIPCN(t *testing.T) {
assert := With(t)
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat")))
ips, err := loadGeoIP("CN")
common.Must(err)
ipNet := NewIPNetTable()
for _, ip := range ips {
ipNet.AddIP(ip.Ip, byte(ip.Prefix))
}
assert(ipNet.Contains([]byte{8, 8, 8, 8}), IsFalse)
}
func loadGeoIP(country string) ([]*router.CIDR, error) {
geoipBytes, err := sysio.ReadAsset("geoip.dat")
if err != nil {
return nil, err
}
var geoipList router.GeoIPList
if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
return nil, err
}
for _, geoip := range geoipList.Entry {
if geoip.CountryCode == country {
return geoip.Cidr, nil
}
}
panic("country not found: " + country)
}
func BenchmarkIPNetQuery(b *testing.B) {
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat")))
ips, err := loadGeoIP("CN")
common.Must(err)
ipNet := NewIPNetTable()
for _, ip := range ips {
ipNet.AddIP(ip.Ip, byte(ip.Prefix))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ipNet.Contains([]byte{8, 8, 8, 8})
}
}
func BenchmarkCIDRQuery(b *testing.B) {
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat")))
ips, err := loadGeoIP("CN")
common.Must(err)
ipNet := make([]*net.IPNet, 0, 1024)
for _, ip := range ips {
if len(ip.Ip) != 4 {
continue
}
ipNet = append(ipNet, &net.IPNet{
IP: net.IP(ip.Ip),
Mask: net.CIDRMask(int(ip.Prefix), 32),
})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, n := range ipNet {
if n.Contains([]byte{8, 8, 8, 8}) {
break
}
}
}
}
Loading…
Cancel
Save