cleanup buffer usage

pull/700/head v2.47
Darien Raymond 2017-11-09 22:33:15 +01:00
parent 6e61538b36
commit 594ec15c09
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
41 changed files with 358 additions and 529 deletions

View File

@ -161,7 +161,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
log.Trace(newError("dispatching request to ", dest)) log.Trace(newError("dispatching request to ", dest))
data, _ := s.input.ReadTimeout(time.Millisecond * 500) data, _ := s.input.ReadTimeout(time.Millisecond * 500)
if err := writer.Write(data); err != nil { if err := writer.WriteMultiBuffer(data); err != nil {
log.Trace(newError("failed to write first payload").Base(err)) log.Trace(newError("failed to write first payload").Base(err))
return return
} }
@ -234,7 +234,7 @@ func (m *Client) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error {
func (m *Client) fetchOutput() { func (m *Client) fetchOutput() {
defer m.cancel() defer m.cancel()
reader := buf.ToBytesReader(m.inboundRay.InboundOutput()) reader := buf.NewBufferedReader(m.inboundRay.InboundOutput())
for { for {
meta, err := ReadMetadata(reader) meta, err := ReadMetadata(reader)
@ -396,7 +396,7 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader) error
func (w *ServerWorker) run(ctx context.Context) { func (w *ServerWorker) run(ctx context.Context) {
input := w.outboundRay.OutboundInput() input := w.outboundRay.OutboundInput()
reader := buf.ToBytesReader(input) reader := buf.NewBufferedReader(input)
defer w.sessionManager.Close() defer w.sessionManager.Close()

View File

@ -16,7 +16,7 @@ import (
func readAll(reader buf.Reader) (buf.MultiBuffer, error) { func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
var mb buf.MultiBuffer var mb buf.MultiBuffer
for { for {
b, err := reader.Read() b, err := reader.ReadMultiBuffer()
if err == io.EOF { if err == io.EOF {
break break
} }
@ -45,7 +45,7 @@ func TestReaderWriter(t *testing.T) {
writePayload := func(writer *Writer, payload ...byte) error { writePayload := func(writer *Writer, payload ...byte) error {
b := buf.New() b := buf.New()
b.Append(payload) b.Append(payload)
return writer.Write(buf.NewMultiBufferValue(b)) return writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
} }
assert(writePayload(writer, 'a', 'b', 'c', 'd'), IsNil) assert(writePayload(writer, 'a', 'b', 'c', 'd'), IsNil)
@ -60,7 +60,7 @@ func TestReaderWriter(t *testing.T) {
assert(writePayload(writer2, 'y'), IsNil) assert(writePayload(writer2, 'y'), IsNil)
writer2.Close() writer2.Close()
bytesReader := buf.ToBytesReader(stream) bytesReader := buf.NewBufferedReader(stream)
streamReader := NewStreamReader(bytesReader) streamReader := NewStreamReader(bytesReader)
meta, err := ReadMetadata(bytesReader) meta, err := ReadMetadata(bytesReader)

View File

@ -40,8 +40,8 @@ func NewPacketReader(reader io.Reader) *PacketReader {
} }
} }
// Read implements buf.Reader. // ReadMultiBuffer implements buf.Reader.
func (r *PacketReader) Read() (buf.MultiBuffer, error) { func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
if r.eof { if r.eof {
return nil, io.EOF return nil, io.EOF
} }
@ -79,8 +79,8 @@ func NewStreamReader(reader io.Reader) *StreamReader {
} }
} }
// Read implmenets buf.Reader. // ReadMultiBuffer implmenets buf.Reader.
func (r *StreamReader) Read() (buf.MultiBuffer, error) { func (r *StreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
if r.leftOver == 0 { if r.leftOver == 0 {
r.leftOver = -1 r.leftOver = -1
return nil, io.EOF return nil, io.EOF

View File

@ -56,7 +56,7 @@ func (w *Writer) writeMetaOnly() error {
if err := b.Reset(meta.AsSupplier()); err != nil { if err := b.Reset(meta.AsSupplier()); err != nil {
return err return err
} }
return w.writer.Write(buf.NewMultiBufferValue(b)) return w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
} }
func (w *Writer) writeData(mb buf.MultiBuffer) error { func (w *Writer) writeData(mb buf.MultiBuffer) error {
@ -74,11 +74,11 @@ func (w *Writer) writeData(mb buf.MultiBuffer) error {
mb2 := buf.NewMultiBufferCap(len(mb) + 1) mb2 := buf.NewMultiBufferCap(len(mb) + 1)
mb2.Append(frame) mb2.Append(frame)
mb2.AppendMulti(mb) mb2.AppendMulti(mb)
return w.writer.Write(mb2) return w.writer.WriteMultiBuffer(mb2)
} }
// Write implements buf.MultiBufferWriter. // WriteMultiBuffer implements buf.Writer.
func (w *Writer) Write(mb buf.MultiBuffer) error { func (w *Writer) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer mb.Release() defer mb.Release()
if mb.IsEmpty() { if mb.IsEmpty() {
@ -109,5 +109,5 @@ func (w *Writer) Close() {
frame := buf.New() frame := buf.New()
common.Must(frame.Reset(meta.AsSupplier())) common.Must(frame.Reset(meta.AsSupplier()))
w.writer.Write(buf.NewMultiBufferValue(frame)) w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(frame))
} }

View File

@ -123,8 +123,8 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn
} }
var ( var (
_ buf.MultiBufferReader = (*Connection)(nil) _ buf.Reader = (*Connection)(nil)
_ buf.MultiBufferWriter = (*Connection)(nil) _ buf.Writer = (*Connection)(nil)
) )
type Connection struct { type Connection struct {
@ -133,9 +133,8 @@ type Connection struct {
localAddr net.Addr localAddr net.Addr
remoteAddr net.Addr remoteAddr net.Addr
bytesReader io.Reader reader *buf.BufferedReader
reader buf.Reader writer buf.Writer
writer buf.Writer
} }
func NewConnection(stream ray.Ray) *Connection { func NewConnection(stream ray.Ray) *Connection {
@ -149,9 +148,8 @@ func NewConnection(stream ray.Ray) *Connection {
IP: []byte{0, 0, 0, 0}, IP: []byte{0, 0, 0, 0},
Port: 0, Port: 0,
}, },
bytesReader: buf.ToBytesReader(stream.InboundOutput()), reader: buf.NewBufferedReader(stream.InboundOutput()),
reader: stream.InboundOutput(), writer: stream.InboundInput(),
writer: stream.InboundInput(),
} }
} }
@ -160,11 +158,11 @@ func (v *Connection) Read(b []byte) (int, error) {
if v.closed { if v.closed {
return 0, io.EOF return 0, io.EOF
} }
return v.bytesReader.Read(b) return v.reader.Read(b)
} }
func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
return v.reader.Read() return v.reader.ReadMultiBuffer()
} }
// Write implements net.Conn.Write(). // Write implements net.Conn.Write().
@ -172,14 +170,19 @@ func (v *Connection) Write(b []byte) (int, error) {
if v.closed { if v.closed {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
} }
return buf.ToBytesWriter(v.writer).Write(b)
l := len(b)
mb := buf.NewMultiBufferCap(l/buf.Size + 1)
mb.Write(b)
return l, v.writer.WriteMultiBuffer(mb)
} }
func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
if v.closed { if v.closed {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
return v.writer.Write(mb)
return v.writer.WriteMultiBuffer(mb)
} }
// Close implements net.Conn.Close(). // Close implements net.Conn.Close().

View File

@ -1,53 +0,0 @@
package buf
import (
"io"
)
// BufferedReader is a reader with internal cache.
type BufferedReader struct {
reader io.Reader
buffer *Buffer
buffered bool
}
// NewBufferedReader creates a new BufferedReader based on an io.Reader.
func NewBufferedReader(rawReader io.Reader) *BufferedReader {
return &BufferedReader{
reader: rawReader,
buffer: NewLocal(1024),
buffered: true,
}
}
// IsBuffered returns true if the internal cache is effective.
func (r *BufferedReader) IsBuffered() bool {
return r.buffered
}
// SetBuffered is to enable or disable internal cache. If cache is disabled,
// Read() calls will be delegated to the underlying io.Reader directly.
func (r *BufferedReader) SetBuffered(cached bool) {
r.buffered = cached
}
// Read implements io.Reader.Read().
func (r *BufferedReader) Read(b []byte) (int, error) {
if !r.buffered || r.buffer == nil {
if !r.buffer.IsEmpty() {
return r.buffer.Read(b)
}
return r.reader.Read(b)
}
if r.buffer.IsEmpty() {
if err := r.buffer.Reset(ReadFrom(r.reader)); err != nil {
return 0, err
}
}
if r.buffer.IsEmpty() {
return 0, nil
}
return r.buffer.Read(b)
}

View File

@ -1,36 +0,0 @@
package buf_test
import (
"crypto/rand"
"testing"
. "v2ray.com/core/common/buf"
. "v2ray.com/ext/assert"
)
func TestBufferedReader(t *testing.T) {
assert := With(t)
content := New()
assert(content.AppendSupplier(ReadFrom(rand.Reader)), IsNil)
len := content.Len()
reader := NewBufferedReader(content)
assert(reader.IsBuffered(), IsTrue)
payload := make([]byte, 16)
nBytes, err := reader.Read(payload)
assert(nBytes, Equals, 16)
assert(err, IsNil)
len2 := content.Len()
assert(len-len2, GreaterThan, 16)
nBytes, err = reader.Read(payload)
assert(nBytes, Equals, 16)
assert(err, IsNil)
assert(content.Len(), Equals, len2)
}

View File

@ -1,73 +0,0 @@
package buf
import "io"
// BufferedWriter is an io.Writer with internal buffer. It writes to underlying writer when buffer is full or on demand.
// This type is not thread safe.
type BufferedWriter struct {
writer io.Writer
buffer *Buffer
buffered bool
}
// NewBufferedWriter creates a new BufferedWriter.
func NewBufferedWriter(writer io.Writer) *BufferedWriter {
return NewBufferedWriterSize(writer, 1024)
}
// NewBufferedWriterSize creates a BufferedWriter with specified buffer size.
func NewBufferedWriterSize(writer io.Writer, size uint32) *BufferedWriter {
return &BufferedWriter{
writer: writer,
buffer: NewLocal(int(size)),
buffered: true,
}
}
// Write implements io.Writer.
func (w *BufferedWriter) Write(b []byte) (int, error) {
if !w.buffered || w.buffer == nil {
return w.writer.Write(b)
}
bytesWritten := 0
for bytesWritten < len(b) {
nBytes, err := w.buffer.Write(b[bytesWritten:])
if err != nil {
return bytesWritten, err
}
bytesWritten += nBytes
if w.buffer.IsFull() {
if err := w.Flush(); err != nil {
return bytesWritten, err
}
}
}
return bytesWritten, nil
}
// Flush writes all buffered content into underlying writer, if any.
func (w *BufferedWriter) Flush() error {
defer w.buffer.Clear()
for !w.buffer.IsEmpty() {
nBytes, err := w.writer.Write(w.buffer.Bytes())
if err != nil {
return err
}
w.buffer.SliceFrom(nBytes)
}
return nil
}
// IsBuffered returns true if this BufferedWriter holds a buffer.
func (w *BufferedWriter) IsBuffered() bool {
return w.buffered
}
// SetBuffered controls whether the BufferedWriter holds a buffer for writing. If not buffered, any write() calls into underlying writer directly.
func (w *BufferedWriter) SetBuffered(cached bool) error {
w.buffered = cached
if !cached && !w.buffer.IsEmpty() {
return w.Flush()
}
return nil
}

View File

@ -1,54 +0,0 @@
package buf_test
import (
"crypto/rand"
"testing"
"v2ray.com/core/common"
. "v2ray.com/core/common/buf"
. "v2ray.com/ext/assert"
)
func TestBufferedWriter(t *testing.T) {
assert := With(t)
content := New()
writer := NewBufferedWriter(content)
assert(writer.IsBuffered(), IsTrue)
payload := make([]byte, 16)
nBytes, err := writer.Write(payload)
assert(nBytes, Equals, 16)
assert(err, IsNil)
assert(content.IsEmpty(), IsTrue)
assert(writer.SetBuffered(false), IsNil)
assert(content.Len(), Equals, 16)
}
func TestBufferedWriterLargePayload(t *testing.T) {
assert := With(t)
content := NewLocal(128 * 1024)
writer := NewBufferedWriter(content)
assert(writer.IsBuffered(), IsTrue)
payload := make([]byte, 64*1024)
common.Must2(rand.Read(payload))
nBytes, err := writer.Write(payload[:512])
assert(nBytes, Equals, 512)
assert(err, IsNil)
assert(content.IsEmpty(), IsTrue)
nBytes, err = writer.Write(payload[512:])
assert(err, IsNil)
assert(writer.Flush(), IsNil)
assert(nBytes, Equals, 64*1024-512)
assert(content.Bytes(), Equals, payload)
}

View File

@ -17,7 +17,7 @@ type copyHandler struct {
} }
func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) { func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) {
mb, err := reader.Read() mb, err := reader.ReadMultiBuffer()
if err != nil { if err != nil {
for _, handler := range h.onReadError { for _, handler := range h.onReadError {
err = handler(err) err = handler(err)
@ -27,7 +27,7 @@ func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) {
} }
func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error { func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error {
err := writer.Write(mb) err := writer.WriteMultiBuffer(mb)
if err != nil { if err != nil {
for _, handler := range h.onWriteError { for _, handler := range h.onWriteError {
err = handler(err) err = handler(err)
@ -36,6 +36,10 @@ func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error {
return err return err
} }
type SizeCounter struct {
Size int64
}
type CopyOption func(*copyHandler) type CopyOption func(*copyHandler)
func IgnoreReaderError() CopyOption { func IgnoreReaderError() CopyOption {
@ -62,6 +66,14 @@ func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
} }
} }
func CountSize(sc *SizeCounter) CopyOption {
return func(handler *copyHandler) {
handler.onData = append(handler.onData, func(b MultiBuffer) {
sc.Size += int64(b.Len())
})
}
}
func copyInternal(reader Reader, writer Writer, handler *copyHandler) error { func copyInternal(reader Reader, writer Writer, handler *copyHandler) error {
for { for {
buffer, err := handler.readFrom(reader) buffer, err := handler.readFrom(reader)

View File

@ -5,10 +5,10 @@ import (
"time" "time"
) )
// Reader extends io.Reader with alloc.Buffer. // Reader extends io.Reader with MultiBuffer.
type Reader interface { type Reader interface {
// Read reads content from underlying reader, and put it into an alloc.Buffer. // ReadMultiBuffer reads content from underlying reader, and put it into a MultiBuffer.
Read() (MultiBuffer, error) ReadMultiBuffer() (MultiBuffer, error)
} }
// ErrReadTimeout is an error that happens with IO timeout. // ErrReadTimeout is an error that happens with IO timeout.
@ -19,10 +19,10 @@ type TimeoutReader interface {
ReadTimeout(time.Duration) (MultiBuffer, error) ReadTimeout(time.Duration) (MultiBuffer, error)
} }
// Writer extends io.Writer with alloc.Buffer. // Writer extends io.Writer with MultiBuffer.
type Writer interface { type Writer interface {
// Write writes an alloc.Buffer into underlying writer. // WriteMultiBuffer writes a MultiBuffer into underlying writer.
Write(MultiBuffer) error WriteMultiBuffer(MultiBuffer) error
} }
// ReadFrom creates a Supplier to read from a given io.Reader. // ReadFrom creates a Supplier to read from a given io.Reader.
@ -49,45 +49,21 @@ func ReadAtLeastFrom(reader io.Reader, size int) Supplier {
// NewReader creates a new Reader. // NewReader creates a new Reader.
// The Reader instance doesn't take the ownership of reader. // The Reader instance doesn't take the ownership of reader.
func NewReader(reader io.Reader) Reader { func NewReader(reader io.Reader) Reader {
if mr, ok := reader.(MultiBufferReader); ok { if mr, ok := reader.(Reader); ok {
return &readerAdpater{ return mr
MultiBufferReader: mr,
}
} }
return &BytesToBufferReader{ return NewBytesToBufferReader(reader)
reader: reader,
}
}
// ToBytesReader converts a Reaaer to io.Reader.
func ToBytesReader(stream Reader) io.Reader {
return &bufferToBytesReader{
stream: stream,
}
} }
// NewWriter creates a new Writer. // NewWriter creates a new Writer.
func NewWriter(writer io.Writer) Writer { func NewWriter(writer io.Writer) Writer {
if mw, ok := writer.(MultiBufferWriter); ok { if mw, ok := writer.(Writer); ok {
return &writerAdapter{ return mw
writer: mw,
}
} }
return &BufferToBytesWriter{ return &BufferToBytesWriter{
writer: writer, Writer: writer,
}
}
func NewMergingWriter(writer io.Writer) Writer {
return NewMergingWriterSize(writer, 4096)
}
func NewMergingWriterSize(writer io.Writer, size uint32) Writer {
return &mergingWriter{
writer: writer,
buffer: make([]byte, size),
} }
} }
@ -96,10 +72,3 @@ func NewSequentialWriter(writer io.Writer) Writer {
writer: writer, writer: writer,
} }
} }
// ToBytesWriter converts a Writer to io.Writer
func ToBytesWriter(writer Writer) io.Writer {
return &bytesToBufferWriter{
writer: writer,
}
}

View File

@ -8,16 +8,6 @@ import (
"v2ray.com/core/common/errors" "v2ray.com/core/common/errors"
) )
// MultiBufferWriter is a writer that writes MultiBuffer.
type MultiBufferWriter interface {
WriteMultiBuffer(MultiBuffer) error
}
// MultiBufferReader is a reader that reader payload as MultiBuffer.
type MultiBufferReader interface {
ReadMultiBuffer() (MultiBuffer, error)
}
// ReadAllToMultiBuffer reads all content from the reader into a MultiBuffer, until EOF. // ReadAllToMultiBuffer reads all content from the reader into a MultiBuffer, until EOF.
func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) { func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) {
mb := NewMultiBufferCap(128) mb := NewMultiBufferCap(128)

View File

@ -8,19 +8,19 @@ import (
// BytesToBufferReader is a Reader that adjusts its reading speed automatically. // BytesToBufferReader is a Reader that adjusts its reading speed automatically.
type BytesToBufferReader struct { type BytesToBufferReader struct {
reader io.Reader io.Reader
buffer []byte buffer []byte
} }
func NewBytesToBufferReader(reader io.Reader) Reader { func NewBytesToBufferReader(reader io.Reader) Reader {
return &BytesToBufferReader{ return &BytesToBufferReader{
reader: reader, Reader: reader,
} }
} }
func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) { func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) {
b := New() b := New()
if err := b.Reset(ReadFrom(r.reader)); err != nil { if err := b.Reset(ReadFrom(r.Reader)); err != nil {
b.Release() b.Release()
return nil, err return nil, err
} }
@ -30,13 +30,13 @@ func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) {
return NewMultiBufferValue(b), nil return NewMultiBufferValue(b), nil
} }
// Read implements Reader.Read(). // ReadMultiBuffer implements Reader.
func (r *BytesToBufferReader) Read() (MultiBuffer, error) { func (r *BytesToBufferReader) ReadMultiBuffer() (MultiBuffer, error) {
if r.buffer == nil { if r.buffer == nil {
return r.readSmall() return r.readSmall()
} }
nBytes, err := r.reader.Read(r.buffer) nBytes, err := r.Reader.Read(r.buffer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -46,20 +46,33 @@ func (r *BytesToBufferReader) Read() (MultiBuffer, error) {
return mb, nil return mb, nil
} }
type readerAdpater struct { type BufferedReader struct {
MultiBufferReader stream Reader
legacyReader io.Reader
leftOver MultiBuffer
buffered bool
} }
func (r *readerAdpater) Read() (MultiBuffer, error) { func NewBufferedReader(reader Reader) *BufferedReader {
return r.ReadMultiBuffer() r := &BufferedReader{
stream: reader,
buffered: true,
}
if lr, ok := reader.(io.Reader); ok {
r.legacyReader = lr
}
return r
} }
type bufferToBytesReader struct { func (r *BufferedReader) SetBuffered(f bool) {
stream Reader r.buffered = f
leftOver MultiBuffer
} }
func (r *bufferToBytesReader) Read(b []byte) (int, error) { func (r *BufferedReader) IsBuffered() bool {
return r.buffered
}
func (r *BufferedReader) Read(b []byte) (int, error) {
if r.leftOver != nil { if r.leftOver != nil {
nBytes, _ := r.leftOver.Read(b) nBytes, _ := r.leftOver.Read(b)
if r.leftOver.IsEmpty() { if r.leftOver.IsEmpty() {
@ -69,7 +82,11 @@ func (r *bufferToBytesReader) Read(b []byte) (int, error) {
return nBytes, nil return nBytes, nil
} }
mb, err := r.stream.Read() if !r.buffered && r.legacyReader != nil {
return r.legacyReader.Read(b)
}
mb, err := r.stream.ReadMultiBuffer()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -81,39 +98,39 @@ func (r *bufferToBytesReader) Read(b []byte) (int, error) {
return nBytes, nil return nBytes, nil
} }
func (r *bufferToBytesReader) ReadMultiBuffer() (MultiBuffer, error) { func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) {
if r.leftOver != nil { if r.leftOver != nil {
mb := r.leftOver mb := r.leftOver
r.leftOver = nil r.leftOver = nil
return mb, nil return mb, nil
} }
return r.stream.Read() return r.stream.ReadMultiBuffer()
} }
func (r *bufferToBytesReader) writeToInternal(writer io.Writer) (int64, error) { func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) {
mbWriter := NewWriter(writer) mbWriter := NewWriter(writer)
totalBytes := int64(0) totalBytes := int64(0)
if r.leftOver != nil { if r.leftOver != nil {
totalBytes += int64(r.leftOver.Len()) totalBytes += int64(r.leftOver.Len())
if err := mbWriter.Write(r.leftOver); err != nil { if err := mbWriter.WriteMultiBuffer(r.leftOver); err != nil {
return 0, err return 0, err
} }
} }
for { for {
mb, err := r.stream.Read() mb, err := r.stream.ReadMultiBuffer()
if err != nil { if err != nil {
return totalBytes, err return totalBytes, err
} }
totalBytes += int64(mb.Len()) totalBytes += int64(mb.Len())
if err := mbWriter.Write(mb); err != nil { if err := mbWriter.WriteMultiBuffer(mb); err != nil {
return totalBytes, err return totalBytes, err
} }
} }
} }
func (r *bufferToBytesReader) WriteTo(writer io.Writer) (int64, error) { func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) {
nBytes, err := r.writeToInternal(writer) nBytes, err := r.writeToInternal(writer)
if errors.Cause(err) == io.EOF { if errors.Cause(err) == io.EOF {
return nBytes, nil return nBytes, nil

View File

@ -15,11 +15,11 @@ func TestAdaptiveReader(t *testing.T) {
assert := With(t) assert := With(t)
reader := NewReader(bytes.NewReader(make([]byte, 1024*1024))) reader := NewReader(bytes.NewReader(make([]byte, 1024*1024)))
b, err := reader.Read() b, err := reader.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
assert(b.Len(), Equals, 2*1024) assert(b.Len(), Equals, 2*1024)
b, err = reader.Read() b, err = reader.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
assert(b.Len(), Equals, 32*1024) assert(b.Len(), Equals, 32*1024)
} }
@ -28,22 +28,23 @@ func TestBytesReaderWriteTo(t *testing.T) {
assert := With(t) assert := With(t)
stream := ray.NewStream(context.Background()) stream := ray.NewStream(context.Background())
reader := ToBytesReader(stream) reader := NewBufferedReader(stream)
b1 := New() b1 := New()
b1.AppendBytes('a', 'b', 'c') b1.AppendBytes('a', 'b', 'c')
b2 := New() b2 := New()
b2.AppendBytes('e', 'f', 'g') b2.AppendBytes('e', 'f', 'g')
assert(stream.Write(NewMultiBufferValue(b1, b2)), IsNil) assert(stream.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil)
stream.Close() stream.Close()
stream2 := ray.NewStream(context.Background()) stream2 := ray.NewStream(context.Background())
writer := ToBytesWriter(stream2) writer := NewBufferedWriter(stream2)
writer.SetBuffered(false)
nBytes, err := io.Copy(writer, reader) nBytes, err := io.Copy(writer, reader)
assert(err, IsNil) assert(err, IsNil)
assert(nBytes, Equals, int64(6)) assert(nBytes, Equals, int64(6))
mb, err := stream2.Read() mb, err := stream2.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
assert(len(mb), Equals, 2) assert(len(mb), Equals, 2)
assert(mb[0].String(), Equals, "abc") assert(mb[0].String(), Equals, "abc")
@ -54,16 +55,16 @@ func TestBytesReaderMultiBuffer(t *testing.T) {
assert := With(t) assert := With(t)
stream := ray.NewStream(context.Background()) stream := ray.NewStream(context.Background())
reader := ToBytesReader(stream) reader := NewBufferedReader(stream)
b1 := New() b1 := New()
b1.AppendBytes('a', 'b', 'c') b1.AppendBytes('a', 'b', 'c')
b2 := New() b2 := New()
b2.AppendBytes('e', 'f', 'g') b2.AppendBytes('e', 'f', 'g')
assert(stream.Write(NewMultiBufferValue(b1, b2)), IsNil) assert(stream.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil)
stream.Close() stream.Close()
mbReader := NewReader(reader) mbReader := NewReader(reader)
mb, err := mbReader.Read() mb, err := mbReader.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
assert(len(mb), Equals, 2) assert(len(mb), Equals, 2)
assert(mb[0].String(), Equals, "abc") assert(mb[0].String(), Equals, "abc")

View File

@ -8,49 +8,142 @@ import (
// BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer. // BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer.
type BufferToBytesWriter struct { type BufferToBytesWriter struct {
writer io.Writer io.Writer
} }
// Write implements Writer.Write(). Write() takes ownership of the given buffer. func NewBufferToBytesWriter(writer io.Writer) *BufferToBytesWriter {
func (w *BufferToBytesWriter) Write(mb MultiBuffer) error { return &BufferToBytesWriter{
Writer: writer,
}
}
// WriteMultiBuffer implements Writer. This method takes ownership of the given buffer.
func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer mb.Release() defer mb.Release()
bs := mb.ToNetBuffers() bs := mb.ToNetBuffers()
_, err := bs.WriteTo(w.writer) _, err := bs.WriteTo(w)
return err return err
} }
type writerAdapter struct { func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) {
writer MultiBufferWriter if readerFrom, ok := w.Writer.(io.ReaderFrom); ok {
return readerFrom.ReadFrom(reader)
}
var sc SizeCounter
err := Copy(NewReader(reader), w, CountSize(&sc))
return sc.Size, err
} }
// Write implements buf.MultiBufferWriter. type BufferedWriter struct {
func (w *writerAdapter) Write(mb MultiBuffer) error { writer Writer
return w.writer.WriteMultiBuffer(mb) legacyWriter io.Writer
buffer *Buffer
buffered bool
} }
type mergingWriter struct { func NewBufferedWriter(writer Writer) *BufferedWriter {
writer io.Writer w := &BufferedWriter{
buffer []byte writer: writer,
buffer: New(),
buffered: true,
}
if lw, ok := writer.(io.Writer); ok {
w.legacyWriter = lw
}
return w
} }
func (w *mergingWriter) Write(mb MultiBuffer) error { func (w *BufferedWriter) Write(b []byte) (int, error) {
defer mb.Release() if !w.buffered && w.legacyWriter != nil {
return w.legacyWriter.Write(b)
}
for !mb.IsEmpty() { totalBytes := 0
nBytes, _ := mb.Read(w.buffer) for len(b) > 0 {
if _, err := w.writer.Write(w.buffer[:nBytes]); err != nil { nBytes, err := w.buffer.Write(b)
totalBytes += nBytes
if err != nil {
return totalBytes, err
}
if w.buffer.IsFull() {
if err := w.Flush(); err != nil {
return totalBytes, err
}
}
b = b[nBytes:]
}
return totalBytes, nil
}
func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
if !w.buffered {
return w.writer.WriteMultiBuffer(b)
}
defer b.Release()
for !b.IsEmpty() {
if err := w.buffer.AppendSupplier(ReadFrom(&b)); err != nil {
return err return err
} }
if w.buffer.IsFull() {
if err := w.Flush(); err != nil {
return err
}
}
}
return nil
}
func (w *BufferedWriter) Flush() error {
if !w.buffer.IsEmpty() {
if err := w.writer.WriteMultiBuffer(NewMultiBufferValue(w.buffer)); err != nil {
return err
}
if w.buffered {
w.buffer = New()
} else {
w.buffer = nil
}
} }
return nil return nil
} }
func (w *BufferedWriter) SetBuffered(f bool) error {
w.buffered = f
if !f {
return w.Flush()
}
return nil
}
func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) {
var sc SizeCounter
if !w.buffer.IsEmpty() {
sc.Size += int64(w.buffer.Len())
if err := w.Flush(); err != nil {
return sc.Size, err
}
}
if readerFrom, ok := w.writer.(io.ReaderFrom); ok {
return readerFrom.ReadFrom(reader)
}
w.buffered = false
err := Copy(NewReader(reader), w, CountSize(&sc))
return sc.Size, err
}
type seqWriter struct { type seqWriter struct {
writer io.Writer writer io.Writer
} }
func (w *seqWriter) Write(mb MultiBuffer) error { func (w *seqWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer mb.Release() defer mb.Release()
for _, b := range mb { for _, b := range mb {
@ -65,49 +158,9 @@ func (w *seqWriter) Write(mb MultiBuffer) error {
return nil return nil
} }
var (
_ MultiBufferWriter = (*bytesToBufferWriter)(nil)
)
type bytesToBufferWriter struct {
writer Writer
}
// Write implements io.Writer.
func (w *bytesToBufferWriter) Write(payload []byte) (int, error) {
mb := NewMultiBufferCap(len(payload)/Size + 1)
mb.Write(payload)
if err := w.writer.Write(mb); err != nil {
return 0, err
}
return len(payload), nil
}
func (w *bytesToBufferWriter) WriteMultiBuffer(mb MultiBuffer) error {
return w.writer.Write(mb)
}
func (w *bytesToBufferWriter) ReadFrom(reader io.Reader) (int64, error) {
mbReader := NewReader(reader)
totalBytes := int64(0)
for {
mb, err := mbReader.Read()
if errors.Cause(err) == io.EOF {
break
} else if err != nil {
return totalBytes, err
}
totalBytes += int64(mb.Len())
if err := w.writer.Write(mb); err != nil {
return totalBytes, err
}
}
return totalBytes, nil
}
type noOpWriter struct{} type noOpWriter struct{}
func (noOpWriter) Write(b MultiBuffer) error { func (noOpWriter) WriteMultiBuffer(b MultiBuffer) error {
b.Release() b.Release()
return nil return nil
} }

View File

@ -25,9 +25,11 @@ func TestWriter(t *testing.T) {
writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024)) writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024))
writer := NewWriter(NewBufferedWriter(writeBuffer)) writer := NewBufferedWriter(NewWriter(writeBuffer))
err := writer.Write(NewMultiBufferValue(lb)) writer.SetBuffered(false)
err := writer.WriteMultiBuffer(NewMultiBufferValue(lb))
assert(err, IsNil) assert(err, IsNil)
assert(writer.Flush(), IsNil)
assert(expectedBytes, Equals, writeBuffer.Bytes()) assert(expectedBytes, Equals, writeBuffer.Bytes())
} }
@ -36,20 +38,21 @@ func TestBytesWriterReadFrom(t *testing.T) {
cache := ray.NewStream(context.Background()) cache := ray.NewStream(context.Background())
reader := bufio.NewReader(io.LimitReader(rand.Reader, 8192)) reader := bufio.NewReader(io.LimitReader(rand.Reader, 8192))
_, err := reader.WriteTo(ToBytesWriter(cache)) writer := NewBufferedWriter(cache)
writer.SetBuffered(false)
_, err := reader.WriteTo(writer)
assert(err, IsNil) assert(err, IsNil)
mb, err := cache.Read() mb, err := cache.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
assert(mb.Len(), Equals, 8192) assert(mb.Len(), Equals, 8192)
assert(len(mb), Equals, 4)
} }
func TestDiscardBytes(t *testing.T) { func TestDiscardBytes(t *testing.T) {
assert := With(t) assert := With(t)
b := New() b := New()
common.Must(b.Reset(ReadFrom(rand.Reader))) common.Must(b.Reset(ReadFullFrom(rand.Reader, Size)))
nBytes, err := io.Copy(DiscardBytes, b) nBytes, err := io.Copy(DiscardBytes, b)
assert(nBytes, Equals, int64(Size)) assert(nBytes, Equals, int64(Size))
@ -64,7 +67,7 @@ func TestDiscardBytesMultiBuffer(t *testing.T) {
common.Must2(buffer.ReadFrom(io.LimitReader(rand.Reader, size))) common.Must2(buffer.ReadFrom(io.LimitReader(rand.Reader, size)))
r := NewReader(buffer) r := NewReader(buffer)
nBytes, err := io.Copy(DiscardBytes, ToBytesReader(r)) nBytes, err := io.Copy(DiscardBytes, NewBufferedReader(r))
assert(nBytes, Equals, int64(size)) assert(nBytes, Equals, int64(size))
assert(err, IsNil) assert(err, IsNil)
} }

View File

@ -151,7 +151,7 @@ func (r *AuthenticationReader) readChunk(waitForData bool) ([]byte, error) {
return b, nil return b, nil
} }
func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) { func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
b, err := r.readChunk(true) b, err := r.readChunk(true)
if err != nil { if err != nil {
return nil, err return nil, err
@ -193,81 +193,97 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
return mb, nil return mb, nil
} }
const (
WriteSize = 1024
)
type AuthenticationWriter struct { type AuthenticationWriter struct {
auth Authenticator auth Authenticator
buffer []byte writer buf.Writer
payload []byte
writer *buf.BufferedWriter
sizeParser ChunkSizeEncoder sizeParser ChunkSizeEncoder
transferType protocol.TransferType transferType protocol.TransferType
} }
func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType) *AuthenticationWriter { func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType) *AuthenticationWriter {
const payloadSize = 1024
return &AuthenticationWriter{ return &AuthenticationWriter{
auth: auth, auth: auth,
buffer: make([]byte, payloadSize+sizeParser.SizeBytes()+auth.Overhead()), writer: buf.NewWriter(writer),
payload: make([]byte, payloadSize),
writer: buf.NewBufferedWriterSize(writer, readerBufferSize),
sizeParser: sizeParser, sizeParser: sizeParser,
transferType: transferType, transferType: transferType,
} }
} }
func (w *AuthenticationWriter) append(b []byte) error { func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) {
encryptedSize := len(b) + w.auth.Overhead() encryptedSize := b.Len() + w.auth.Overhead()
buffer := w.sizeParser.Encode(uint16(encryptedSize), w.buffer[:0])
buffer, err := w.auth.Seal(buffer, b) eb := buf.New()
if err != nil { common.Must(eb.Reset(func(bb []byte) (int, error) {
return err w.sizeParser.Encode(uint16(encryptedSize), bb[:0])
return w.sizeParser.SizeBytes(), nil
}))
if err := eb.AppendSupplier(func(bb []byte) (int, error) {
_, err := w.auth.Seal(bb[:0], b.Bytes())
return encryptedSize, err
}); err != nil {
eb.Release()
return nil, err
} }
if _, err := w.writer.Write(buffer); err != nil { return eb, nil
return err
}
return nil
} }
func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error { func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
defer mb.Release() defer mb.Release()
mb2Write := buf.NewMultiBufferCap(len(mb) + 10)
for { for {
n, _ := mb.Read(w.payload) b := buf.New()
if err := w.append(w.payload[:n]); err != nil { common.Must(b.Reset(func(bb []byte) (int, error) {
return mb.Read(bb[:WriteSize])
}))
eb, err := w.seal(b)
b.Release()
if err != nil {
mb2Write.Release()
return err return err
} }
mb2Write.Append(eb)
if mb.IsEmpty() { if mb.IsEmpty() {
break break
} }
} }
return w.writer.Flush() return w.writer.WriteMultiBuffer(mb2Write)
} }
func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error { func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error {
defer mb.Release() defer mb.Release()
mb2Write := buf.NewMultiBufferCap(len(mb) * 2)
for { for {
b := mb.SplitFirst() b := mb.SplitFirst()
if b == nil { if b == nil {
b = buf.New() b = buf.New()
} }
if err := w.append(b.Bytes()); err != nil { eb, err := w.seal(b)
b.Release() b.Release()
if err != nil {
mb2Write.Release()
return err return err
} }
b.Release() mb2Write.Append(eb)
if mb.IsEmpty() { if mb.IsEmpty() {
break break
} }
} }
return w.writer.Flush() return w.writer.WriteMultiBuffer(mb2Write)
} }
func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error { func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
if w.transferType == protocol.TransferTypeStream { if w.transferType == protocol.TransferTypeStream {
return w.writeStream(mb) return w.writeStream(mb)
} }

View File

@ -42,9 +42,9 @@ func TestAuthenticationReaderWriter(t *testing.T) {
AdditionalDataGenerator: &NoOpBytesGenerator{}, AdditionalDataGenerator: &NoOpBytesGenerator{},
}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream) }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream)
assert(writer.Write(buf.NewMultiBufferValue(payload)), IsNil) assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(payload)), IsNil)
assert(cache.Len(), Equals, 83360) assert(cache.Len(), Equals, 83360)
assert(writer.Write(buf.MultiBuffer{}), IsNil) assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
assert(err, IsNil) assert(err, IsNil)
reader := NewAuthenticationReader(&AEADAuthenticator{ reader := NewAuthenticationReader(&AEADAuthenticator{
@ -58,7 +58,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
var mb buf.MultiBuffer var mb buf.MultiBuffer
for mb.Len() < len(rawPayload) { for mb.Len() < len(rawPayload) {
mb2, err := reader.Read() mb2, err := reader.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
mb.AppendMulti(mb2) mb.AppendMulti(mb2)
@ -68,7 +68,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
mb.Read(mbContent) mb.Read(mbContent)
assert(mbContent, Equals, rawPayload) assert(mbContent, Equals, rawPayload)
_, err = reader.Read() _, err = reader.ReadMultiBuffer()
assert(err, Equals, io.EOF) assert(err, Equals, io.EOF)
} }
@ -104,9 +104,9 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
pb2.Append([]byte("efgh")) pb2.Append([]byte("efgh"))
payload.Append(pb2) payload.Append(pb2)
assert(writer.Write(payload), IsNil) assert(writer.WriteMultiBuffer(payload), IsNil)
assert(cache.Len(), GreaterThan, 0) assert(cache.Len(), GreaterThan, 0)
assert(writer.Write(buf.MultiBuffer{}), IsNil) assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
assert(err, IsNil) assert(err, IsNil)
reader := NewAuthenticationReader(&AEADAuthenticator{ reader := NewAuthenticationReader(&AEADAuthenticator{
@ -117,7 +117,7 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
AdditionalDataGenerator: &NoOpBytesGenerator{}, AdditionalDataGenerator: &NoOpBytesGenerator{},
}, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket) }, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket)
mb, err := reader.Read() mb, err := reader.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
b1 := mb.SplitFirst() b1 := mb.SplitFirst()
@ -126,6 +126,6 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
assert(b2.String(), Equals, "efgh") assert(b2.String(), Equals, "efgh")
assert(mb.IsEmpty(), IsTrue) assert(mb.IsEmpty(), IsTrue)
_, err = reader.Read() _, err = reader.ReadMultiBuffer()
assert(err, Equals, io.EOF) assert(err, Equals, io.EOF)
} }

View File

@ -48,7 +48,7 @@ func NewChunkStreamReader(sizeDecoder ChunkSizeDecoder, reader io.Reader) *Chunk
sizeDecoder: sizeDecoder, sizeDecoder: sizeDecoder,
reader: buf.NewReader(reader), reader: buf.NewReader(reader),
buffer: make([]byte, sizeDecoder.SizeBytes()), buffer: make([]byte, sizeDecoder.SizeBytes()),
leftOver: buf.NewMultiBufferCap(16), leftOver: buf.NewMultiBufferCap(16),
} }
} }
@ -56,7 +56,7 @@ func (r *ChunkStreamReader) readAtLeast(size int) error {
mb := r.leftOver mb := r.leftOver
r.leftOver = nil r.leftOver = nil
for mb.Len() < size { for mb.Len() < size {
extra, err := r.reader.Read() extra, err := r.reader.ReadMultiBuffer()
if err != nil { if err != nil {
mb.Release() mb.Release()
return err return err
@ -78,7 +78,7 @@ func (r *ChunkStreamReader) readSize() (uint16, error) {
return r.sizeDecoder.Decode(r.buffer) return r.sizeDecoder.Decode(r.buffer)
} }
func (r *ChunkStreamReader) Read() (buf.MultiBuffer, error) { func (r *ChunkStreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
size := r.leftOverSize size := r.leftOverSize
if size == 0 { if size == 0 {
nextSize, err := r.readSize() nextSize, err := r.readSize()
@ -129,10 +129,10 @@ func NewChunkStreamWriter(sizeEncoder ChunkSizeEncoder, writer io.Writer) *Chunk
} }
} }
func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error { func (w *ChunkStreamWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
const sliceSize = 8192 const sliceSize = 8192
mbLen := mb.Len() mbLen := mb.Len()
mb2Write := buf.NewMultiBufferCap(mbLen / buf.Size + mbLen / sliceSize + 2) mb2Write := buf.NewMultiBufferCap(mbLen/buf.Size + mbLen/sliceSize + 2)
for { for {
slice := mb.SliceBySize(sliceSize) slice := mb.SliceBySize(sliceSize)
@ -150,5 +150,5 @@ func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error {
} }
} }
return w.writer.Write(mb2Write) return w.writer.WriteMultiBuffer(mb2Write)
} }

View File

@ -19,26 +19,26 @@ func TestChunkStreamIO(t *testing.T) {
b := buf.New() b := buf.New()
b.AppendBytes('a', 'b', 'c', 'd') b.AppendBytes('a', 'b', 'c', 'd')
assert(writer.Write(buf.NewMultiBufferValue(b)), IsNil) assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil)
b = buf.New() b = buf.New()
b.AppendBytes('e', 'f', 'g') b.AppendBytes('e', 'f', 'g')
assert(writer.Write(buf.NewMultiBufferValue(b)), IsNil) assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil)
assert(writer.Write(buf.MultiBuffer{}), IsNil) assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
assert(cache.Len(), Equals, 13) assert(cache.Len(), Equals, 13)
mb, err := reader.Read() mb, err := reader.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
assert(mb.Len(), Equals, 4) assert(mb.Len(), Equals, 4)
assert(mb[0].Bytes(), Equals, []byte("abcd")) assert(mb[0].Bytes(), Equals, []byte("abcd"))
mb, err = reader.Read() mb, err = reader.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
assert(mb.Len(), Equals, 3) assert(mb.Len(), Equals, 3)
assert(mb[0].Bytes(), Equals, []byte("efg")) assert(mb[0].Bytes(), Equals, []byte("efg"))
_, err = reader.Read() _, err = reader.ReadMultiBuffer()
assert(err, Equals, io.EOF) assert(err, Equals, io.EOF)
} }

View File

@ -28,7 +28,7 @@ func (r *CryptionReader) Read(data []byte) (int, error) {
} }
var ( var (
_ buf.MultiBufferWriter = (*CryptionWriter)(nil) _ buf.Writer = (*CryptionWriter)(nil)
) )
type CryptionWriter struct { type CryptionWriter struct {

View File

@ -29,7 +29,7 @@ func (*NoneResponse) WriteTo(buf.Writer) {}
func (*HTTPResponse) WriteTo(writer buf.Writer) { func (*HTTPResponse) WriteTo(writer buf.Writer) {
b := buf.NewLocal(512) b := buf.NewLocal(512)
common.Must(b.AppendSupplier(serial.WriteString(http403response))) common.Must(b.AppendSupplier(serial.WriteString(http403response)))
writer.Write(buf.NewMultiBufferValue(b)) writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
} }
// GetInternalResponse converts response settings from proto to internal data structure. // GetInternalResponse converts response settings from proto to internal data structure.

View File

@ -255,15 +255,18 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, rea
requestDone := signal.ExecuteAsync(func() error { requestDone := signal.ExecuteAsync(func() error {
request.Header.Set("Connection", "close") request.Header.Set("Connection", "close")
requestWriter := buf.ToBytesWriter(ray.InboundInput()) requestWriter := buf.NewBufferedWriter(ray.InboundInput())
if err := request.Write(requestWriter); err != nil { if err := request.Write(requestWriter); err != nil {
return err return err
} }
if err := requestWriter.Flush(); err != nil {
return err
}
return nil return nil
}) })
responseDone := signal.ExecuteAsync(func() error { responseDone := signal.ExecuteAsync(func() error {
responseReader := bufio.NewReaderSize(buf.ToBytesReader(ray.InboundOutput()), 2048) responseReader := bufio.NewReaderSize(buf.NewBufferedReader(ray.InboundOutput()), 2048)
response, err := http.ReadResponse(responseReader, request) response, err := http.ReadResponse(responseReader, request)
if err == nil { if err == nil {
StripHopByHopHeaders(response.Header) StripHopByHopHeaders(response.Header)

View File

@ -93,7 +93,7 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5) ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
if request.Command == protocol.RequestCommandTCP { if request.Command == protocol.RequestCommandTCP {
bufferedWriter := buf.NewBufferedWriter(conn) bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
bodyWriter, err := WriteTCPRequest(request, bufferedWriter) bodyWriter, err := WriteTCPRequest(request, bufferedWriter)
if err != nil { if err != nil {
return newError("failed to write request").Base(err) return newError("failed to write request").Base(err)

View File

@ -68,7 +68,7 @@ func NewChunkReader(reader io.Reader, auth *Authenticator) *ChunkReader {
} }
} }
func (v *ChunkReader) Read() (buf.MultiBuffer, error) { func (v *ChunkReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
buffer := buf.New() buffer := buf.New()
if err := buffer.AppendSupplier(buf.ReadFullFrom(v.reader, 2)); err != nil { if err := buffer.AppendSupplier(buf.ReadFullFrom(v.reader, 2)); err != nil {
buffer.Release() buffer.Release()
@ -117,8 +117,8 @@ func NewChunkWriter(writer io.Writer, auth *Authenticator) *ChunkWriter {
} }
} }
// Write implements buf.MultiBufferWriter. // WriteMultiBuffer implements buf.Writer.
func (w *ChunkWriter) Write(mb buf.MultiBuffer) error { func (w *ChunkWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer mb.Release() defer mb.Release()
for { for {

View File

@ -16,7 +16,7 @@ func TestNormalChunkReading(t *testing.T) {
0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18) 0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18)
reader := NewChunkReader(buffer, NewAuthenticator(ChunkKeyGenerator( reader := NewChunkReader(buffer, NewAuthenticator(ChunkKeyGenerator(
[]byte{21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}))) []byte{21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36})))
payload, err := reader.Read() payload, err := reader.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
assert(payload[0].Bytes(), Equals, []byte{11, 12, 13, 14, 15, 16, 17, 18}) assert(payload[0].Bytes(), Equals, []byte{11, 12, 13, 14, 15, 16, 17, 18})
} }
@ -30,7 +30,7 @@ func TestNormalChunkWriting(t *testing.T) {
b := buf.NewLocal(256) b := buf.NewLocal(256)
b.Append([]byte{11, 12, 13, 14, 15, 16, 17, 18}) b.Append([]byte{11, 12, 13, 14, 15, 16, 17, 18})
err := writer.Write(buf.NewMultiBufferValue(b)) err := writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
assert(err, IsNil) assert(err, IsNil)
assert(buffer.Bytes(), Equals, []byte{0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18}) assert(buffer.Bytes(), Equals, []byte{0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18})
} }

View File

@ -362,7 +362,7 @@ type UDPReader struct {
User *protocol.User User *protocol.User
} }
func (v *UDPReader) Read() (buf.MultiBuffer, error) { func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
buffer := buf.New() buffer := buf.New()
err := buffer.AppendSupplier(buf.ReadFrom(v.Reader)) err := buffer.AppendSupplier(buf.ReadFrom(v.Reader))
if err != nil { if err != nil {

View File

@ -112,14 +112,14 @@ func TestTCPRequest(t *testing.T) {
writer, err := WriteTCPRequest(request, cache) writer, err := WriteTCPRequest(request, cache)
assert(err, IsNil) assert(err, IsNil)
assert(writer.Write(buf.NewMultiBufferValue(data)), IsNil) assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(data)), IsNil)
decodedRequest, reader, err := ReadTCPSession(request.User, cache) decodedRequest, reader, err := ReadTCPSession(request.User, cache)
assert(err, IsNil) assert(err, IsNil)
assert(decodedRequest.Address, Equals, request.Address) assert(decodedRequest.Address, Equals, request.Address)
assert(decodedRequest.Port, Equals, request.Port) assert(decodedRequest.Port, Equals, request.Port)
decodedData, err := reader.Read() decodedData, err := reader.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
assert(decodedData[0].String(), Equals, string(payload)) assert(decodedData[0].String(), Equals, string(payload))
} }
@ -158,19 +158,19 @@ func TestUDPReaderWriter(t *testing.T) {
b := buf.New() b := buf.New()
b.AppendSupplier(serial.WriteString("test payload")) b.AppendSupplier(serial.WriteString("test payload"))
err := writer.Write(buf.NewMultiBufferValue(b)) err := writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
assert(err, IsNil) assert(err, IsNil)
payload, err := reader.Read() payload, err := reader.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
assert(payload[0].String(), Equals, "test payload") assert(payload[0].String(), Equals, "test payload")
b = buf.New() b = buf.New()
b.AppendSupplier(serial.WriteString("test payload 2")) b.AppendSupplier(serial.WriteString("test payload 2"))
err = writer.Write(buf.NewMultiBufferValue(b)) err = writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
assert(err, IsNil) assert(err, IsNil)
payload, err = reader.Read() payload, err = reader.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
assert(payload[0].String(), Equals, "test payload 2") assert(payload[0].String(), Equals, "test payload 2")
} }

View File

@ -74,7 +74,7 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
reader := buf.NewReader(conn) reader := buf.NewReader(conn)
for { for {
mpayload, err := reader.Read() mpayload, err := reader.ReadMultiBuffer()
if err != nil { if err != nil {
break break
} }
@ -129,7 +129,7 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, dispatcher dispatcher.Interface) error { func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, dispatcher dispatcher.Interface) error {
conn.SetReadDeadline(time.Now().Add(time.Second * 8)) conn.SetReadDeadline(time.Now().Add(time.Second * 8))
bufferedReader := buf.NewBufferedReader(conn) bufferedReader := buf.NewBufferedReader(buf.NewReader(conn))
request, bodyReader, err := ReadTCPSession(s.user, bufferedReader) request, bodyReader, err := ReadTCPSession(s.user, bufferedReader)
if err != nil { if err != nil {
log.Access(conn.RemoteAddr(), "", log.AccessRejected, err) log.Access(conn.RemoteAddr(), "", log.AccessRejected, err)
@ -153,17 +153,17 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
} }
responseDone := signal.ExecuteAsync(func() error { responseDone := signal.ExecuteAsync(func() error {
bufferedWriter := buf.NewBufferedWriter(conn) bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
responseWriter, err := WriteTCPResponse(request, bufferedWriter) responseWriter, err := WriteTCPResponse(request, bufferedWriter)
if err != nil { if err != nil {
return newError("failed to write response").Base(err) return newError("failed to write response").Base(err)
} }
payload, err := ray.InboundOutput().Read() payload, err := ray.InboundOutput().ReadMultiBuffer()
if err != nil { if err != nil {
return err return err
} }
if err := responseWriter.Write(payload); err != nil { if err := responseWriter.WriteMultiBuffer(payload); err != nil {
return err return err
} }
payload.Release() payload.Release()

View File

@ -352,7 +352,7 @@ func NewUDPReader(reader io.Reader) *UDPReader {
return &UDPReader{reader: reader} return &UDPReader{reader: reader}
} }
func (r *UDPReader) Read() (buf.MultiBuffer, error) { func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
b := buf.New() b := buf.New()
if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil { if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil {
return nil, err return nil, err

View File

@ -24,11 +24,11 @@ func TestUDPEncoding(t *testing.T) {
content := []byte{'a'} content := []byte{'a'}
payload := buf.New() payload := buf.New()
payload.Append(content) payload.Append(content)
assert(writer.Write(buf.NewMultiBufferValue(payload)), IsNil) assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(payload)), IsNil)
reader := NewUDPReader(b) reader := NewUDPReader(b)
decodedPayload, err := reader.Read() decodedPayload, err := reader.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
assert(decodedPayload[0].Bytes(), Equals, content) assert(decodedPayload[0].Bytes(), Equals, content)
} }

View File

@ -58,7 +58,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispatcher dispatcher.Interface) error { func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispatcher dispatcher.Interface) error {
conn.SetReadDeadline(time.Now().Add(time.Second * 8)) conn.SetReadDeadline(time.Now().Add(time.Second * 8))
reader := buf.NewBufferedReader(conn) reader := buf.NewBufferedReader(buf.NewReader(conn))
inboundDest, ok := proxy.InboundEntryPointFromContext(ctx) inboundDest, ok := proxy.InboundEntryPointFromContext(ctx)
if !ok { if !ok {
@ -154,7 +154,7 @@ func (v *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
reader := buf.NewReader(conn) reader := buf.NewReader(conn)
for { for {
mpayload, err := reader.Read() mpayload, err := reader.ReadMultiBuffer()
if err != nil { if err != nil {
return err return err
} }

View File

@ -142,12 +142,12 @@ func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSess
bodyWriter := session.EncodeResponseBody(request, output) bodyWriter := session.EncodeResponseBody(request, output)
// Optimize for small response packet // Optimize for small response packet
data, err := input.Read() data, err := input.ReadMultiBuffer()
if err != nil { if err != nil {
return err return err
} }
if err := bodyWriter.Write(data); err != nil { if err := bodyWriter.WriteMultiBuffer(data); err != nil {
return err return err
} }
data.Release() data.Release()
@ -163,7 +163,7 @@ func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSess
} }
if request.Option.Has(protocol.RequestOptionChunkStream) { if request.Option.Has(protocol.RequestOptionChunkStream) {
if err := bodyWriter.Write(buf.MultiBuffer{}); err != nil { if err := bodyWriter.WriteMultiBuffer(buf.MultiBuffer{}); err != nil {
return err return err
} }
} }
@ -177,7 +177,7 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i
return err return err
} }
reader := buf.NewBufferedReader(connection) reader := buf.NewBufferedReader(buf.NewReader(connection))
session := encoding.NewServerSession(v.clients, v.sessionHistory) session := encoding.NewServerSession(v.clients, v.sessionHistory)
request, err := session.DecodeRequestHeader(reader) request, err := session.DecodeRequestHeader(reader)
@ -213,14 +213,12 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i
input := ray.InboundInput() input := ray.InboundInput()
output := ray.InboundOutput() output := ray.InboundOutput()
reader.SetBuffered(false)
requestDone := signal.ExecuteAsync(func() error { requestDone := signal.ExecuteAsync(func() error {
return transferRequest(timer, session, request, reader, input) return transferRequest(timer, session, request, reader, input)
}) })
responseDone := signal.ExecuteAsync(func() error { responseDone := signal.ExecuteAsync(func() error {
writer := buf.NewBufferedWriter(connection) writer := buf.NewBufferedWriter(buf.NewWriter(connection))
defer writer.Flush() defer writer.Flush()
response := &protocol.ResponseHeader{ response := &protocol.ResponseHeader{

View File

@ -106,7 +106,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5) ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
requestDone := signal.ExecuteAsync(func() error { requestDone := signal.ExecuteAsync(func() error {
writer := buf.NewBufferedWriter(conn) writer := buf.NewBufferedWriter(buf.NewWriter(conn))
if err := session.EncodeRequestHeader(request, writer); err != nil { if err := session.EncodeRequestHeader(request, writer); err != nil {
return newError("failed to encode request").Base(err).AtWarning() return newError("failed to encode request").Base(err).AtWarning()
} }
@ -117,7 +117,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
return newError("failed to get first payload").Base(err) return newError("failed to get first payload").Base(err)
} }
if !firstPayload.IsEmpty() { if !firstPayload.IsEmpty() {
if err := bodyWriter.Write(firstPayload); err != nil { if err := bodyWriter.WriteMultiBuffer(firstPayload); err != nil {
return newError("failed to write first payload").Base(err) return newError("failed to write first payload").Base(err)
} }
firstPayload.Release() firstPayload.Release()
@ -132,7 +132,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
} }
if request.Option.Has(protocol.RequestOptionChunkStream) { if request.Option.Has(protocol.RequestOptionChunkStream) {
if err := bodyWriter.Write(buf.MultiBuffer{}); err != nil { if err := bodyWriter.WriteMultiBuffer(buf.MultiBuffer{}); err != nil {
return err return err
} }
} }
@ -142,7 +142,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
responseDone := signal.ExecuteAsync(func() error { responseDone := signal.ExecuteAsync(func() error {
defer output.Close() defer output.Close()
reader := buf.NewBufferedReader(conn) reader := buf.NewBufferedReader(buf.NewReader(conn))
header, err := session.DecodeResponseHeader(reader) header, err := session.DecodeResponseHeader(reader)
if err != nil { if err != nil {
return err return err

View File

@ -169,8 +169,7 @@ type SystemConnection interface {
} }
var ( var (
_ buf.MultiBufferReader = (*Connection)(nil) _ buf.Reader = (*Connection)(nil)
_ buf.MultiBufferWriter = (*Connection)(nil)
) )
// Connection is a KCP connection over UDP. // Connection is a KCP connection over UDP.
@ -265,7 +264,7 @@ func (v *Connection) OnDataOutput() {
} }
} }
// ReadMultiBuffer implements buf.MultiBufferReader. // ReadMultiBuffer implements buf.Reader.
func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
if v == nil { if v == nil {
return nil, io.EOF return nil, io.EOF
@ -375,13 +374,6 @@ func (v *Connection) Write(b []byte) (int, error) {
} }
} }
func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
if c.mergingWriter == nil {
c.mergingWriter = buf.NewMergingWriterSize(c, c.mss)
}
return c.mergingWriter.Write(mb)
}
func (v *Connection) SetState(state State) { func (v *Connection) SetState(state State) {
current := v.Elapsed() current := v.Elapsed()
atomic.StoreInt32((*int32)(&v.state), int32(state)) atomic.StoreInt32((*int32)(&v.state), int32(state))

View File

@ -10,29 +10,23 @@ import (
//go:generate go run $GOPATH/src/v2ray.com/core/tools/generrorgen/main.go -pkg tls -path Transport,Internet,TLS //go:generate go run $GOPATH/src/v2ray.com/core/tools/generrorgen/main.go -pkg tls -path Transport,Internet,TLS
var ( var (
_ buf.MultiBufferReader = (*conn)(nil) _ buf.Writer = (*conn)(nil)
_ buf.MultiBufferWriter = (*conn)(nil)
) )
type conn struct { type conn struct {
net.Conn net.Conn
mergingReader buf.Reader mergingWriter *buf.BufferedWriter
mergingWriter buf.Writer
}
func (c *conn) ReadMultiBuffer() (buf.MultiBuffer, error) {
if c.mergingReader == nil {
c.mergingReader = buf.NewBytesToBufferReader(c.Conn)
}
return c.mergingReader.Read()
} }
func (c *conn) WriteMultiBuffer(mb buf.MultiBuffer) error { func (c *conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
if c.mergingWriter == nil { if c.mergingWriter == nil {
c.mergingWriter = buf.NewMergingWriter(c.Conn) c.mergingWriter = buf.NewBufferedWriter(buf.NewWriter(c.Conn))
} }
return c.mergingWriter.Write(mb) if err := c.mergingWriter.WriteMultiBuffer(mb); err != nil {
return err
}
return c.mergingWriter.Flush()
} }
func Client(c net.Conn, config *tls.Config) net.Conn { func Client(c net.Conn, config *tls.Config) net.Conn {

View File

@ -57,7 +57,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
inboundRay, existing := v.getInboundRay(ctx, destination) inboundRay, existing := v.getInboundRay(ctx, destination)
outputStream := inboundRay.InboundInput() outputStream := inboundRay.InboundInput()
if outputStream != nil { if outputStream != nil {
if err := outputStream.Write(buf.NewMultiBufferValue(payload)); err != nil { if err := outputStream.WriteMultiBuffer(buf.NewMultiBufferValue(payload)); err != nil {
v.RemoveRay(destination) v.RemoveRay(destination)
} }
} }
@ -71,7 +71,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
func handleInput(input ray.InputStream, callback ResponseCallback) { func handleInput(input ray.InputStream, callback ResponseCallback) {
for { for {
mb, err := input.Read() mb, err := input.ReadMultiBuffer()
if err != nil { if err != nil {
break break
} }

View File

@ -28,11 +28,11 @@ func TestSameDestinationDispatching(t *testing.T) {
link := ray.NewRay(ctx) link := ray.NewRay(ctx)
go func() { go func() {
for { for {
data, err := link.OutboundInput().Read() data, err := link.OutboundInput().ReadMultiBuffer()
if err != nil { if err != nil {
break break
} }
err = link.OutboundOutput().Write(data) err = link.OutboundOutput().WriteMultiBuffer(data)
assert(err, IsNil) assert(err, IsNil)
} }
}() }()

View File

@ -11,8 +11,7 @@ import (
) )
var ( var (
_ buf.MultiBufferReader = (*connection)(nil) _ buf.Writer = (*connection)(nil)
_ buf.MultiBufferWriter = (*connection)(nil)
) )
// connection is a wrapper for net.Conn over WebSocket connection. // connection is a wrapper for net.Conn over WebSocket connection.
@ -20,8 +19,7 @@ type connection struct {
conn *websocket.Conn conn *websocket.Conn
reader io.Reader reader io.Reader
mergingReader buf.Reader mergingWriter *buf.BufferedWriter
mergingWriter buf.Writer
} }
func newConnection(conn *websocket.Conn) *connection { func newConnection(conn *websocket.Conn) *connection {
@ -47,13 +45,6 @@ func (c *connection) Read(b []byte) (int, error) {
} }
} }
func (c *connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
if c.mergingReader == nil {
c.mergingReader = buf.NewBytesToBufferReader(c)
}
return c.mergingReader.Read()
}
func (c *connection) getReader() (io.Reader, error) { func (c *connection) getReader() (io.Reader, error) {
if c.reader != nil { if c.reader != nil {
return c.reader, nil return c.reader, nil
@ -77,9 +68,12 @@ func (c *connection) Write(b []byte) (int, error) {
func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error { func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
if c.mergingWriter == nil { if c.mergingWriter == nil {
c.mergingWriter = buf.NewMergingWriter(c) c.mergingWriter = buf.NewBufferedWriter(buf.NewBufferToBytesWriter(c))
} }
return c.mergingWriter.Write(mb) if err := c.mergingWriter.WriteMultiBuffer(mb); err != nil {
return err
}
return c.mergingWriter.Flush()
} }
func (c *connection) Close() error { func (c *connection) Close() error {

View File

@ -106,7 +106,7 @@ func (s *Stream) Peek(b *buf.Buffer) {
} }
// Read reads data from the Stream. // Read reads data from the Stream.
func (s *Stream) Read() (buf.MultiBuffer, error) { func (s *Stream) ReadMultiBuffer() (buf.MultiBuffer, error) {
for { for {
mb, err := s.getData() mb, err := s.getData()
if err != nil { if err != nil {
@ -178,7 +178,7 @@ func (s *Stream) waitForStreamSize() error {
} }
// Write writes more data into the Stream. // Write writes more data into the Stream.
func (s *Stream) Write(data buf.MultiBuffer) error { func (s *Stream) WriteMultiBuffer(data buf.MultiBuffer) error {
if data.IsEmpty() { if data.IsEmpty() {
return nil return nil
} }

View File

@ -16,18 +16,18 @@ func TestStreamIO(t *testing.T) {
stream := NewStream(context.Background()) stream := NewStream(context.Background())
b1 := buf.New() b1 := buf.New()
b1.AppendBytes('a') b1.AppendBytes('a')
assert(stream.Write(buf.NewMultiBufferValue(b1)), IsNil) assert(stream.WriteMultiBuffer(buf.NewMultiBufferValue(b1)), IsNil)
_, err := stream.Read() _, err := stream.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
stream.Close() stream.Close()
_, err = stream.Read() _, err = stream.ReadMultiBuffer()
assert(err, Equals, io.EOF) assert(err, Equals, io.EOF)
b2 := buf.New() b2 := buf.New()
b2.AppendBytes('b') b2.AppendBytes('b')
err = stream.Write(buf.NewMultiBufferValue(b2)) err = stream.WriteMultiBuffer(buf.NewMultiBufferValue(b2))
assert(err, Equals, io.ErrClosedPipe) assert(err, Equals, io.ErrClosedPipe)
} }
@ -37,13 +37,13 @@ func TestStreamClose(t *testing.T) {
stream := NewStream(context.Background()) stream := NewStream(context.Background())
b1 := buf.New() b1 := buf.New()
b1.AppendBytes('a') b1.AppendBytes('a')
assert(stream.Write(buf.NewMultiBufferValue(b1)), IsNil) assert(stream.WriteMultiBuffer(buf.NewMultiBufferValue(b1)), IsNil)
stream.Close() stream.Close()
_, err := stream.Read() _, err := stream.ReadMultiBuffer()
assert(err, IsNil) assert(err, IsNil)
_, err = stream.Read() _, err = stream.ReadMultiBuffer()
assert(err, Equals, io.EOF) assert(err, Equals, io.EOF)
} }