From 148a7d064d5fcdc289d56f71411365c2b28d1c3a Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sat, 21 Apr 2018 00:54:53 +0200 Subject: [PATCH] simplify buf.BufferedReader --- app/proxyman/mux/mux.go | 4 +- app/proxyman/mux/mux_test.go | 2 +- common/buf/reader.go | 83 +++++++++++++------------------ common/buf/reader_test.go | 4 +- common/buf/writer_test.go | 2 +- common/crypto/auth.go | 2 +- common/crypto/chunk.go | 2 +- common/net/connection.go | 4 +- proxy/http/server.go | 2 +- proxy/shadowsocks/protocol.go | 4 +- proxy/shadowsocks/server.go | 6 +-- proxy/socks/server.go | 2 +- proxy/vmess/inbound/inbound.go | 2 +- proxy/vmess/outbound/outbound.go | 4 +- transport/internet/http/dialer.go | 4 +- 15 files changed, 56 insertions(+), 71 deletions(-) diff --git a/app/proxyman/mux/mux.go b/app/proxyman/mux/mux.go index bb902ec1..1571a916 100644 --- a/app/proxyman/mux/mux.go +++ b/app/proxyman/mux/mux.go @@ -258,7 +258,7 @@ func (m *Client) fetchOutput() { common.Must(m.done.Close()) }() - reader := buf.NewBufferedReader(m.link.Reader) + reader := &buf.BufferedReader{Reader: m.link.Reader} for { meta, err := ReadMetadata(reader) @@ -456,7 +456,7 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead func (w *ServerWorker) run(ctx context.Context) { input := w.link.Reader - reader := buf.NewBufferedReader(input) + reader := &buf.BufferedReader{Reader: input} defer w.sessionManager.Close() diff --git a/app/proxyman/mux/mux_test.go b/app/proxyman/mux/mux_test.go index f0d495ad..c8508dc2 100644 --- a/app/proxyman/mux/mux_test.go +++ b/app/proxyman/mux/mux_test.go @@ -59,7 +59,7 @@ func TestReaderWriter(t *testing.T) { assert(writePayload(writer2, 'y'), IsNil) writer2.Close() - bytesReader := buf.NewBufferedReader(pReader) + bytesReader := &buf.BufferedReader{Reader: pReader} meta, err := ReadMetadata(bytesReader) assert(err, IsNil) diff --git a/common/buf/reader.go b/common/buf/reader.go index f7f05132..5f3ac868 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -75,32 +75,17 @@ func (r *BytesToBufferReader) ReadMultiBuffer() (MultiBuffer, error) { // BufferedReader is a Reader that keeps its internal buffer. type BufferedReader struct { - stream Reader - leftOver MultiBuffer - buffered bool -} - -// NewBufferedReader returns a new BufferedReader. -func NewBufferedReader(reader Reader) *BufferedReader { - return &BufferedReader{ - stream: reader, - buffered: true, - } -} - -// SetBuffered sets whether to keep the interal buffer. -func (r *BufferedReader) SetBuffered(f bool) { - r.buffered = f -} - -// IsBuffered returns true if internal buffer is used. -func (r *BufferedReader) IsBuffered() bool { - return r.buffered + // Reader is the underlying reader to be read from + Reader Reader + // Buffer is the internal buffer to be read from first + Buffer MultiBuffer + // Direct indicates whether or not to use the internal buffer + Direct bool } // BufferedBytes returns the number of bytes that is cached in this reader. func (r *BufferedReader) BufferedBytes() int32 { - return r.leftOver.Len() + return r.Buffer.Len() } // ReadByte implements io.ByteReader. @@ -112,26 +97,26 @@ func (r *BufferedReader) ReadByte() (byte, error) { // Read implements io.Reader. It reads from internal buffer first (if available) and then reads from the underlying reader. func (r *BufferedReader) Read(b []byte) (int, error) { - if r.leftOver != nil { - nBytes, _ := r.leftOver.Read(b) - if r.leftOver.IsEmpty() { - r.leftOver.Release() - r.leftOver = nil + if r.Buffer != nil { + nBytes, _ := r.Buffer.Read(b) + if r.Buffer.IsEmpty() { + r.Buffer.Release() + r.Buffer = nil } return nBytes, nil } - if !r.buffered { - if reader, ok := r.stream.(io.Reader); ok { + if r.Direct { + if reader, ok := r.Reader.(io.Reader); ok { return reader.Read(b) } } - mb, err := r.stream.ReadMultiBuffer() + mb, err := r.Reader.ReadMultiBuffer() if mb != nil { nBytes, _ := mb.Read(b) if !mb.IsEmpty() { - r.leftOver = mb + r.Buffer = mb } return nBytes, err } @@ -140,28 +125,28 @@ func (r *BufferedReader) Read(b []byte) (int, error) { // ReadMultiBuffer implements Reader. func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) { - if r.leftOver != nil { - mb := r.leftOver - r.leftOver = nil + if r.Buffer != nil { + mb := r.Buffer + r.Buffer = nil return mb, nil } - return r.stream.ReadMultiBuffer() + return r.Reader.ReadMultiBuffer() } // ReadAtMost returns a MultiBuffer with at most size. func (r *BufferedReader) ReadAtMost(size int32) (MultiBuffer, error) { - if r.leftOver == nil { - mb, err := r.stream.ReadMultiBuffer() + if r.Buffer == nil { + mb, err := r.Reader.ReadMultiBuffer() if mb.IsEmpty() && err != nil { return nil, err } - r.leftOver = mb + r.Buffer = mb } - mb := r.leftOver.SliceBySize(size) - if r.leftOver.IsEmpty() { - r.leftOver = nil + mb := r.Buffer.SliceBySize(size) + if r.Buffer.IsEmpty() { + r.Buffer = nil } return mb, nil } @@ -169,16 +154,16 @@ func (r *BufferedReader) ReadAtMost(size int32) (MultiBuffer, error) { func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) { mbWriter := NewWriter(writer) totalBytes := int64(0) - if r.leftOver != nil { - totalBytes += int64(r.leftOver.Len()) - if err := mbWriter.WriteMultiBuffer(r.leftOver); err != nil { + if r.Buffer != nil { + totalBytes += int64(r.Buffer.Len()) + if err := mbWriter.WriteMultiBuffer(r.Buffer); err != nil { return 0, err } - r.leftOver = nil + r.Buffer = nil } for { - mb, err := r.stream.ReadMultiBuffer() + mb, err := r.Reader.ReadMultiBuffer() if mb != nil { totalBytes += int64(mb.Len()) if werr := mbWriter.WriteMultiBuffer(mb); werr != nil { @@ -202,8 +187,8 @@ func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) { // Close implements io.Closer. func (r *BufferedReader) Close() error { - if !r.leftOver.IsEmpty() { - r.leftOver.Release() + if !r.Buffer.IsEmpty() { + r.Buffer.Release() } - return common.Close(r.stream) + return common.Close(r.Reader) } diff --git a/common/buf/reader_test.go b/common/buf/reader_test.go index 1932b1bc..f5ea1e10 100644 --- a/common/buf/reader_test.go +++ b/common/buf/reader_test.go @@ -39,7 +39,7 @@ func TestBytesReaderWriteTo(t *testing.T) { assert := With(t) pReader, pWriter := pipe.New() - reader := NewBufferedReader(pReader) + reader := &BufferedReader{Reader: pReader} b1 := New() b1.AppendBytes('a', 'b', 'c') b2 := New() @@ -66,7 +66,7 @@ func TestBytesReaderMultiBuffer(t *testing.T) { assert := With(t) pReader, pWriter := pipe.New() - reader := NewBufferedReader(pReader) + reader := &BufferedReader{Reader: pReader} b1 := New() b1.AppendBytes('a', 'b', 'c') b2 := New() diff --git a/common/buf/writer_test.go b/common/buf/writer_test.go index 9b266413..dbd324ab 100644 --- a/common/buf/writer_test.go +++ b/common/buf/writer_test.go @@ -67,7 +67,7 @@ func TestDiscardBytesMultiBuffer(t *testing.T) { common.Must2(buffer.ReadFrom(io.LimitReader(rand.Reader, size))) r := NewReader(buffer) - nBytes, err := io.Copy(DiscardBytes, NewBufferedReader(r)) + nBytes, err := io.Copy(DiscardBytes, &BufferedReader{Reader: r}) assert(nBytes, Equals, int64(size)) assert(err, IsNil) } diff --git a/common/crypto/auth.go b/common/crypto/auth.go index 9661c861..5121ff34 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -91,7 +91,7 @@ type AuthenticationReader struct { func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType) *AuthenticationReader { return &AuthenticationReader{ auth: auth, - reader: buf.NewBufferedReader(buf.NewReader(reader)), + reader: &buf.BufferedReader{Reader: buf.NewReader(reader)}, sizeParser: sizeParser, transferType: transferType, size: -1, diff --git a/common/crypto/chunk.go b/common/crypto/chunk.go index e6ea63d1..7960aa97 100755 --- a/common/crypto/chunk.go +++ b/common/crypto/chunk.go @@ -68,7 +68,7 @@ type ChunkStreamReader struct { func NewChunkStreamReader(sizeDecoder ChunkSizeDecoder, reader io.Reader) *ChunkStreamReader { return &ChunkStreamReader{ sizeDecoder: sizeDecoder, - reader: buf.NewBufferedReader(buf.NewReader(reader)), + reader: &buf.BufferedReader{Reader: buf.NewReader(reader)}, buffer: make([]byte, sizeDecoder.SizeBytes()), } } diff --git a/common/net/connection.go b/common/net/connection.go index cd3dfb07..1b85bb6a 100644 --- a/common/net/connection.go +++ b/common/net/connection.go @@ -38,13 +38,13 @@ func ConnectionInputMulti(writer buf.Writer) ConnectionOption { func ConnectionOutput(reader io.Reader) ConnectionOption { return func(c *connection) { - c.reader = buf.NewBufferedReader(buf.NewReader(reader)) + c.reader = &buf.BufferedReader{Reader: buf.NewReader(reader)} } } func ConnectionOutputMulti(reader buf.Reader) ConnectionOption { return func(c *connection) { - c.reader = buf.NewBufferedReader(reader) + c.reader = &buf.BufferedReader{Reader: reader} } } diff --git a/proxy/http/server.go b/proxy/http/server.go index 50e7fdd8..39ae609c 100755 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -268,7 +268,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri } responseDone := func() error { - responseReader := bufio.NewReaderSize(buf.NewBufferedReader(link.Reader), buf.Size) + responseReader := bufio.NewReaderSize(&buf.BufferedReader{Reader: link.Reader}, buf.Size) response, err := http.ReadResponse(responseReader, request) if err == nil { http_proto.RemoveHopByHopHeaders(response.Header) diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 00d76efc..e3f582df 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -52,7 +52,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea if err != nil { return nil, nil, newError("failed to initialize decoding stream").Base(err).AtError() } - br := buf.NewBufferedReader(r) + br := &buf.BufferedReader{Reader: r} reader = nil authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv)) @@ -109,7 +109,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea return nil, nil, newError("invalid remote address.") } - br.SetBuffered(false) + br.Direct = true var chunkReader buf.Reader if request.Option.Has(RequestOptionOneTimeAuth) { diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 96ac7eaf..0aa80784 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -140,8 +140,8 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, dispatcher core.Dispatcher) error { sessionPolicy := s.v.PolicyManager().ForLevel(s.user.Level) conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)) - bufferedReader := buf.NewBufferedReader(buf.NewReader(conn)) - request, bodyReader, err := ReadTCPSession(s.user, bufferedReader) + bufferedReader := buf.BufferedReader{Reader: buf.NewReader(conn)} + request, bodyReader, err := ReadTCPSession(s.user, &bufferedReader) if err != nil { log.Record(&log.AccessMessage{ From: conn.RemoteAddr(), @@ -153,7 +153,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, } conn.SetReadDeadline(time.Time{}) - bufferedReader.SetBuffered(false) + bufferedReader.Direct = true dest := request.Destination() log.Record(&log.AccessMessage{ diff --git a/proxy/socks/server.go b/proxy/socks/server.go index d1a9308b..e1cd9916 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -70,7 +70,7 @@ func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispa newError("failed to set deadline").Base(err).WithContext(ctx).WriteToLog() } - reader := buf.NewBufferedReader(buf.NewReader(conn)) + reader := &buf.BufferedReader{Reader: buf.NewReader(conn)} inboundDest, ok := proxy.InboundEntryPointFromContext(ctx) if !ok { diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index c88ee2da..eff75d28 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -224,7 +224,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i return newError("unable to set read deadline").Base(err).AtWarning() } - reader := buf.NewBufferedReader(buf.NewReader(connection)) + reader := &buf.BufferedReader{Reader: buf.NewReader(connection)} session := encoding.NewServerSession(h.clients, h.sessionHistory) request, err := session.DecodeRequestHeader(reader) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 6e42a1d0..e9a52b11 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -146,14 +146,14 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia responseDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) - reader := buf.NewBufferedReader(buf.NewReader(conn)) + reader := &buf.BufferedReader{Reader: buf.NewReader(conn)} header, err := session.DecodeResponseHeader(reader) if err != nil { return newError("failed to read header").Base(err) } v.handleCommand(rec.Destination(), header.Command) - reader.SetBuffered(false) + reader.Direct = true bodyReader := session.DecodeResponseBody(request, reader) return buf.Copy(bodyReader, output, buf.UpdateActivity(timer)) diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index a1daf4b4..6cbe336e 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -84,11 +84,11 @@ func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error } preader, pwriter := pipe.New(pipe.WithSizeLimit(20 * 1024)) - breader := buf.NewBufferedReader(preader) + breader := &buf.BufferedReader{Reader: preader} request := &http.Request{ Method: "PUT", Host: httpSettings.getRandomHost(), - Body: buf.NewBufferedReader(preader), + Body: breader, URL: &url.URL{ Scheme: "https", Host: dest.NetAddr(),