mirror of https://github.com/v2ray/v2ray-core
simplify auth reader
parent
6652edfa6f
commit
bcfcba396b
|
@ -90,133 +90,58 @@ func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
|
|||
|
||||
type AuthenticationReader struct {
|
||||
auth Authenticator
|
||||
buffer *buf.Buffer
|
||||
reader io.Reader
|
||||
reader *buf.BufferedReader
|
||||
sizeParser ChunkSizeDecoder
|
||||
size int
|
||||
transferType protocol.TransferType
|
||||
}
|
||||
|
||||
const (
|
||||
readerBufferSize = 32 * 1024
|
||||
)
|
||||
|
||||
func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType) *AuthenticationReader {
|
||||
return &AuthenticationReader{
|
||||
auth: auth,
|
||||
buffer: buf.NewLocal(readerBufferSize),
|
||||
reader: reader,
|
||||
reader: buf.NewBufferedReader(buf.NewReader(reader)),
|
||||
sizeParser: sizeParser,
|
||||
size: -1,
|
||||
transferType: transferType,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *AuthenticationReader) readSize() error {
|
||||
if r.size >= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sizeBytes := r.sizeParser.SizeBytes()
|
||||
if r.buffer.Len() < sizeBytes {
|
||||
if r.buffer.IsEmpty() {
|
||||
r.buffer.Clear()
|
||||
} else {
|
||||
common.Must(r.buffer.Reset(buf.ReadFrom(r.buffer)))
|
||||
}
|
||||
|
||||
delta := sizeBytes - r.buffer.Len()
|
||||
if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
size, err := r.sizeParser.Decode(r.buffer.BytesTo(sizeBytes))
|
||||
func (r *AuthenticationReader) readSize() (int, error) {
|
||||
sizeBytes := make([]byte, r.sizeParser.SizeBytes())
|
||||
_, err := io.ReadFull(r.reader, sizeBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
r.size = int(size)
|
||||
r.buffer.SliceFrom(sizeBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *AuthenticationReader) readChunk(waitForData bool) ([]byte, error) {
|
||||
if err := r.readSize(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.size > readerBufferSize-r.sizeParser.SizeBytes() {
|
||||
return nil, newError("size too large ", r.size).AtWarning()
|
||||
}
|
||||
|
||||
if r.size == r.auth.Overhead() {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
if r.buffer.Len() < r.size {
|
||||
if !waitForData {
|
||||
return nil, io.ErrNoProgress
|
||||
}
|
||||
|
||||
if r.buffer.IsEmpty() {
|
||||
r.buffer.Clear()
|
||||
} else {
|
||||
common.Must(r.buffer.Reset(buf.ReadFrom(r.buffer)))
|
||||
}
|
||||
|
||||
delta := r.size - r.buffer.Len()
|
||||
if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
b, err := r.auth.Open(r.buffer.BytesTo(0), r.buffer.BytesTo(r.size))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.buffer.SliceFrom(r.size)
|
||||
r.size = -1
|
||||
return b, nil
|
||||
size, err := r.sizeParser.Decode(sizeBytes)
|
||||
return int(size), err
|
||||
}
|
||||
|
||||
func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||
b, err := r.readChunk(true)
|
||||
size, err := r.readSize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var mb buf.MultiBuffer
|
||||
if r.transferType == protocol.TransferTypeStream {
|
||||
mb.Write(b)
|
||||
if size == r.auth.Overhead() {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
var b *buf.Buffer
|
||||
if size <= buf.Size {
|
||||
b = buf.New()
|
||||
} else {
|
||||
var bb *buf.Buffer
|
||||
if len(b) <= buf.Size {
|
||||
bb = buf.New()
|
||||
} else {
|
||||
bb = buf.NewLocal(len(b))
|
||||
}
|
||||
bb.Append(b)
|
||||
mb.Append(bb)
|
||||
b = buf.NewLocal(size)
|
||||
}
|
||||
if err := b.Reset(buf.ReadFullFrom(r.reader, size)); err != nil {
|
||||
b.Release()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for r.buffer.Len() >= r.sizeParser.SizeBytes() {
|
||||
b, err := r.readChunk(false)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if r.transferType == protocol.TransferTypeStream {
|
||||
mb.Write(b)
|
||||
} else {
|
||||
var bb *buf.Buffer
|
||||
if len(b) <= buf.Size {
|
||||
bb = buf.New()
|
||||
} else {
|
||||
bb = buf.NewLocal(len(b))
|
||||
}
|
||||
bb.Append(b)
|
||||
mb.Append(bb)
|
||||
}
|
||||
rb, err := r.auth.Open(b.BytesTo(0), b.BytesTo(size))
|
||||
if err != nil {
|
||||
b.Release()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mb, nil
|
||||
b.Slice(0, len(rb))
|
||||
return buf.NewMultiBufferValue(b), nil
|
||||
}
|
||||
|
||||
type AuthenticationWriter struct {
|
||||
|
|
|
@ -122,6 +122,10 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
|
|||
|
||||
b1 := mb.SplitFirst()
|
||||
assert(b1.String(), Equals, "abcd")
|
||||
assert(mb.IsEmpty(), IsTrue)
|
||||
|
||||
mb, err = reader.ReadMultiBuffer()
|
||||
assert(err, IsNil)
|
||||
b2 := mb.SplitFirst()
|
||||
assert(b2.String(), Equals, "efgh")
|
||||
assert(mb.IsEmpty(), IsTrue)
|
||||
|
|
Loading…
Reference in New Issue