diff --git a/common/functions/functions.go b/common/functions/functions.go index 76b0d47b..1b790bae 100644 --- a/common/functions/functions.go +++ b/common/functions/functions.go @@ -5,8 +5,8 @@ import "v2ray.com/core/common" // Task is a function that may return an error. type Task func() error -// CloseOnSuccess returns a Task to run a follow task if pre-condition passes, otherwise the error in pre-condition is returned. -func CloseOnSuccess(pre func() error, followup Task) Task { +// OnSuccess returns a Task to run a follow task if pre-condition passes, otherwise the error in pre-condition is returned. +func OnSuccess(pre func() error, followup Task) Task { return func() error { if err := pre(); err != nil { return err diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 2ea04217..cba00c2c 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -9,6 +9,7 @@ import ( "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" + "v2ray.com/core/common/functions" "v2ray.com/core/common/net" "v2ray.com/core/common/signal" "v2ray.com/core/proxy" @@ -79,7 +80,6 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in } requestDone := func() error { - defer common.Close(link.Writer) defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly) chunkReader := buf.NewReader(conn) @@ -118,7 +118,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in return nil } - if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index e8b4a270..ddf888d5 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -136,7 +136,7 @@ func (h *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia return nil } - if err := signal.ExecuteParallel(ctx, requestDone, functions.CloseOnSuccess(responseDone, functions.Close(output))); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, functions.OnSuccess(responseDone, functions.Close(output))); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/http/server.go b/proxy/http/server.go index f8ef71a2..4ff9dedf 100755 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -16,6 +16,7 @@ import ( "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/errors" + "v2ray.com/core/common/functions" "v2ray.com/core/common/log" "v2ray.com/core/common/net" http_proto "v2ray.com/core/common/protocol/http" @@ -192,7 +193,6 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade } requestDone := func() error { - defer common.Close(link.Writer) defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly) v2reader := buf.NewReader(conn) @@ -210,7 +210,7 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade return nil } - if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index fa89de3b..84427cab 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -158,7 +158,7 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial return nil } - if err := signal.ExecuteParallel(ctx, requestDone, functions.CloseOnSuccess(responseDone, functions.Close(link.Writer))); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, functions.OnSuccess(responseDone, functions.Close(link.Writer))); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index ad8f2f37..bc6653f4 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -7,6 +7,7 @@ import ( "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" + "v2ray.com/core/common/functions" "v2ray.com/core/common/log" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -207,7 +208,6 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, requestDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) - defer common.Close(link.Writer) if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil { return newError("failed to transport all TCP request").Base(err) @@ -216,7 +216,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, return nil } - if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/socks/client.go b/proxy/socks/client.go index 9e90c833..4078345c 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -130,7 +130,7 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial } } - if err := signal.ExecuteParallel(ctx, requestFunc, functions.CloseOnSuccess(responseFunc, functions.Close(link.Writer))); err != nil { + if err := signal.ExecuteParallel(ctx, requestFunc, functions.OnSuccess(responseFunc, functions.Close(link.Writer))); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 4c74f69c..928dd820 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -8,6 +8,7 @@ import ( "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" + "v2ray.com/core/common/functions" "v2ray.com/core/common/log" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -139,7 +140,6 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ requestDone := func() error { defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly) - defer common.Close(link.Writer) // nolint: errcheck v2reader := buf.NewReader(reader) if err := buf.Copy(v2reader, link.Writer, buf.UpdateActivity(timer)); err != nil { @@ -160,7 +160,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ return nil } - if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 936d70f9..ffc2a3de 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -13,6 +13,7 @@ import ( "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/errors" + "v2ray.com/core/common/functions" "v2ray.com/core/common/log" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -168,8 +169,6 @@ func (h *Handler) RemoveUser(ctx context.Context, email string) error { } func transferRequest(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output buf.Writer) error { - defer common.Close(output) - bodyReader := session.DecodeRequestBody(request, input) if err := buf.Copy(bodyReader, output, buf.UpdateActivity(timer)); err != nil { return newError("failed to transfer request").Base(err) @@ -295,7 +294,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i return transferResponse(timer, session, request, response, link.Reader, writer) } - if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 217d733a..5629cf51 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -161,7 +161,7 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia return buf.Copy(bodyReader, output, buf.UpdateActivity(timer)) } - if err := signal.ExecuteParallel(ctx, requestDone, functions.CloseOnSuccess(responseDone, functions.Close(output))); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, functions.OnSuccess(responseDone, functions.Close(output))); err != nil { return newError("connection ends").Base(err) }