mirror of https://github.com/v2ray/v2ray-core
simplify task execution
parent
cf1705267e
commit
427679e66d
|
@ -2,7 +2,8 @@ package task
|
|||
|
||||
import "v2ray.com/core/common"
|
||||
|
||||
func Close(v interface{}) Task {
|
||||
// Close returns a func() that closes v.
|
||||
func Close(v interface{}) func() error {
|
||||
return func() error {
|
||||
return common.Close(v)
|
||||
}
|
||||
|
|
|
@ -6,121 +6,25 @@ import (
|
|||
"v2ray.com/core/common/signal/semaphore"
|
||||
)
|
||||
|
||||
type Task func() error
|
||||
|
||||
type executionContext struct {
|
||||
ctx context.Context
|
||||
tasks []Task
|
||||
onSuccess Task
|
||||
onFailure Task
|
||||
}
|
||||
|
||||
func (c *executionContext) executeTask() error {
|
||||
if len(c.tasks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reuse current goroutine if we only have one task to run.
|
||||
if len(c.tasks) == 1 && c.ctx == nil {
|
||||
return c.tasks[0]()
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
if c.ctx != nil {
|
||||
ctx = c.ctx
|
||||
}
|
||||
|
||||
return executeParallel(ctx, c.tasks)
|
||||
}
|
||||
|
||||
func (c *executionContext) run() error {
|
||||
err := c.executeTask()
|
||||
if err == nil && c.onSuccess != nil {
|
||||
return c.onSuccess()
|
||||
}
|
||||
if err != nil && c.onFailure != nil {
|
||||
return c.onFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type ExecutionOption func(*executionContext)
|
||||
|
||||
func WithContext(ctx context.Context) ExecutionOption {
|
||||
return func(c *executionContext) {
|
||||
c.ctx = ctx
|
||||
}
|
||||
}
|
||||
|
||||
func Parallel(tasks ...Task) ExecutionOption {
|
||||
return func(c *executionContext) {
|
||||
c.tasks = append(c.tasks, tasks...)
|
||||
}
|
||||
}
|
||||
|
||||
// Sequential runs all tasks sequentially, and returns the first error encountered.Sequential
|
||||
// Once a task returns an error, the following tasks will not run.
|
||||
func Sequential(tasks ...Task) ExecutionOption {
|
||||
return func(c *executionContext) {
|
||||
switch len(tasks) {
|
||||
case 0:
|
||||
return
|
||||
case 1:
|
||||
c.tasks = append(c.tasks, tasks[0])
|
||||
default:
|
||||
c.tasks = append(c.tasks, func() error {
|
||||
return execute(tasks...)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func OnSuccess(task Task) ExecutionOption {
|
||||
return func(c *executionContext) {
|
||||
c.onSuccess = task
|
||||
}
|
||||
}
|
||||
|
||||
func OnFailure(task Task) ExecutionOption {
|
||||
return func(c *executionContext) {
|
||||
c.onFailure = task
|
||||
}
|
||||
}
|
||||
|
||||
func Single(task Task, opts ...ExecutionOption) Task {
|
||||
return Run(append([]ExecutionOption{Sequential(task)}, opts...)...)
|
||||
}
|
||||
|
||||
func Run(opts ...ExecutionOption) Task {
|
||||
var c executionContext
|
||||
for _, opt := range opts {
|
||||
opt(&c)
|
||||
}
|
||||
// OnSuccess executes g() after f() returns nil.
|
||||
func OnSuccess(f func() error, g func() error) func() error {
|
||||
return func() error {
|
||||
return c.run()
|
||||
}
|
||||
}
|
||||
|
||||
// execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
|
||||
func execute(tasks ...Task) error {
|
||||
for _, task := range tasks {
|
||||
if err := task(); err != nil {
|
||||
if err := f(); err != nil {
|
||||
return err
|
||||
}
|
||||
return g()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
|
||||
func executeParallel(ctx context.Context, tasks []Task) error {
|
||||
// Run executes a list of tasks in parallel, returns the first error encountered or nil if all tasks pass.
|
||||
func Run(ctx context.Context, tasks ...func() error) error {
|
||||
n := len(tasks)
|
||||
s := semaphore.New(n)
|
||||
done := make(chan error, 1)
|
||||
|
||||
for _, task := range tasks {
|
||||
<-s.Wait()
|
||||
go func(f Task) {
|
||||
go func(f func() error) {
|
||||
err := f()
|
||||
if err == nil {
|
||||
s.Signal()
|
||||
|
|
|
@ -14,13 +14,14 @@ import (
|
|||
func TestExecuteParallel(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
err := Run(Parallel(func() error {
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
return errors.New("test")
|
||||
}, func() error {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
return errors.New("test2")
|
||||
}))()
|
||||
err := Run(context.Background(),
|
||||
func() error {
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
return errors.New("test")
|
||||
}, func() error {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
return errors.New("test2")
|
||||
})
|
||||
|
||||
assert(err.Error(), Equals, "test")
|
||||
}
|
||||
|
@ -29,7 +30,7 @@ func TestExecuteParallelContextCancel(t *testing.T) {
|
|||
assert := With(t)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
err := Run(WithContext(ctx), Parallel(func() error {
|
||||
err := Run(ctx, func() error {
|
||||
time.Sleep(time.Millisecond * 2000)
|
||||
return errors.New("test")
|
||||
}, func() error {
|
||||
|
@ -38,7 +39,7 @@ func TestExecuteParallelContextCancel(t *testing.T) {
|
|||
}, func() error {
|
||||
cancel()
|
||||
return nil
|
||||
}))()
|
||||
})
|
||||
|
||||
assert(err.Error(), HasSubstring, "canceled")
|
||||
}
|
||||
|
@ -48,7 +49,7 @@ func BenchmarkExecuteOne(b *testing.B) {
|
|||
return nil
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
common.Must(Run(Parallel(noop))())
|
||||
common.Must(Run(context.Background(), noop))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,17 +58,6 @@ func BenchmarkExecuteTwo(b *testing.B) {
|
|||
return nil
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
common.Must(Run(Parallel(noop, noop))())
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecuteContext(b *testing.B) {
|
||||
noop := func() error {
|
||||
return nil
|
||||
}
|
||||
background := context.Background()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
common.Must(Run(WithContext(background), Parallel(noop, noop))())
|
||||
common.Must(Run(context.Background(), noop, noop))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -147,10 +147,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
|
|||
return nil
|
||||
}
|
||||
|
||||
if err := task.Run(task.WithContext(ctx),
|
||||
task.Parallel(
|
||||
task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))),
|
||||
responseDone))(); err != nil {
|
||||
if err := task.Run(ctx, task.OnSuccess(requestDone, task.Close(link.Writer)), responseDone); err != nil {
|
||||
pipe.CloseError(link.Reader)
|
||||
pipe.CloseError(link.Writer)
|
||||
return newError("connection ends").Base(err)
|
||||
|
|
|
@ -167,7 +167,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||
return nil
|
||||
}
|
||||
|
||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, task.Single(responseDone, task.OnSuccess(task.Close(output)))))(); err != nil {
|
||||
if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
|
||||
return newError("connection ends").Base(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -210,8 +210,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
|
|||
return nil
|
||||
}
|
||||
|
||||
var closeWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
|
||||
if err := task.Run(task.WithContext(ctx), task.Parallel(closeWriter, responseDone))(); err != nil {
|
||||
var closeWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, closeWriter, responseDone); err != nil {
|
||||
pipe.CloseError(link.Reader)
|
||||
pipe.CloseError(link.Writer)
|
||||
return newError("connection ends").Base(err)
|
||||
|
@ -307,7 +307,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
|
|||
return nil
|
||||
}
|
||||
|
||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDone))(); err != nil {
|
||||
if err := task.Run(ctx, requestDone, responseDone); err != nil {
|
||||
pipe.CloseError(link.Reader)
|
||||
pipe.CloseError(link.Writer)
|
||||
return newError("connection ends").Base(err)
|
||||
|
|
|
@ -62,8 +62,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
|
|||
return buf.Copy(connReader, link.Writer)
|
||||
}
|
||||
|
||||
var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer)))
|
||||
if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil {
|
||||
var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
|
||||
return newError("connection ends").Base(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -141,8 +141,8 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
|
|||
return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
|
||||
}
|
||||
|
||||
var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer)))
|
||||
if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil {
|
||||
var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
|
||||
pipe.CloseError(link.Reader)
|
||||
pipe.CloseError(link.Writer)
|
||||
return newError("connection ends").Base(err)
|
||||
|
|
|
@ -129,8 +129,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
|
|||
return buf.Copy(responseReader, link.Writer, buf.UpdateActivity(timer))
|
||||
}
|
||||
|
||||
var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
|
||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
|
||||
var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
|
||||
return newError("connection ends").Base(err)
|
||||
}
|
||||
|
||||
|
@ -167,8 +167,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
|
|||
return nil
|
||||
}
|
||||
|
||||
var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
|
||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
|
||||
var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
|
||||
return newError("connection ends").Base(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -229,8 +229,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
|
|||
return nil
|
||||
}
|
||||
|
||||
var requestDoneAndCloseWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
|
||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDoneAndCloseWriter, responseDone))(); err != nil {
|
||||
var requestDoneAndCloseWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil {
|
||||
pipe.CloseError(link.Reader)
|
||||
pipe.CloseError(link.Writer)
|
||||
return newError("connection ends").Base(err)
|
||||
|
|
|
@ -137,8 +137,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
|
|||
}
|
||||
}
|
||||
|
||||
var responseDonePost = task.Single(responseFunc, task.OnSuccess(task.Close(link.Writer)))
|
||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestFunc, responseDonePost))(); err != nil {
|
||||
var responseDonePost = task.OnSuccess(responseFunc, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
|
||||
return newError("connection ends").Base(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -164,8 +164,8 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
|
|||
return nil
|
||||
}
|
||||
|
||||
var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
|
||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
|
||||
var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
|
||||
pipe.CloseError(link.Reader)
|
||||
pipe.CloseError(link.Writer)
|
||||
return newError("connection ends").Base(err)
|
||||
|
|
|
@ -302,8 +302,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
|
|||
return transferResponse(timer, svrSession, request, response, link.Reader, writer)
|
||||
}
|
||||
|
||||
var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
|
||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
|
||||
var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
|
||||
pipe.CloseError(link.Reader)
|
||||
pipe.CloseError(link.Writer)
|
||||
return newError("connection ends").Base(err)
|
||||
|
|
|
@ -161,8 +161,8 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||
return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
|
||||
}
|
||||
|
||||
var responseDonePost = task.Single(responseDone, task.OnSuccess(task.Close(output)))
|
||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDonePost))(); err != nil {
|
||||
var responseDonePost = task.OnSuccess(responseDone, task.Close(output))
|
||||
if err := task.Run(ctx, requestDone, responseDonePost); err != nil {
|
||||
return newError("connection ends").Base(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ func (server *Server) handleConnection(conn net.Conn) {
|
|||
}
|
||||
|
||||
pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
|
||||
err := task.Run(task.Parallel(func() error {
|
||||
err := task.Run(context.Background(), func() error {
|
||||
defer pWriter.Close() // nolint: errcheck
|
||||
|
||||
for {
|
||||
|
@ -96,7 +96,7 @@ func (server *Server) handleConnection(conn net.Conn) {
|
|||
return err
|
||||
}
|
||||
}
|
||||
}))()
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("failed to transfer data: ", err.Error())
|
||||
|
|
Loading…
Reference in New Issue