From 7efa7ee632f1859b99e1b1f8f2d7f3b818f815f6 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Mon, 9 Jul 2018 22:27:24 +0200 Subject: [PATCH] prepare to remove constructor of AuthenticationReader --- common/crypto/auth.go | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/common/crypto/auth.go b/common/crypto/auth.go index 26af6f7a..26bdade4 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -87,8 +87,10 @@ type AuthenticationReader struct { sizeParser ChunkSizeDecoder transferType protocol.TransferType padding PaddingLengthGenerator - size int32 - paddingLen int32 + size uint16 + paddingLen uint16 + hasSize bool + done bool } func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType, paddingLen PaddingLengthGenerator) *AuthenticationReader { @@ -98,26 +100,24 @@ func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, re sizeParser: sizeParser, transferType: transferType, padding: paddingLen, - size: -1, } } -func (r *AuthenticationReader) readSize() (int32, int32, error) { - if r.size != -1 { - s := r.size - r.size = -1 - return s, r.paddingLen, nil +func (r *AuthenticationReader) readSize() (uint16, uint16, error) { + if r.hasSize { + r.hasSize = false + return r.size, r.paddingLen, nil } sizeBytes := make([]byte, r.sizeParser.SizeBytes()) if _, err := io.ReadFull(r.reader, sizeBytes); err != nil { return 0, 0, err } - var padding int32 + var padding uint16 if r.padding != nil { - padding = int32(r.padding.NextPaddingLen()) + padding = r.padding.NextPaddingLen() } size, err := r.sizeParser.Decode(sizeBytes) - return int32(size), padding, err + return size, padding, err } var errSoft = newError("waiting for more data") @@ -127,31 +127,36 @@ func (r *AuthenticationReader) readInternal(soft bool) (*buf.Buffer, error) { return nil, errSoft } + if r.done { + return nil, io.EOF + } + size, padding, err := r.readSize() if err != nil { return nil, err } - if size == -2 || size == int32(r.auth.Overhead())+padding { - r.size = -2 + if size == uint16(r.auth.Overhead())+padding { + r.done = true return nil, io.EOF } - if soft && size > r.reader.BufferedBytes() { + if soft && int32(size) > r.reader.BufferedBytes() { r.size = size r.paddingLen = padding + r.hasSize = true return nil, errSoft } - b := buf.NewSize(size) - if err := b.Reset(buf.ReadFullFrom(r.reader, size)); err != nil { + b := buf.NewSize(int32(size)) + if err := b.Reset(buf.ReadFullFrom(r.reader, int32(size))); err != nil { b.Release() return nil, err } size -= padding - rb, err := r.auth.Open(b.BytesTo(0), b.BytesTo(size)) + rb, err := r.auth.Open(b.BytesTo(0), b.BytesTo(int32(size))) if err != nil { b.Release() return nil, err