mirror of https://github.com/v2ray/v2ray-core
				
				
				
			use session.Outbound.ResolvedIPs
							parent
							
								
									98d89aebc2
								
							
						
					
					
						commit
						82d562d1f0
					
				|  | @ -111,9 +111,18 @@ func targetFromContent(ctx context.Context) net.Destination { | |||
| 	return outbound.Target | ||||
| } | ||||
| 
 | ||||
| func resolvedIPFromContext(ctx context.Context) []net.IP { | ||||
| 	outbound := session.OutboundFromContext(ctx) | ||||
| 	if outbound == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return outbound.ResolvedIPs | ||||
| } | ||||
| 
 | ||||
| type MultiGeoIPMatcher struct { | ||||
| 	matchers []*GeoIPMatcher | ||||
| 	destFunc func(context.Context) net.Destination | ||||
| 	matchers       []*GeoIPMatcher | ||||
| 	destFunc       func(context.Context) net.Destination | ||||
| 	resolvedIPFunc func(context.Context) []net.IP | ||||
| } | ||||
| 
 | ||||
| func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) { | ||||
|  | @ -126,17 +135,18 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e | |||
| 		matchers = append(matchers, matcher) | ||||
| 	} | ||||
| 
 | ||||
| 	var destFunc func(context.Context) net.Destination | ||||
| 	if onSource { | ||||
| 		destFunc = sourceFromContext | ||||
| 	} else { | ||||
| 		destFunc = targetFromContent | ||||
| 	matcher := &MultiGeoIPMatcher{ | ||||
| 		matchers: matchers, | ||||
| 	} | ||||
| 
 | ||||
| 	return &MultiGeoIPMatcher{ | ||||
| 		matchers: matchers, | ||||
| 		destFunc: destFunc, | ||||
| 	}, nil | ||||
| 	if onSource { | ||||
| 		matcher.destFunc = sourceFromContext | ||||
| 	} else { | ||||
| 		matcher.destFunc = targetFromContent | ||||
| 		matcher.resolvedIPFunc = resolvedIPFromContext | ||||
| 	} | ||||
| 
 | ||||
| 	return matcher, nil | ||||
| } | ||||
| 
 | ||||
| func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool { | ||||
|  | @ -146,10 +156,12 @@ func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool { | |||
| 
 | ||||
| 	if dest.IsValid() && dest.Address.Family().IsIP() { | ||||
| 		ips = append(ips, dest.Address.IP()) | ||||
| 	} else if resolver, ok := ResolvedIPsFromContext(ctx); ok { | ||||
| 		resolvedIPs := resolver.Resolve() | ||||
| 		for _, rip := range resolvedIPs { | ||||
| 			ips = append(ips, rip.IP()) | ||||
| 	} | ||||
| 
 | ||||
| 	if m.resolvedIPFunc != nil { | ||||
| 		rips := m.resolvedIPFunc(ctx) | ||||
| 		if len(rips) > 0 { | ||||
| 			ips = append(ips, rips...) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -7,32 +7,12 @@ import ( | |||
| 
 | ||||
| 	"v2ray.com/core" | ||||
| 	"v2ray.com/core/common" | ||||
| 	"v2ray.com/core/common/net" | ||||
| 	"v2ray.com/core/common/session" | ||||
| 	"v2ray.com/core/features/dns" | ||||
| 	"v2ray.com/core/features/outbound" | ||||
| 	"v2ray.com/core/features/routing" | ||||
| ) | ||||
| 
 | ||||
| type key uint32 | ||||
| 
 | ||||
| const ( | ||||
| 	resolvedIPsKey key = iota | ||||
| ) | ||||
| 
 | ||||
| type IPResolver interface { | ||||
| 	Resolve() []net.Address | ||||
| } | ||||
| 
 | ||||
| func ContextWithResolveIPs(ctx context.Context, f IPResolver) context.Context { | ||||
| 	return context.WithValue(ctx, resolvedIPsKey, f) | ||||
| } | ||||
| 
 | ||||
| func ResolvedIPsFromContext(ctx context.Context) (IPResolver, bool) { | ||||
| 	ips, ok := ctx.Value(resolvedIPsKey).(IPResolver) | ||||
| 	return ips, ok | ||||
| } | ||||
| 
 | ||||
| func init() { | ||||
| 	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { | ||||
| 		r := new(Router) | ||||
|  | @ -91,34 +71,6 @@ func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type ipResolver struct { | ||||
| 	dns      dns.Client | ||||
| 	ip       []net.Address | ||||
| 	domain   string | ||||
| 	resolved bool | ||||
| } | ||||
| 
 | ||||
| func (r *ipResolver) Resolve() []net.Address { | ||||
| 	if r.resolved { | ||||
| 		return r.ip | ||||
| 	} | ||||
| 
 | ||||
| 	newError("looking for IP for domain: ", r.domain).WriteToLog() | ||||
| 	r.resolved = true | ||||
| 	ips, err := r.dns.LookupIP(r.domain) | ||||
| 	if err != nil { | ||||
| 		newError("failed to get IP address").Base(err).WriteToLog() | ||||
| 	} | ||||
| 	if len(ips) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 	r.ip = make([]net.Address, len(ips)) | ||||
| 	for i, ip := range ips { | ||||
| 		r.ip[i] = net.IPAddress(ip) | ||||
| 	} | ||||
| 	return r.ip | ||||
| } | ||||
| 
 | ||||
| func (r *Router) PickRoute(ctx context.Context) (string, error) { | ||||
| 	rule, err := r.pickRouteInternal(ctx) | ||||
| 	if err != nil { | ||||
|  | @ -127,17 +79,27 @@ func (r *Router) PickRoute(ctx context.Context) (string, error) { | |||
| 	return rule.GetTag() | ||||
| } | ||||
| 
 | ||||
| // PickRoute implements routing.Router.
 | ||||
| func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) { | ||||
| 	resolver := &ipResolver{ | ||||
| 		dns: r.dns, | ||||
| func isDomainOutbound(outbound *session.Outbound) bool { | ||||
| 	return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() | ||||
| } | ||||
| 
 | ||||
| func (r *Router) resolveIP(outbound *session.Outbound) error { | ||||
| 	domain := outbound.Target.Address.Domain() | ||||
| 	ips, err := r.dns.LookupIP(domain) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	outbound.ResolvedIPs = ips | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // PickRoute implements routing.Router.
 | ||||
| func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) { | ||||
| 	outbound := session.OutboundFromContext(ctx) | ||||
| 	if r.domainStrategy == Config_IpOnDemand { | ||||
| 		if outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() { | ||||
| 			resolver.domain = outbound.Target.Address.Domain() | ||||
| 			ctx = ContextWithResolveIPs(ctx, resolver) | ||||
| 	if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) { | ||||
| 		if err := r.resolveIP(outbound); err != nil { | ||||
| 			newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx)) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
|  | @ -147,21 +109,19 @@ func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) { | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if outbound == nil || !outbound.Target.IsValid() { | ||||
| 	if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) { | ||||
| 		return nil, common.ErrNoClue | ||||
| 	} | ||||
| 
 | ||||
| 	dest := outbound.Target | ||||
| 	if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() { | ||||
| 		resolver.domain = dest.Address.Domain() | ||||
| 		ips := resolver.Resolve() | ||||
| 		if len(ips) > 0 { | ||||
| 			ctx = ContextWithResolveIPs(ctx, resolver) | ||||
| 			for _, rule := range r.rules { | ||||
| 				if rule.Apply(ctx) { | ||||
| 					return rule, nil | ||||
| 				} | ||||
| 			} | ||||
| 	if err := r.resolveIP(outbound); err != nil { | ||||
| 		newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx)) | ||||
| 		return nil, common.ErrNoClue | ||||
| 	} | ||||
| 
 | ||||
| 	// Try applying rules again if we have IPs.
 | ||||
| 	for _, rule := range r.rules { | ||||
| 		if rule.Apply(ctx) { | ||||
| 			return rule, nil | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -125,3 +125,72 @@ func TestIPOnDemand(t *testing.T) { | |||
| 		t.Error("expect tag 'test', bug actually ", tag) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestIPIfNonMatchDomain(t *testing.T) { | ||||
| 	config := &Config{ | ||||
| 		DomainStrategy: Config_IpIfNonMatch, | ||||
| 		Rule: []*RoutingRule{ | ||||
| 			{ | ||||
| 				TargetTag: &RoutingRule_Tag{ | ||||
| 					Tag: "test", | ||||
| 				}, | ||||
| 				Cidr: []*CIDR{ | ||||
| 					{ | ||||
| 						Ip:     []byte{192, 168, 0, 0}, | ||||
| 						Prefix: 16, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	mockCtl := gomock.NewController(t) | ||||
| 	defer mockCtl.Finish() | ||||
| 
 | ||||
| 	mockDns := mocks.NewDNSClient(mockCtl) | ||||
| 	mockDns.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes() | ||||
| 
 | ||||
| 	r := new(Router) | ||||
| 	common.Must(r.Init(config, mockDns, nil)) | ||||
| 
 | ||||
| 	ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) | ||||
| 	tag, err := r.PickRoute(ctx) | ||||
| 	common.Must(err) | ||||
| 	if tag != "test" { | ||||
| 		t.Error("expect tag 'test', bug actually ", tag) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestIPIfNonMatchIP(t *testing.T) { | ||||
| 	config := &Config{ | ||||
| 		DomainStrategy: Config_IpIfNonMatch, | ||||
| 		Rule: []*RoutingRule{ | ||||
| 			{ | ||||
| 				TargetTag: &RoutingRule_Tag{ | ||||
| 					Tag: "test", | ||||
| 				}, | ||||
| 				Cidr: []*CIDR{ | ||||
| 					{ | ||||
| 						Ip:     []byte{127, 0, 0, 0}, | ||||
| 						Prefix: 8, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	mockCtl := gomock.NewController(t) | ||||
| 	defer mockCtl.Finish() | ||||
| 
 | ||||
| 	mockDns := mocks.NewDNSClient(mockCtl) | ||||
| 
 | ||||
| 	r := new(Router) | ||||
| 	common.Must(r.Init(config, mockDns, nil)) | ||||
| 
 | ||||
| 	ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)}) | ||||
| 	tag, err := r.PickRoute(ctx) | ||||
| 	common.Must(err) | ||||
| 	if tag != "test" { | ||||
| 		t.Error("expect tag 'test', bug actually ", tag) | ||||
| 	} | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Darien Raymond
						Darien Raymond