Merge pull request #108 from Vigilans/vigilans/routing-context

Extract session information during routing as routing context
pull/2725/head
Kslr 2020-09-04 13:39:12 +08:00 committed by GitHub
commit b083aa2376
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 263 additions and 133 deletions

View File

@ -20,6 +20,7 @@ import (
"v2ray.com/core/features/outbound"
"v2ray.com/core/features/policy"
"v2ray.com/core/features/routing"
routing_session "v2ray.com/core/features/routing/session"
"v2ray.com/core/features/stats"
"v2ray.com/core/transport"
"v2ray.com/core/transport/pipe"
@ -265,7 +266,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
}
if d.router != nil && !skipRoutePick {
if tag, err := d.router.PickRoute(ctx); err == nil {
if tag, err := d.router.PickRoute(routing_session.AsRoutingContext(ctx)); err == nil {
if h := d.ohm.GetHandler(tag); h != nil {
newError("taking detour [", tag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
handler = h

View File

@ -10,10 +10,11 @@ import (
"v2ray.com/core/common/net"
"v2ray.com/core/common/strmatcher"
"v2ray.com/core/features/routing"
)
type Condition interface {
Apply(ctx *Context) bool
Apply(ctx routing.Context) bool
}
type ConditionChan []Condition
@ -28,7 +29,8 @@ func (v *ConditionChan) Add(cond Condition) *ConditionChan {
return v
}
func (v *ConditionChan) Apply(ctx *Context) bool {
// Apply applies all conditions registered in this chan.
func (v *ConditionChan) Apply(ctx routing.Context) bool {
for _, cond := range *v {
if !cond.Apply(ctx) {
return false
@ -85,36 +87,18 @@ func (m *DomainMatcher) ApplyDomain(domain string) bool {
return len(m.matchers.Match(domain)) > 0
}
func (m *DomainMatcher) Apply(ctx *Context) bool {
if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
// Apply implements Condition.
func (m *DomainMatcher) Apply(ctx routing.Context) bool {
domain := ctx.GetTargetDomain()
if len(domain) == 0 {
return false
}
dest := ctx.Outbound.Target
if !dest.Address.Family().IsDomain() {
return false
}
return m.ApplyDomain(dest.Address.Domain())
}
func getIPsFromSource(ctx *Context) []net.IP {
if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() {
return nil
}
dest := ctx.Inbound.Source
if dest.Address.Family().IsDomain() {
return nil
}
return []net.IP{dest.Address.IP()}
}
func getIPsFromTarget(ctx *Context) []net.IP {
return ctx.GetTargetIPs()
return m.ApplyDomain(domain)
}
type MultiGeoIPMatcher struct {
matchers []*GeoIPMatcher
ipFunc func(*Context) []net.IP
onSource bool
}
func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
@ -129,20 +113,20 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e
matcher := &MultiGeoIPMatcher{
matchers: matchers,
}
if onSource {
matcher.ipFunc = getIPsFromSource
} else {
matcher.ipFunc = getIPsFromTarget
onSource: onSource,
}
return matcher, nil
}
func (m *MultiGeoIPMatcher) Apply(ctx *Context) bool {
ips := m.ipFunc(ctx)
// Apply implements Condition.
func (m *MultiGeoIPMatcher) Apply(ctx routing.Context) bool {
var ips []net.IP
if m.onSource {
ips = ctx.GetSourceIPs()
} else {
ips = ctx.GetTargetIPs()
}
for _, ip := range ips {
for _, matcher := range m.matchers {
if matcher.Match(ip) {
@ -166,20 +150,13 @@ func NewPortMatcher(list *net.PortList, onSource bool) *PortMatcher {
}
}
func (v *PortMatcher) Apply(ctx *Context) bool {
var port net.Port
// Apply implements Condition.
func (v *PortMatcher) Apply(ctx routing.Context) bool {
if v.onSource {
if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() {
return false
}
port = ctx.Inbound.Source.Port
return v.port.Contains(ctx.GetSourcePort())
} else {
if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
return false
}
port = ctx.Outbound.Target.Port
return v.port.Contains(ctx.GetTargetPort())
}
return v.port.Contains(port)
}
type NetworkMatcher struct {
@ -194,11 +171,9 @@ func NewNetworkMatcher(network []net.Network) NetworkMatcher {
return matcher
}
func (v NetworkMatcher) Apply(ctx *Context) bool {
if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
return false
}
return v.list[int(ctx.Outbound.Target.Network)]
// Apply implements Condition.
func (v NetworkMatcher) Apply(ctx routing.Context) bool {
return v.list[int(ctx.GetNetwork())]
}
type UserMatcher struct {
@ -217,17 +192,14 @@ func NewUserMatcher(users []string) *UserMatcher {
}
}
func (v *UserMatcher) Apply(ctx *Context) bool {
if ctx.Inbound == nil {
return false
}
user := ctx.Inbound.User
if user == nil {
// Apply implements Condition.
func (v *UserMatcher) Apply(ctx routing.Context) bool {
user := ctx.GetUser()
if len(user) == 0 {
return false
}
for _, u := range v.user {
if u == user.Email {
if u == user {
return true
}
}
@ -250,11 +222,12 @@ func NewInboundTagMatcher(tags []string) *InboundTagMatcher {
}
}
func (v *InboundTagMatcher) Apply(ctx *Context) bool {
if ctx.Inbound == nil || len(ctx.Inbound.Tag) == 0 {
// Apply implements Condition.
func (v *InboundTagMatcher) Apply(ctx routing.Context) bool {
tag := ctx.GetInboundTag()
if len(tag) == 0 {
return false
}
tag := ctx.Inbound.Tag
for _, t := range v.tags {
if t == tag {
return true
@ -281,18 +254,17 @@ func NewProtocolMatcher(protocols []string) *ProtocolMatcher {
}
}
func (m *ProtocolMatcher) Apply(ctx *Context) bool {
if ctx.Content == nil {
// Apply implements Condition.
func (m *ProtocolMatcher) Apply(ctx routing.Context) bool {
protocol := ctx.GetProtocol()
if len(protocol) == 0 {
return false
}
protocol := ctx.Content.Protocol
for _, p := range m.protocols {
if strings.HasPrefix(protocol, p) {
return true
}
}
return false
}
@ -343,9 +315,11 @@ func (m *AttributeMatcher) Match(attrs map[string]interface{}) bool {
return satisfied != nil && bool(satisfied.Truth())
}
func (m *AttributeMatcher) Apply(ctx *Context) bool {
if ctx.Content == nil {
// Apply implements Condition.
func (m *AttributeMatcher) Apply(ctx routing.Context) bool {
attributes := ctx.GetAttributes()
if attributes == nil {
return false
}
return m.Match(ctx.Content.Attributes)
return m.Match(attributes)
}

View File

@ -17,6 +17,8 @@ import (
"v2ray.com/core/common/protocol"
"v2ray.com/core/common/protocol/http"
"v2ray.com/core/common/session"
"v2ray.com/core/features/routing"
routing_session "v2ray.com/core/features/routing/session"
)
func init() {
@ -31,17 +33,25 @@ func init() {
}
}
func withOutbound(outbound *session.Outbound) *Context {
return &Context{Outbound: outbound}
func withBackground() routing.Context {
return &routing_session.Context{}
}
func withInbound(inbound *session.Inbound) *Context {
return &Context{Inbound: inbound}
func withOutbound(outbound *session.Outbound) routing.Context {
return &routing_session.Context{Outbound: outbound}
}
func withInbound(inbound *session.Inbound) routing.Context {
return &routing_session.Context{Inbound: inbound}
}
func withContent(content *session.Content) routing.Context {
return &routing_session.Context{Content: content}
}
func TestRoutingRule(t *testing.T) {
type ruleTest struct {
input *Context
input routing.Context
output bool
}
@ -92,7 +102,7 @@ func TestRoutingRule(t *testing.T) {
output: false,
},
{
input: &Context{},
input: withBackground(),
output: false,
},
},
@ -128,7 +138,7 @@ func TestRoutingRule(t *testing.T) {
output: true,
},
{
input: &Context{},
input: withBackground(),
output: false,
},
},
@ -168,7 +178,7 @@ func TestRoutingRule(t *testing.T) {
output: true,
},
{
input: &Context{},
input: withBackground(),
output: false,
},
},
@ -209,7 +219,7 @@ func TestRoutingRule(t *testing.T) {
output: false,
},
{
input: &Context{},
input: withBackground(),
output: false,
},
},
@ -220,7 +230,7 @@ func TestRoutingRule(t *testing.T) {
},
test: []ruleTest{
{
input: &Context{Content: &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}},
input: withContent(&session.Content{Protocol: (&http.SniffHeader{}).Protocol()}),
output: true,
},
},
@ -303,7 +313,7 @@ func TestRoutingRule(t *testing.T) {
},
test: []ruleTest{
{
input: &Context{Content: &session.Content{Protocol: "http/1.1", Attributes: map[string]interface{}{":path": "/test/1"}}},
input: withContent(&session.Content{Protocol: "http/1.1", Attributes: map[string]interface{}{":path": "/test/1"}}),
output: true,
},
},

View File

@ -5,6 +5,7 @@ package router
import (
"v2ray.com/core/common/net"
"v2ray.com/core/features/outbound"
"v2ray.com/core/features/routing"
)
// CIDRList is an alias of []*CIDR to provide sort.Interface.
@ -59,7 +60,8 @@ func (r *Rule) GetTag() (string, error) {
return r.Tag, nil
}
func (r *Rule) Apply(ctx *Context) bool {
// Apply checks rule matching of current routing context.
func (r *Rule) Apply(ctx routing.Context) bool {
return r.Condition.Apply(ctx)
}

View File

@ -10,7 +10,6 @@ import (
"v2ray.com/core"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/session"
"v2ray.com/core/features/dns"
"v2ray.com/core/features/outbound"
"v2ray.com/core/features/routing"
@ -74,7 +73,8 @@ func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error
return nil
}
func (r *Router) PickRoute(ctx context.Context) (string, error) {
// PickRoute implements routing.Router.
func (r *Router) PickRoute(ctx routing.Context) (string, error) {
rule, err := r.pickRouteInternal(ctx)
if err != nil {
return "", err
@ -82,37 +82,26 @@ func (r *Router) PickRoute(ctx context.Context) (string, error) {
return rule.GetTag()
}
func isDomainOutbound(outbound *session.Outbound) bool {
return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain()
}
// PickRoute implements routing.Router.
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
sessionContext := &Context{
Inbound: session.InboundFromContext(ctx),
Outbound: session.OutboundFromContext(ctx),
Content: session.ContentFromContext(ctx),
}
func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, error) {
if r.domainStrategy == Config_IpOnDemand {
sessionContext.dnsClient = r.dns
ctx = ContextWithDNSClient(ctx, r.dns)
}
for _, rule := range r.rules {
if rule.Apply(sessionContext) {
if rule.Apply(ctx) {
return rule, nil
}
}
if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(sessionContext.Outbound) {
if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 {
return nil, common.ErrNoClue
}
sessionContext.dnsClient = r.dns
ctx = ContextWithDNSClient(ctx, r.dns)
// Try applying rules again if we have IPs.
for _, rule := range r.rules {
if rule.Apply(sessionContext) {
if rule.Apply(ctx) {
return rule, nil
}
}
@ -135,32 +124,30 @@ func (*Router) Type() interface{} {
return routing.RouterType()
}
type Context struct {
Inbound *session.Inbound
Outbound *session.Outbound
Content *session.Content
dnsClient dns.Client
// ContextWithDNSClient creates a new routing context with domain resolving capability. Resolved domain IPs can be retrieved by GetTargetIPs().
func ContextWithDNSClient(ctx routing.Context, client dns.Client) routing.Context {
return &resolvableContext{Context: ctx, dnsClient: client}
}
func (c *Context) GetTargetIPs() []net.IP {
if c.Outbound == nil || !c.Outbound.Target.IsValid() {
return nil
type resolvableContext struct {
routing.Context
dnsClient dns.Client
resolvedIPs []net.IP
}
func (ctx *resolvableContext) GetTargetIPs() []net.IP {
if ips := ctx.Context.GetTargetIPs(); len(ips) != 0 {
return ips
}
if c.Outbound.Target.Address.Family().IsIP() {
return []net.IP{c.Outbound.Target.Address.IP()}
if len(ctx.resolvedIPs) > 0 {
return ctx.resolvedIPs
}
if len(c.Outbound.ResolvedIPs) > 0 {
return c.Outbound.ResolvedIPs
}
if c.dnsClient != nil {
domain := c.Outbound.Target.Address.Domain()
ips, err := c.dnsClient.LookupIP(domain)
if domain := ctx.GetTargetDomain(); len(domain) != 0 {
ips, err := ctx.dnsClient.LookupIP(domain)
if err == nil {
c.Outbound.ResolvedIPs = ips
ctx.resolvedIPs = ips
return ips
}
newError("resolve ip for ", domain).Base(err).WriteToLog()

View File

@ -10,6 +10,7 @@ import (
"v2ray.com/core/common/net"
"v2ray.com/core/common/session"
"v2ray.com/core/features/outbound"
routing_session "v2ray.com/core/features/routing/session"
"v2ray.com/core/testing/mocks"
)
@ -44,7 +45,7 @@ func TestSimpleRouter(t *testing.T) {
}))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx)
tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag != "test" {
t.Error("expect tag 'test', bug actually ", tag)
@ -85,7 +86,7 @@ func TestSimpleBalancer(t *testing.T) {
}))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx)
tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag != "test" {
t.Error("expect tag 'test', bug actually ", tag)
@ -120,7 +121,7 @@ func TestIPOnDemand(t *testing.T) {
common.Must(r.Init(config, mockDns, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx)
tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag != "test" {
t.Error("expect tag 'test', bug actually ", tag)
@ -155,7 +156,7 @@ func TestIPIfNonMatchDomain(t *testing.T) {
common.Must(r.Init(config, mockDns, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx)
tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag != "test" {
t.Error("expect tag 'test', bug actually ", tag)
@ -189,7 +190,7 @@ func TestIPIfNonMatchIP(t *testing.T) {
common.Must(r.Init(config, mockDns, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
tag, err := r.PickRoute(ctx)
tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag != "test" {
t.Error("expect tag 'test', bug actually ", tag)

View File

@ -51,8 +51,6 @@ type Outbound struct {
Target net.Destination
// Gateway address
Gateway net.Address
// ResolvedIPs is the resolved IP addresses, if the Targe is a domain address.
ResolvedIPs []net.IP
}
type SniffingRequest struct {

View File

@ -0,0 +1,40 @@
package routing
import (
"v2ray.com/core/common/net"
)
// Context is a feature to store connection information for routing.
//
// v2ray:api:beta
type Context interface {
// GetInboundTag returns the tag of the inbound the connection was from.
GetInboundTag() string
// GetSourcesIPs returns the source IPs bound to the connection.
GetSourceIPs() []net.IP
// GetSourcePort returns the source port of the connection.
GetSourcePort() net.Port
// GetTargetIPs returns the target IP of the connection or resolved IPs of target domain.
GetTargetIPs() []net.IP
// GetTargetPort returns the target port of the connection.
GetTargetPort() net.Port
// GetTargetDomain returns the target domain of the connection, if exists.
GetTargetDomain() string
// GetNetwork returns the network type of the connection.
GetNetwork() net.Network
// GetProtocol returns the protocol from the connection content, if sniffed out.
GetProtocol() string
// GetUser returns the user email from the connection content, if exists.
GetUser() string
// GetAttributes returns extra attributes from the conneciont content.
GetAttributes() map[string]interface{}
}

View File

@ -1,20 +1,18 @@
package routing
import (
"context"
"v2ray.com/core/common"
"v2ray.com/core/features"
)
// Router is a feature to choose an outbound tag for the given request.
//
// v2ray:api:stable
// v2ray:api:beta
type Router interface {
features.Feature
// PickRoute returns a tag of an OutboundHandler based on the given context.
PickRoute(ctx context.Context) (string, error)
PickRoute(ctx Context) (string, error)
}
// RouterType return the type of Router interface. Can be used to implement common.HasType.
@ -33,7 +31,7 @@ func (DefaultRouter) Type() interface{} {
}
// PickRoute implements Router.
func (DefaultRouter) PickRoute(ctx context.Context) (string, error) {
func (DefaultRouter) PickRoute(ctx Context) (string, error) {
return "", common.ErrNoClue
}

View File

@ -0,0 +1,119 @@
package session
import (
"context"
"v2ray.com/core/common/net"
"v2ray.com/core/common/session"
"v2ray.com/core/features/routing"
)
// Context is an implementation of routing.Context, which is a wrapper of context.context with session info.
type Context struct {
Inbound *session.Inbound
Outbound *session.Outbound
Content *session.Content
}
// GetInboundTag implements routing.Context.
func (ctx *Context) GetInboundTag() string {
if ctx.Inbound == nil {
return ""
}
return ctx.Inbound.Tag
}
// GetSourceIPs implements routing.Context.
func (ctx *Context) GetSourceIPs() []net.IP {
if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() {
return nil
}
dest := ctx.Inbound.Source
if dest.Address.Family().IsDomain() {
return nil
}
return []net.IP{dest.Address.IP()}
}
// GetSourcePort implements routing.Context.
func (ctx *Context) GetSourcePort() net.Port {
if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() {
return 0
}
return ctx.Inbound.Source.Port
}
// GetTargetIPs implements routing.Context.
func (ctx *Context) GetTargetIPs() []net.IP {
if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
return nil
}
if ctx.Outbound.Target.Address.Family().IsIP() {
return []net.IP{ctx.Outbound.Target.Address.IP()}
}
return nil
}
// GetTargetPort implements routing.Context.
func (ctx *Context) GetTargetPort() net.Port {
if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
return 0
}
return ctx.Outbound.Target.Port
}
// GetTargetDomain implements routing.Context.
func (ctx *Context) GetTargetDomain() string {
if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
return ""
}
dest := ctx.Outbound.Target
if !dest.Address.Family().IsDomain() {
return ""
}
return dest.Address.Domain()
}
// GetNetwork implements routing.Context.
func (ctx *Context) GetNetwork() net.Network {
if ctx.Outbound == nil {
return net.Network_Unknown
}
return ctx.Outbound.Target.Network
}
// GetProtocol implements routing.Context.
func (ctx *Context) GetProtocol() string {
if ctx.Content == nil {
return ""
}
return ctx.Content.Protocol
}
// GetUser implements routing.Context.
func (ctx *Context) GetUser() string {
if ctx.Inbound == nil {
return ""
}
return ctx.Inbound.User.Email
}
// GetAttributes implements routing.Context.
func (ctx *Context) GetAttributes() map[string]interface{} {
if ctx.Content == nil {
return nil
}
return ctx.Content.Attributes
}
// AsRoutingContext creates a context from context.context with session info.
func AsRoutingContext(ctx context.Context) routing.Context {
return &Context{
Inbound: session.InboundFromContext(ctx),
Outbound: session.OutboundFromContext(ctx),
Content: session.ContentFromContext(ctx),
}
}