diff --git a/core/xray.go b/core/xray.go index 0e1f0830..e0364b0f 100644 --- a/core/xray.go +++ b/core/xray.go @@ -44,22 +44,13 @@ func getFeature(allFeatures []features.Feature, t reflect.Type) features.Feature return nil } -func (r *resolution) resolve(allFeatures []features.Feature) (bool, error) { - var fs []features.Feature - for _, d := range r.deps { - f := getFeature(allFeatures, d) - if f == nil { - return false, nil - } - fs = append(fs, f) - } - +func (r *resolution) callbackResolution(allFeatures []features.Feature) error { callback := reflect.ValueOf(r.callback) var input []reflect.Value callbackType := callback.Type() for i := 0; i < callbackType.NumIn(); i++ { pt := callbackType.In(i) - for _, f := range fs { + for _, f := range allFeatures { if reflect.TypeOf(f).AssignableTo(pt) { input = append(input, reflect.ValueOf(f)) break @@ -84,15 +75,16 @@ func (r *resolution) resolve(allFeatures []features.Feature) (bool, error) { } } - return true, err + return err } // Instance combines all Xray features. type Instance struct { - access sync.Mutex + statusLock sync.Mutex features []features.Feature - featureResolutions []resolution + pendingResolutions []resolution running bool + resolveLock sync.Mutex ctx context.Context } @@ -227,9 +219,12 @@ func initInstanceWithConfig(config *Config, server *Instance) (bool, error) { }(), ) - if server.featureResolutions != nil { + server.resolveLock.Lock() + if server.pendingResolutions != nil { + server.resolveLock.Unlock() return true, errors.New("not all dependencies are resolved.") } + server.resolveLock.Unlock() if err := addInboundHandlers(server, config.Inbound); err != nil { return true, err @@ -248,8 +243,8 @@ func (s *Instance) Type() interface{} { // Close shutdown the Xray instance. func (s *Instance) Close() error { - s.access.Lock() - defer s.access.Unlock() + s.statusLock.Lock() + defer s.statusLock.Unlock() s.running = false @@ -283,17 +278,28 @@ func (s *Instance) RequireFeatures(callback interface{}) error { deps: featureTypes, callback: callback, } - if finished, err := r.resolve(s.features); finished { - return err + + s.resolveLock.Lock() + foundAll := true + for _, d := range r.deps { + f := getFeature(s.features, d) + if f == nil { + foundAll = false + break + } + } + if foundAll { + s.resolveLock.Unlock() + return r.callbackResolution(s.features) + } else { + s.pendingResolutions = append(s.pendingResolutions, r) + s.resolveLock.Unlock() + return nil } - s.featureResolutions = append(s.featureResolutions, r) - return nil } // AddFeature registers a feature into current Instance. func (s *Instance) AddFeature(feature features.Feature) error { - s.features = append(s.features, feature) - if s.running { if err := feature.Start(); err != nil { errors.LogInfoInner(s.ctx, err, "failed to start feature") @@ -301,27 +307,37 @@ func (s *Instance) AddFeature(feature features.Feature) error { return nil } - if s.featureResolutions == nil { + s.resolveLock.Lock() + s.features = append(s.features, feature) + if s.pendingResolutions == nil { + s.resolveLock.Unlock() return nil } - - var pendingResolutions []resolution - for _, r := range s.featureResolutions { - finished, err := r.resolve(s.features) - if finished && err != nil { - return err + var pending []resolution + var availableResolution []resolution + for _, r := range s.pendingResolutions { + foundAll := true + for _, d := range r.deps { + f := getFeature(s.features, d) + if f == nil { + foundAll = false + break + } } - if !finished { - pendingResolutions = append(pendingResolutions, r) + if foundAll { + availableResolution = append(availableResolution, r) + } else { + pending = append(pending, r) } } - if len(pendingResolutions) == 0 { - s.featureResolutions = nil - } else if len(pendingResolutions) < len(s.featureResolutions) { - s.featureResolutions = pendingResolutions + s.pendingResolutions = pending + s.resolveLock.Unlock() + + var err error + for _, r := range availableResolution { + err = r.callbackResolution(s.features) // only return the last error for now } - - return nil + return err } // GetFeature returns a feature of the given type, or nil if such feature is not registered. @@ -334,8 +350,8 @@ func (s *Instance) GetFeature(featureType interface{}) features.Feature { // // xray:api:stable func (s *Instance) Start() error { - s.access.Lock() - defer s.access.Unlock() + s.statusLock.Lock() + defer s.statusLock.Unlock() s.running = true for _, f := range s.features {