diff --git a/common/crypto/auth.go b/common/crypto/auth.go index bf1c6664..f860577e 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -86,6 +86,7 @@ type AuthenticationReader struct { auth Authenticator reader *buf.BufferedReader sizeParser ChunkSizeDecoder + sizeBytes []byte transferType protocol.TransferType padding PaddingLengthGenerator size uint16 @@ -95,13 +96,20 @@ type AuthenticationReader struct { } func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType, paddingLen PaddingLengthGenerator) *AuthenticationReader { - return &AuthenticationReader{ + r := &AuthenticationReader{ auth: auth, - reader: &buf.BufferedReader{Reader: buf.NewReader(reader)}, sizeParser: sizeParser, transferType: transferType, padding: paddingLen, + sizeBytes: make([]byte, sizeParser.SizeBytes()), } + if breader, ok := reader.(*buf.BufferedReader); ok { + breader.Direct = false + r.reader = breader + } else { + r.reader = &buf.BufferedReader{Reader: buf.NewReader(reader)} + } + return r } func (r *AuthenticationReader) readSize() (uint16, uint16, error) { @@ -109,15 +117,14 @@ func (r *AuthenticationReader) readSize() (uint16, uint16, error) { r.hasSize = false return r.size, r.paddingLen, nil } - sizeBytes := make([]byte, r.sizeParser.SizeBytes()) - if _, err := io.ReadFull(r.reader, sizeBytes); err != nil { + if _, err := io.ReadFull(r.reader, r.sizeBytes); err != nil { return 0, 0, err } var padding uint16 if r.padding != nil { padding = r.padding.NextPaddingLen() } - size, err := r.sizeParser.Decode(sizeBytes) + size, err := r.sizeParser.Decode(r.sizeBytes) return size, padding, err } diff --git a/common/crypto/chunk.go b/common/crypto/chunk.go index 36e6900f..b8855fd0 100755 --- a/common/crypto/chunk.go +++ b/common/crypto/chunk.go @@ -71,11 +71,17 @@ type ChunkStreamReader struct { } func NewChunkStreamReader(sizeDecoder ChunkSizeDecoder, reader io.Reader) *ChunkStreamReader { - return &ChunkStreamReader{ + r := &ChunkStreamReader{ sizeDecoder: sizeDecoder, - reader: &buf.BufferedReader{Reader: buf.NewReader(reader)}, buffer: make([]byte, sizeDecoder.SizeBytes()), } + if breader, ok := reader.(*buf.BufferedReader); ok { + r.reader = breader + } else { + r.reader = &buf.BufferedReader{Reader: buf.NewReader(reader)} + } + + return r } func (r *ChunkStreamReader) readSize() (uint16, error) {