mirror of https://github.com/v2ray/v2ray-core
use session.Outbound.ResolvedIPs
parent
98d89aebc2
commit
82d562d1f0
|
@ -111,9 +111,18 @@ func targetFromContent(ctx context.Context) net.Destination {
|
||||||
return outbound.Target
|
return outbound.Target
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolvedIPFromContext(ctx context.Context) []net.IP {
|
||||||
|
outbound := session.OutboundFromContext(ctx)
|
||||||
|
if outbound == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return outbound.ResolvedIPs
|
||||||
|
}
|
||||||
|
|
||||||
type MultiGeoIPMatcher struct {
|
type MultiGeoIPMatcher struct {
|
||||||
matchers []*GeoIPMatcher
|
matchers []*GeoIPMatcher
|
||||||
destFunc func(context.Context) net.Destination
|
destFunc func(context.Context) net.Destination
|
||||||
|
resolvedIPFunc func(context.Context) []net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
|
func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
|
||||||
|
@ -126,17 +135,18 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e
|
||||||
matchers = append(matchers, matcher)
|
matchers = append(matchers, matcher)
|
||||||
}
|
}
|
||||||
|
|
||||||
var destFunc func(context.Context) net.Destination
|
matcher := &MultiGeoIPMatcher{
|
||||||
if onSource {
|
matchers: matchers,
|
||||||
destFunc = sourceFromContext
|
|
||||||
} else {
|
|
||||||
destFunc = targetFromContent
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &MultiGeoIPMatcher{
|
if onSource {
|
||||||
matchers: matchers,
|
matcher.destFunc = sourceFromContext
|
||||||
destFunc: destFunc,
|
} else {
|
||||||
}, nil
|
matcher.destFunc = targetFromContent
|
||||||
|
matcher.resolvedIPFunc = resolvedIPFromContext
|
||||||
|
}
|
||||||
|
|
||||||
|
return matcher, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
|
func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
|
||||||
|
@ -146,10 +156,12 @@ func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
|
||||||
|
|
||||||
if dest.IsValid() && dest.Address.Family().IsIP() {
|
if dest.IsValid() && dest.Address.Family().IsIP() {
|
||||||
ips = append(ips, dest.Address.IP())
|
ips = append(ips, dest.Address.IP())
|
||||||
} else if resolver, ok := ResolvedIPsFromContext(ctx); ok {
|
}
|
||||||
resolvedIPs := resolver.Resolve()
|
|
||||||
for _, rip := range resolvedIPs {
|
if m.resolvedIPFunc != nil {
|
||||||
ips = append(ips, rip.IP())
|
rips := m.resolvedIPFunc(ctx)
|
||||||
|
if len(rips) > 0 {
|
||||||
|
ips = append(ips, rips...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,32 +7,12 @@ import (
|
||||||
|
|
||||||
"v2ray.com/core"
|
"v2ray.com/core"
|
||||||
"v2ray.com/core/common"
|
"v2ray.com/core/common"
|
||||||
"v2ray.com/core/common/net"
|
|
||||||
"v2ray.com/core/common/session"
|
"v2ray.com/core/common/session"
|
||||||
"v2ray.com/core/features/dns"
|
"v2ray.com/core/features/dns"
|
||||||
"v2ray.com/core/features/outbound"
|
"v2ray.com/core/features/outbound"
|
||||||
"v2ray.com/core/features/routing"
|
"v2ray.com/core/features/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type key uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
resolvedIPsKey key = iota
|
|
||||||
)
|
|
||||||
|
|
||||||
type IPResolver interface {
|
|
||||||
Resolve() []net.Address
|
|
||||||
}
|
|
||||||
|
|
||||||
func ContextWithResolveIPs(ctx context.Context, f IPResolver) context.Context {
|
|
||||||
return context.WithValue(ctx, resolvedIPsKey, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ResolvedIPsFromContext(ctx context.Context) (IPResolver, bool) {
|
|
||||||
ips, ok := ctx.Value(resolvedIPsKey).(IPResolver)
|
|
||||||
return ips, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
||||||
r := new(Router)
|
r := new(Router)
|
||||||
|
@ -91,34 +71,6 @@ func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ipResolver struct {
|
|
||||||
dns dns.Client
|
|
||||||
ip []net.Address
|
|
||||||
domain string
|
|
||||||
resolved bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ipResolver) Resolve() []net.Address {
|
|
||||||
if r.resolved {
|
|
||||||
return r.ip
|
|
||||||
}
|
|
||||||
|
|
||||||
newError("looking for IP for domain: ", r.domain).WriteToLog()
|
|
||||||
r.resolved = true
|
|
||||||
ips, err := r.dns.LookupIP(r.domain)
|
|
||||||
if err != nil {
|
|
||||||
newError("failed to get IP address").Base(err).WriteToLog()
|
|
||||||
}
|
|
||||||
if len(ips) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
r.ip = make([]net.Address, len(ips))
|
|
||||||
for i, ip := range ips {
|
|
||||||
r.ip[i] = net.IPAddress(ip)
|
|
||||||
}
|
|
||||||
return r.ip
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Router) PickRoute(ctx context.Context) (string, error) {
|
func (r *Router) PickRoute(ctx context.Context) (string, error) {
|
||||||
rule, err := r.pickRouteInternal(ctx)
|
rule, err := r.pickRouteInternal(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -127,17 +79,27 @@ func (r *Router) PickRoute(ctx context.Context) (string, error) {
|
||||||
return rule.GetTag()
|
return rule.GetTag()
|
||||||
}
|
}
|
||||||
|
|
||||||
// PickRoute implements routing.Router.
|
func isDomainOutbound(outbound *session.Outbound) bool {
|
||||||
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
|
return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain()
|
||||||
resolver := &ipResolver{
|
}
|
||||||
dns: r.dns,
|
|
||||||
|
func (r *Router) resolveIP(outbound *session.Outbound) error {
|
||||||
|
domain := outbound.Target.Address.Domain()
|
||||||
|
ips, err := r.dns.LookupIP(domain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
outbound.ResolvedIPs = ips
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PickRoute implements routing.Router.
|
||||||
|
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
|
||||||
outbound := session.OutboundFromContext(ctx)
|
outbound := session.OutboundFromContext(ctx)
|
||||||
if r.domainStrategy == Config_IpOnDemand {
|
if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) {
|
||||||
if outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() {
|
if err := r.resolveIP(outbound); err != nil {
|
||||||
resolver.domain = outbound.Target.Address.Domain()
|
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
|
||||||
ctx = ContextWithResolveIPs(ctx, resolver)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,21 +109,19 @@ func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if outbound == nil || !outbound.Target.IsValid() {
|
if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) {
|
||||||
return nil, common.ErrNoClue
|
return nil, common.ErrNoClue
|
||||||
}
|
}
|
||||||
|
|
||||||
dest := outbound.Target
|
if err := r.resolveIP(outbound); err != nil {
|
||||||
if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() {
|
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
|
||||||
resolver.domain = dest.Address.Domain()
|
return nil, common.ErrNoClue
|
||||||
ips := resolver.Resolve()
|
}
|
||||||
if len(ips) > 0 {
|
|
||||||
ctx = ContextWithResolveIPs(ctx, resolver)
|
// Try applying rules again if we have IPs.
|
||||||
for _, rule := range r.rules {
|
for _, rule := range r.rules {
|
||||||
if rule.Apply(ctx) {
|
if rule.Apply(ctx) {
|
||||||
return rule, nil
|
return rule, nil
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -125,3 +125,72 @@ func TestIPOnDemand(t *testing.T) {
|
||||||
t.Error("expect tag 'test', bug actually ", tag)
|
t.Error("expect tag 'test', bug actually ", tag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIPIfNonMatchDomain(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
DomainStrategy: Config_IpIfNonMatch,
|
||||||
|
Rule: []*RoutingRule{
|
||||||
|
{
|
||||||
|
TargetTag: &RoutingRule_Tag{
|
||||||
|
Tag: "test",
|
||||||
|
},
|
||||||
|
Cidr: []*CIDR{
|
||||||
|
{
|
||||||
|
Ip: []byte{192, 168, 0, 0},
|
||||||
|
Prefix: 16,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockCtl := gomock.NewController(t)
|
||||||
|
defer mockCtl.Finish()
|
||||||
|
|
||||||
|
mockDns := mocks.NewDNSClient(mockCtl)
|
||||||
|
mockDns.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes()
|
||||||
|
|
||||||
|
r := new(Router)
|
||||||
|
common.Must(r.Init(config, mockDns, nil))
|
||||||
|
|
||||||
|
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
|
||||||
|
tag, err := r.PickRoute(ctx)
|
||||||
|
common.Must(err)
|
||||||
|
if tag != "test" {
|
||||||
|
t.Error("expect tag 'test', bug actually ", tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPIfNonMatchIP(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
DomainStrategy: Config_IpIfNonMatch,
|
||||||
|
Rule: []*RoutingRule{
|
||||||
|
{
|
||||||
|
TargetTag: &RoutingRule_Tag{
|
||||||
|
Tag: "test",
|
||||||
|
},
|
||||||
|
Cidr: []*CIDR{
|
||||||
|
{
|
||||||
|
Ip: []byte{127, 0, 0, 0},
|
||||||
|
Prefix: 8,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockCtl := gomock.NewController(t)
|
||||||
|
defer mockCtl.Finish()
|
||||||
|
|
||||||
|
mockDns := mocks.NewDNSClient(mockCtl)
|
||||||
|
|
||||||
|
r := new(Router)
|
||||||
|
common.Must(r.Init(config, mockDns, nil))
|
||||||
|
|
||||||
|
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
|
||||||
|
tag, err := r.PickRoute(ctx)
|
||||||
|
common.Must(err)
|
||||||
|
if tag != "test" {
|
||||||
|
t.Error("expect tag 'test', bug actually ", tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue