diff --git a/common/buf/reader.go b/common/buf/reader.go index a9e20c90..0dafeec7 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -97,6 +97,11 @@ func (r *BufferedReader) IsBuffered() bool { return r.buffered } +// BufferedBytes returns the number of bytes that is cached in this reader. +func (r *BufferedReader) BufferedBytes() int32 { + return int32(r.leftOver.Len()) +} + // ReadByte implements io.ByteReader. func (r *BufferedReader) ReadByte() (byte, error) { var b [1]byte diff --git a/common/crypto/auth.go b/common/crypto/auth.go index 9895bed1..20258f59 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -93,6 +93,7 @@ type AuthenticationReader struct { reader *buf.BufferedReader sizeParser ChunkSizeDecoder transferType protocol.TransferType + size int32 } func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType) *AuthenticationReader { @@ -101,42 +102,85 @@ func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, re reader: buf.NewBufferedReader(buf.NewReader(reader)), sizeParser: sizeParser, transferType: transferType, + size: -1, } } -func (r *AuthenticationReader) readSize() (int, error) { +func (r *AuthenticationReader) readSize() (int32, error) { + if r.size != -1 { + s := r.size + r.size = -1 + return s, nil + } sizeBytes := make([]byte, r.sizeParser.SizeBytes()) _, err := io.ReadFull(r.reader, sizeBytes) if err != nil { return 0, err } size, err := r.sizeParser.Decode(sizeBytes) - return int(size), err + return int32(size), err } -func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) { +var errSoft = newError("waiting for more data") + +func (r *AuthenticationReader) readInternal(soft bool) (*buf.Buffer, error) { + if soft && r.reader.BufferedBytes() < 2 { + return nil, errSoft + } + size, err := r.readSize() if err != nil { return nil, err } - if size == r.auth.Overhead() { + if size == -2 || size == int32(r.auth.Overhead()) { + r.size = -2 return nil, io.EOF } + if soft && size > r.reader.BufferedBytes() { + r.size = size + return nil, errSoft + } + b := buf.NewSize(uint32(size)) - if err := b.Reset(buf.ReadFullFrom(r.reader, size)); err != nil { + if err := b.Reset(buf.ReadFullFrom(r.reader, int(size))); err != nil { b.Release() return nil, err } - rb, err := r.auth.Open(b.BytesTo(0), b.BytesTo(size)) + rb, err := r.auth.Open(b.BytesTo(0), b.BytesTo(int(size))) if err != nil { b.Release() return nil, err } b.Slice(0, len(rb)) - return buf.NewMultiBufferValue(b), nil + + return b, nil +} + +func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) { + b, err := r.readInternal(false) + if err != nil { + return nil, err + } + + mb := buf.NewMultiBufferCap(32) + mb.Append(b) + + for { + b, err := r.readInternal(true) + if err == errSoft || err == io.EOF { + break + } + if err != nil { + mb.Release() + return nil, err + } + mb.Append(b) + } + + return mb, nil } type AuthenticationWriter struct { diff --git a/common/crypto/auth_test.go b/common/crypto/auth_test.go index 0986a583..b3cbe4f6 100644 --- a/common/crypto/auth_test.go +++ b/common/crypto/auth_test.go @@ -125,12 +125,10 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) { b1 := mb.SplitFirst() assert(b1.String(), Equals, "abcd") - assert(mb.IsEmpty(), IsTrue) - mb, err = reader.ReadMultiBuffer() - assert(err, IsNil) b2 := mb.SplitFirst() assert(b2.String(), Equals, "efgh") + assert(mb.IsEmpty(), IsTrue) _, err = reader.ReadMultiBuffer()