mirror of https://github.com/XTLS/Xray-core
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
5.9 KiB
255 lines
5.9 KiB
package router |
|
|
|
import ( |
|
"context" |
|
sync "sync" |
|
|
|
"github.com/xtls/xray-core/common" |
|
"github.com/xtls/xray-core/common/errors" |
|
"github.com/xtls/xray-core/common/serial" |
|
"github.com/xtls/xray-core/core" |
|
"github.com/xtls/xray-core/features/dns" |
|
"github.com/xtls/xray-core/features/outbound" |
|
"github.com/xtls/xray-core/features/routing" |
|
routing_dns "github.com/xtls/xray-core/features/routing/dns" |
|
) |
|
|
|
// Router is an implementation of routing.Router. |
|
type Router struct { |
|
domainStrategy Config_DomainStrategy |
|
rules []*Rule |
|
balancers map[string]*Balancer |
|
dns dns.Client |
|
|
|
ctx context.Context |
|
ohm outbound.Manager |
|
dispatcher routing.Dispatcher |
|
mu sync.Mutex |
|
} |
|
|
|
// Route is an implementation of routing.Route. |
|
type Route struct { |
|
routing.Context |
|
outboundGroupTags []string |
|
outboundTag string |
|
ruleTag string |
|
} |
|
|
|
// Init initializes the Router. |
|
func (r *Router) Init(ctx context.Context, config *Config, d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error { |
|
r.domainStrategy = config.DomainStrategy |
|
r.dns = d |
|
r.ctx = ctx |
|
r.ohm = ohm |
|
r.dispatcher = dispatcher |
|
|
|
r.balancers = make(map[string]*Balancer, len(config.BalancingRule)) |
|
for _, rule := range config.BalancingRule { |
|
balancer, err := rule.Build(ohm, dispatcher) |
|
if err != nil { |
|
return err |
|
} |
|
balancer.InjectContext(ctx) |
|
r.balancers[rule.Tag] = balancer |
|
} |
|
|
|
r.rules = make([]*Rule, 0, len(config.Rule)) |
|
for _, rule := range config.Rule { |
|
cond, err := rule.BuildCondition() |
|
if err != nil { |
|
return err |
|
} |
|
rr := &Rule{ |
|
Condition: cond, |
|
Tag: rule.GetTag(), |
|
RuleTag: rule.GetRuleTag(), |
|
} |
|
btag := rule.GetBalancingTag() |
|
if len(btag) > 0 { |
|
brule, found := r.balancers[btag] |
|
if !found { |
|
return errors.New("balancer ", btag, " not found") |
|
} |
|
rr.Balancer = brule |
|
} |
|
r.rules = append(r.rules, rr) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
// PickRoute implements routing.Router. |
|
func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) { |
|
rule, ctx, err := r.pickRouteInternal(ctx) |
|
if err != nil { |
|
return nil, err |
|
} |
|
tag, err := rule.GetTag() |
|
if err != nil { |
|
return nil, err |
|
} |
|
return &Route{Context: ctx, outboundTag: tag, ruleTag: rule.RuleTag}, nil |
|
} |
|
|
|
// AddRule implements routing.Router. |
|
func (r *Router) AddRule(config *serial.TypedMessage, shouldAppend bool) error { |
|
|
|
inst, err := config.GetInstance() |
|
if err != nil { |
|
return err |
|
} |
|
if c, ok := inst.(*Config); ok { |
|
return r.ReloadRules(c, shouldAppend) |
|
} |
|
return errors.New("AddRule: config type error") |
|
} |
|
|
|
func (r *Router) ReloadRules(config *Config, shouldAppend bool) error { |
|
r.mu.Lock() |
|
defer r.mu.Unlock() |
|
|
|
if !shouldAppend { |
|
r.balancers = make(map[string]*Balancer, len(config.BalancingRule)) |
|
r.rules = make([]*Rule, 0, len(config.Rule)) |
|
} |
|
for _, rule := range config.BalancingRule { |
|
_, found := r.balancers[rule.Tag] |
|
if found { |
|
return errors.New("duplicate balancer tag") |
|
} |
|
balancer, err := rule.Build(r.ohm, r.dispatcher) |
|
if err != nil { |
|
return err |
|
} |
|
balancer.InjectContext(r.ctx) |
|
r.balancers[rule.Tag] = balancer |
|
} |
|
|
|
for _, rule := range config.Rule { |
|
if r.RuleExists(rule.GetRuleTag()) { |
|
return errors.New("duplicate ruleTag ", rule.GetRuleTag()) |
|
} |
|
cond, err := rule.BuildCondition() |
|
if err != nil { |
|
return err |
|
} |
|
rr := &Rule{ |
|
Condition: cond, |
|
Tag: rule.GetTag(), |
|
RuleTag: rule.GetRuleTag(), |
|
} |
|
btag := rule.GetBalancingTag() |
|
if len(btag) > 0 { |
|
brule, found := r.balancers[btag] |
|
if !found { |
|
return errors.New("balancer ", btag, " not found") |
|
} |
|
rr.Balancer = brule |
|
} |
|
r.rules = append(r.rules, rr) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (r *Router) RuleExists(tag string) bool { |
|
if tag != "" { |
|
for _, rule := range r.rules { |
|
if rule.RuleTag == tag { |
|
return true |
|
} |
|
} |
|
} |
|
return false |
|
} |
|
|
|
// RemoveRule implements routing.Router. |
|
func (r *Router) RemoveRule(tag string) error { |
|
r.mu.Lock() |
|
defer r.mu.Unlock() |
|
|
|
newRules := []*Rule{} |
|
if tag != "" { |
|
for _, rule := range r.rules { |
|
if rule.RuleTag != tag { |
|
newRules = append(newRules, rule) |
|
} |
|
} |
|
r.rules = newRules |
|
return nil |
|
} |
|
return errors.New("empty tag name!") |
|
|
|
} |
|
func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, routing.Context, error) { |
|
// SkipDNSResolve is set from DNS module. |
|
// the DOH remote server maybe a domain name, |
|
// this prevents cycle resolving dead loop |
|
skipDNSResolve := ctx.GetSkipDNSResolve() |
|
|
|
if r.domainStrategy == Config_IpOnDemand && !skipDNSResolve { |
|
ctx = routing_dns.ContextWithDNSClient(ctx, r.dns) |
|
} |
|
|
|
for _, rule := range r.rules { |
|
if rule.Apply(ctx) { |
|
return rule, ctx, nil |
|
} |
|
} |
|
|
|
if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 || skipDNSResolve { |
|
return nil, ctx, common.ErrNoClue |
|
} |
|
|
|
ctx = routing_dns.ContextWithDNSClient(ctx, r.dns) |
|
|
|
// Try applying rules again if we have IPs. |
|
for _, rule := range r.rules { |
|
if rule.Apply(ctx) { |
|
return rule, ctx, nil |
|
} |
|
} |
|
|
|
return nil, ctx, common.ErrNoClue |
|
} |
|
|
|
// Start implements common.Runnable. |
|
func (r *Router) Start() error { |
|
return nil |
|
} |
|
|
|
// Close implements common.Closable. |
|
func (r *Router) Close() error { |
|
return nil |
|
} |
|
|
|
// Type implements common.HasType. |
|
func (*Router) Type() interface{} { |
|
return routing.RouterType() |
|
} |
|
|
|
// GetOutboundGroupTags implements routing.Route. |
|
func (r *Route) GetOutboundGroupTags() []string { |
|
return r.outboundGroupTags |
|
} |
|
|
|
// GetOutboundTag implements routing.Route. |
|
func (r *Route) GetOutboundTag() string { |
|
return r.outboundTag |
|
} |
|
|
|
func (r *Route) GetRuleTag() string { |
|
return r.ruleTag |
|
} |
|
|
|
func init() { |
|
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { |
|
r := new(Router) |
|
if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error { |
|
return r.Init(ctx, config.(*Config), d, ohm, dispatcher) |
|
}); err != nil { |
|
return nil, err |
|
} |
|
return r, nil |
|
})) |
|
}
|
|
|