diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index ca2e6f0a..683a7c95 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -28,7 +28,7 @@ type cachedReader struct { } func (r *cachedReader) Cache(b *buf.Buffer) { - mb, _ := r.reader.ReadMultiBufferWithTimeout(time.Millisecond * 100) + mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100) if !mb.IsEmpty() { common.Must(r.cache.WriteMultiBuffer(mb)) } @@ -47,6 +47,16 @@ func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) { return r.reader.ReadMultiBuffer() } +func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) { + if !r.cache.IsEmpty() { + mb := r.cache + r.cache = nil + return mb, nil + } + + return r.reader.ReadMultiBufferTimeout(timeout) +} + func (r *cachedReader) CloseError() { r.cache.Release() r.reader.CloseError() diff --git a/app/proxyman/mux/mux.go b/app/proxyman/mux/mux.go index 02c4b5c4..d1350f59 100644 --- a/app/proxyman/mux/mux.go +++ b/app/proxyman/mux/mux.go @@ -147,17 +147,17 @@ func (m *Client) monitor() { } } -func copyFirstPayload(reader *pipe.Reader, writer *Writer) error { - data, err := reader.ReadMultiBufferWithTimeout(time.Millisecond * 200) - if err == buf.ErrReadTimeout { - return writer.writeMetaOnly() +func writeFirstPayload(reader buf.Reader, writer *Writer) error { + err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*200) + if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout { + return writer.WriteMultiBuffer(buf.MultiBuffer{}) } if err != nil { return err } - return writer.WriteMultiBuffer(data) + return nil } func fetchInput(ctx context.Context, s *Session, output buf.Writer) { @@ -172,13 +172,11 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) { defer writer.Close() // nolint: errcheck newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx)) - if pReader, ok := s.input.(*pipe.Reader); ok { - if err := copyFirstPayload(pReader, writer); err != nil { - newError("failed to fetch first payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) - writer.hasError = true - pipe.CloseError(s.input) - return - } + if err := writeFirstPayload(s.input, writer); err != nil { + newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) + writer.hasError = true + pipe.CloseError(s.input) + return } if err := buf.Copy(s.input, writer); err != nil { diff --git a/common/buf/copy.go b/common/buf/copy.go index 7d487c17..e6118179 100644 --- a/common/buf/copy.go +++ b/common/buf/copy.go @@ -2,6 +2,7 @@ package buf import ( "io" + "time" "v2ray.com/core/common/errors" "v2ray.com/core/common/signal" @@ -112,3 +113,17 @@ func Copy(reader Reader, writer Writer, options ...CopyOption) error { } return nil } + +var ErrNotTimeoutReader = newError("not a TimeoutReader") + +func CopyOnceTimeout(reader Reader, writer Writer, timeout time.Duration) error { + timeoutReader, ok := reader.(TimeoutReader) + if !ok { + return ErrNotTimeoutReader + } + mb, err := timeoutReader.ReadMultiBufferTimeout(timeout) + if err != nil { + return err + } + return writer.WriteMultiBuffer(mb) +} diff --git a/common/buf/io.go b/common/buf/io.go index c3b9538b..889fdc34 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -16,7 +16,7 @@ var ErrReadTimeout = newError("IO timeout") // TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout. type TimeoutReader interface { - ReadTimeout(time.Duration) (MultiBuffer, error) + ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error) } // Writer extends io.Writer with MultiBuffer. diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 5b7e7d4a..5371b060 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -9,8 +9,6 @@ import ( "v2ray.com/core/common/session" "v2ray.com/core/common/task" - "v2ray.com/core/transport/pipe" - "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" @@ -118,16 +116,8 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia } bodyWriter := session.EncodeRequestBody(request, writer) - if tReader, ok := input.(*pipe.Reader); ok { - firstPayload, err := tReader.ReadMultiBufferWithTimeout(time.Millisecond * 500) - if err != nil && err != buf.ErrReadTimeout { - return newError("failed to get first payload").Base(err) - } - if !firstPayload.IsEmpty() { - if err := bodyWriter.WriteMultiBuffer(firstPayload); err != nil { - return newError("failed to write first payload").Base(err) - } - } + if err := buf.CopyOnceTimeout(input, bodyWriter, time.Millisecond*500); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout { + return newError("failed to write first payload").Base(err) } if err := writer.SetBuffered(false); err != nil { diff --git a/transport/pipe/impl.go b/transport/pipe/impl.go index 328f531b..d58f8d54 100644 --- a/transport/pipe/impl.go +++ b/transport/pipe/impl.go @@ -81,7 +81,7 @@ func (p *pipe) ReadMultiBuffer() (buf.MultiBuffer, error) { } } -func (p *pipe) ReadMultiBufferWithTimeout(d time.Duration) (buf.MultiBuffer, error) { +func (p *pipe) ReadMultiBufferTimeout(d time.Duration) (buf.MultiBuffer, error) { timer := time.After(d) for { data, err := p.readMultiBufferInternal() diff --git a/transport/pipe/pipe_test.go b/transport/pipe/pipe_test.go index 15d0d345..3a27c27d 100644 --- a/transport/pipe/pipe_test.go +++ b/transport/pipe/pipe_test.go @@ -118,3 +118,10 @@ func TestPipeWriteMultiThread(t *testing.T) { assert(err, IsNil) assert(b[0].Bytes(), Equals, []byte{'a', 'b', 'c', 'd'}) } + +func TestInterfaces(t *testing.T) { + assert := With(t) + + assert((*Reader)(nil), Implements, (*buf.Reader)(nil)) + assert((*Reader)(nil), Implements, (*buf.TimeoutReader)(nil)) +} diff --git a/transport/pipe/reader.go b/transport/pipe/reader.go index d0c97fb7..a06af49b 100644 --- a/transport/pipe/reader.go +++ b/transport/pipe/reader.go @@ -16,9 +16,9 @@ func (r *Reader) ReadMultiBuffer() (buf.MultiBuffer, error) { return r.pipe.ReadMultiBuffer() } -// ReadMultiBufferWithTimeout reads content from a pipe within the given duration, or returns buf.ErrTimeout otherwise. -func (r *Reader) ReadMultiBufferWithTimeout(d time.Duration) (buf.MultiBuffer, error) { - return r.pipe.ReadMultiBufferWithTimeout(d) +// ReadMultiBufferTimeout reads content from a pipe within the given duration, or returns buf.ErrTimeout otherwise. +func (r *Reader) ReadMultiBufferTimeout(d time.Duration) (buf.MultiBuffer, error) { + return r.pipe.ReadMultiBufferTimeout(d) } // CloseError sets the pipe to error state. Both reading and writing from/to the pipe will return io.ErrClosedPipe.