diff --git a/common/buf/reader.go b/common/buf/reader.go index 6be917e8..f7f05132 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -199,3 +199,11 @@ func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) { } return nBytes, err } + +// Close implements io.Closer. +func (r *BufferedReader) Close() error { + if !r.leftOver.IsEmpty() { + r.leftOver.Release() + } + return common.Close(r.stream) +} diff --git a/common/buf/writer.go b/common/buf/writer.go index 7d1438bb..73d0538c 100644 --- a/common/buf/writer.go +++ b/common/buf/writer.go @@ -142,6 +142,14 @@ func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) { return sc.Size, err } +// Close implements io.Closable. +func (w *BufferedWriter) Close() error { + if err := w.Flush(); err != nil { + return err + } + return common.Close(w.writer) +} + type seqWriter struct { writer io.Writer } diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index b4b83530..7e4ab49d 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -3,17 +3,17 @@ package http import ( "context" gotls "crypto/tls" - "io" "net/http" "net/url" "sync" "golang.org/x/net/http2" - "v2ray.com/core/common" + "v2ray.com/core/common/buf" "v2ray.com/core/common/net" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/tls" + "v2ray.com/core/transport/pipe" ) var ( @@ -83,11 +83,12 @@ func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error return nil, err } - preader, pwriter := io.Pipe() + preader, pwriter := pipe.New(pipe.WithSizeLimit(20 * 1024)) + breader := buf.NewBufferedReader(preader) request := &http.Request{ Method: "PUT", Host: httpSettings.getRandomHost(), - Body: preader, + Body: buf.NewBufferedReader(preader), URL: &url.URL{ Scheme: "https", Host: dest.NetAddr(), @@ -105,10 +106,12 @@ func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error return nil, newError("unexpected status", response.StatusCode).AtWarning() } + bwriter := buf.NewBufferedWriter(pwriter) + common.Must(bwriter.SetBuffered(false)) return &Connection{ Reader: response.Body, - Writer: pwriter, - Closer: common.NewChainedClosable(preader, pwriter, response.Body), + Writer: bwriter, + Closer: common.NewChainedClosable(breader, bwriter, response.Body), Local: &net.TCPAddr{ IP: []byte{0, 0, 0, 0}, Port: 0, diff --git a/transport/pipe/impl.go b/transport/pipe/impl.go new file mode 100644 index 00000000..f1f8e8f6 --- /dev/null +++ b/transport/pipe/impl.go @@ -0,0 +1,142 @@ +package pipe + +import ( + "io" + "sync" + "time" + + "v2ray.com/core/common/buf" + "v2ray.com/core/common/errors" + "v2ray.com/core/common/signal" +) + +type state byte + +const ( + open state = iota + closed + errord +) + +type pipe struct { + sync.Mutex + data buf.MultiBuffer + readSignal *signal.Notifier + writeSignal *signal.Notifier + limit int32 + state state +} + +func (p *pipe) getState(forRead bool) error { + switch p.state { + case open: + return nil + case closed: + if forRead { + if !p.data.IsEmpty() { + return nil + } + return io.EOF + } + return io.ErrClosedPipe + case errord: + return io.ErrClosedPipe + default: + panic("impossible case") + } +} + +func (p *pipe) readMultiBufferInternal() (buf.MultiBuffer, error) { + p.Lock() + defer p.Unlock() + + if err := p.getState(true); err != nil { + return nil, err + } + + data := p.data + p.data = nil + return data, nil +} + +func (p *pipe) ReadMultiBuffer() (buf.MultiBuffer, error) { + for { + data, err := p.readMultiBufferInternal() + if data != nil || err != nil { + return data, err + } + + <-p.readSignal.Wait() + } +} + +var ErrTimeout = errors.New("Timeout on reading pipeline.") + +func (p *pipe) ReadMultiBufferWithTimeout(d time.Duration) (buf.MultiBuffer, error) { + timer := time.After(d) + for { + data, err := p.readMultiBufferInternal() + if data != nil || err != nil { + p.writeSignal.Signal() + return data, err + } + + select { + case <-p.readSignal.Wait(): + case <-timer: + return nil, ErrTimeout + } + } +} + +func (p *pipe) writeMultiBufferInternal(mb buf.MultiBuffer) error { + p.Lock() + defer p.Unlock() + + if err := p.getState(false); err != nil { + return err + } + + p.data.AppendMulti(mb) + return nil +} + +func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error { + if mb.IsEmpty() { + return nil + } + + for { + if p.limit < 0 || p.data.Len()+mb.Len() <= p.limit { + defer p.readSignal.Signal() + return p.writeMultiBufferInternal(mb) + } + + <-p.writeSignal.Wait() + } +} + +func (p *pipe) Close() error { + p.Lock() + defer p.Unlock() + + p.state = closed + p.readSignal.Signal() + p.writeSignal.Signal() + return nil +} + +func (p *pipe) CloseError() { + p.Lock() + defer p.Unlock() + + p.state = errord + + if !p.data.IsEmpty() { + p.data.Release() + p.data = nil + } + + p.readSignal.Signal() + p.writeSignal.Signal() +} diff --git a/transport/pipe/pipe.go b/transport/pipe/pipe.go new file mode 100644 index 00000000..e3dc238d --- /dev/null +++ b/transport/pipe/pipe.go @@ -0,0 +1,49 @@ +package pipe + +import ( + "v2ray.com/core/common/platform" + "v2ray.com/core/common/signal" +) + +type Option func(*pipe) + +func WithoutSizeLimit() Option { + return func(p *pipe) { + p.limit = -1 + } +} + +func WithSizeLimit(limit int32) Option { + return func(p *pipe) { + p.limit = limit + } +} + +func New(opts ...Option) (*Reader, *Writer) { + p := &pipe{ + limit: defaultLimit, + readSignal: signal.NewNotifier(), + writeSignal: signal.NewNotifier(), + } + + for _, opt := range opts { + opt(p) + } + + return &Reader{ + pipe: p, + }, &Writer{ + pipe: p, + } +} + +var defaultLimit int32 = 10 * 1024 * 1024 + +func init() { + const raySizeEnvKey = "v2ray.ray.buffer.size" + size := platform.EnvFlag{ + Name: raySizeEnvKey, + AltName: platform.NormalizeEnvName(raySizeEnvKey), + }.GetValueAsInt(10) + defaultLimit = int32(size) * 1024 * 1024 +} diff --git a/transport/pipe/reader.go b/transport/pipe/reader.go new file mode 100644 index 00000000..369baeb9 --- /dev/null +++ b/transport/pipe/reader.go @@ -0,0 +1,23 @@ +package pipe + +import ( + "time" + + "v2ray.com/core/common/buf" +) + +type Reader struct { + pipe *pipe +} + +func (r *Reader) ReadMultiBuffer() (buf.MultiBuffer, error) { + return r.pipe.ReadMultiBuffer() +} + +func (r *Reader) ReadMultiBufferWithTimeout(d time.Duration) (buf.MultiBuffer, error) { + return r.pipe.ReadMultiBufferWithTimeout(d) +} + +func (r *Reader) CloseError() { + r.pipe.CloseError() +} diff --git a/transport/pipe/writer.go b/transport/pipe/writer.go new file mode 100644 index 00000000..c9b2838f --- /dev/null +++ b/transport/pipe/writer.go @@ -0,0 +1,21 @@ +package pipe + +import ( + "v2ray.com/core/common/buf" +) + +type Writer struct { + pipe *pipe +} + +func (w *Writer) WriteMultiBuffer(mb buf.MultiBuffer) error { + return w.pipe.WriteMultiBuffer(mb) +} + +func (w *Writer) Close() error { + return w.pipe.Close() +} + +func (w *Writer) CloseError() { + w.pipe.CloseError() +}