diff --git a/common/crypto/auth.go b/common/crypto/auth.go index 192396b7..2d591bfb 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -90,133 +90,58 @@ func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) { type AuthenticationReader struct { auth Authenticator - buffer *buf.Buffer - reader io.Reader + reader *buf.BufferedReader sizeParser ChunkSizeDecoder - size int transferType protocol.TransferType } -const ( - readerBufferSize = 32 * 1024 -) - func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType) *AuthenticationReader { return &AuthenticationReader{ auth: auth, - buffer: buf.NewLocal(readerBufferSize), - reader: reader, + reader: buf.NewBufferedReader(buf.NewReader(reader)), sizeParser: sizeParser, - size: -1, transferType: transferType, } } -func (r *AuthenticationReader) readSize() error { - if r.size >= 0 { - return nil - } - - sizeBytes := r.sizeParser.SizeBytes() - if r.buffer.Len() < sizeBytes { - if r.buffer.IsEmpty() { - r.buffer.Clear() - } else { - common.Must(r.buffer.Reset(buf.ReadFrom(r.buffer))) - } - - delta := sizeBytes - r.buffer.Len() - if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil { - return err - } - } - size, err := r.sizeParser.Decode(r.buffer.BytesTo(sizeBytes)) +func (r *AuthenticationReader) readSize() (int, error) { + sizeBytes := make([]byte, r.sizeParser.SizeBytes()) + _, err := io.ReadFull(r.reader, sizeBytes) if err != nil { - return err + return 0, err } - r.size = int(size) - r.buffer.SliceFrom(sizeBytes) - return nil -} - -func (r *AuthenticationReader) readChunk(waitForData bool) ([]byte, error) { - if err := r.readSize(); err != nil { - return nil, err - } - if r.size > readerBufferSize-r.sizeParser.SizeBytes() { - return nil, newError("size too large ", r.size).AtWarning() - } - - if r.size == r.auth.Overhead() { - return nil, io.EOF - } - - if r.buffer.Len() < r.size { - if !waitForData { - return nil, io.ErrNoProgress - } - - if r.buffer.IsEmpty() { - r.buffer.Clear() - } else { - common.Must(r.buffer.Reset(buf.ReadFrom(r.buffer))) - } - - delta := r.size - r.buffer.Len() - if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil { - return nil, err - } - } - - b, err := r.auth.Open(r.buffer.BytesTo(0), r.buffer.BytesTo(r.size)) - if err != nil { - return nil, err - } - r.buffer.SliceFrom(r.size) - r.size = -1 - return b, nil + size, err := r.sizeParser.Decode(sizeBytes) + return int(size), err } func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) { - b, err := r.readChunk(true) + size, err := r.readSize() if err != nil { return nil, err } - var mb buf.MultiBuffer - if r.transferType == protocol.TransferTypeStream { - mb.Write(b) + if size == r.auth.Overhead() { + return nil, io.EOF + } + + var b *buf.Buffer + if size <= buf.Size { + b = buf.New() } else { - var bb *buf.Buffer - if len(b) <= buf.Size { - bb = buf.New() - } else { - bb = buf.NewLocal(len(b)) - } - bb.Append(b) - mb.Append(bb) + b = buf.NewLocal(size) + } + if err := b.Reset(buf.ReadFullFrom(r.reader, size)); err != nil { + b.Release() + return nil, err } - for r.buffer.Len() >= r.sizeParser.SizeBytes() { - b, err := r.readChunk(false) - if err != nil { - break - } - if r.transferType == protocol.TransferTypeStream { - mb.Write(b) - } else { - var bb *buf.Buffer - if len(b) <= buf.Size { - bb = buf.New() - } else { - bb = buf.NewLocal(len(b)) - } - bb.Append(b) - mb.Append(bb) - } + rb, err := r.auth.Open(b.BytesTo(0), b.BytesTo(size)) + if err != nil { + b.Release() + return nil, err } - - return mb, nil + b.Slice(0, len(rb)) + return buf.NewMultiBufferValue(b), nil } type AuthenticationWriter struct { diff --git a/common/crypto/auth_test.go b/common/crypto/auth_test.go index 58d256e4..3bc35fa3 100644 --- a/common/crypto/auth_test.go +++ b/common/crypto/auth_test.go @@ -122,6 +122,10 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) { b1 := mb.SplitFirst() assert(b1.String(), Equals, "abcd") + assert(mb.IsEmpty(), IsTrue) + + mb, err = reader.ReadMultiBuffer() + assert(err, IsNil) b2 := mb.SplitFirst() assert(b2.String(), Equals, "efgh") assert(mb.IsEmpty(), IsTrue)