diff --git a/common/task/task.go b/common/task/task.go index 4c3a566c..72f327ef 100644 --- a/common/task/task.go +++ b/common/task/task.go @@ -10,29 +10,23 @@ type Task func() error type executionContext struct { ctx context.Context - task Task + tasks []Task onSuccess Task onFailure Task } func (c *executionContext) executeTask() error { - if c.ctx == nil && c.task == nil { + if len(c.tasks) == 0 { return nil } - if c.ctx == nil { - return c.task() + ctx := context.Background() + + if c.ctx != nil { + ctx = c.ctx } - if c.task == nil { - <-c.ctx.Done() - return c.ctx.Err() - } - - return executeParallel(func() error { - <-c.ctx.Done() - return c.ctx.Err() - }, c.task) + return executeParallel(ctx, c.tasks) } func (c *executionContext) run() error { @@ -56,17 +50,15 @@ func WithContext(ctx context.Context) ExecutionOption { func Parallel(tasks ...Task) ExecutionOption { return func(c *executionContext) { - c.task = func() error { - return executeParallel(tasks...) - } + c.tasks = append(c.tasks, tasks...) } } func Sequential(tasks ...Task) ExecutionOption { return func(c *executionContext) { - c.task = func() error { + c.tasks = append(c.tasks, func() error { return execute(tasks...) - } + }) } } @@ -107,7 +99,7 @@ func execute(tasks ...Task) error { } // executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass. -func executeParallel(tasks ...Task) error { +func executeParallel(ctx context.Context, tasks []Task) error { n := len(tasks) s := semaphore.New(n) done := make(chan error, 1) @@ -129,6 +121,8 @@ func executeParallel(tasks ...Task) error { select { case err := <-done: return err + case <-ctx.Done(): + return ctx.Err() case <-s.Wait(): } }