diff --git a/common/crypto/auth.go b/common/crypto/auth.go index c3d431b4..b73ebfe9 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -4,34 +4,86 @@ import ( "crypto/cipher" "errors" "io" + "v2ray.com/core/common/alloc" "v2ray.com/core/common/serial" ) var ( ErrAuthenticationFailed = errors.New("Authentication failed.") - errInsufficientBuffer = errors.New("Insufficient buffer.") + + errInsufficientBuffer = errors.New("Insufficient buffer.") + errInvalidNonce = errors.New("Invalid nonce.") ) -type BytesGenerator func() []byte - -type AuthenticationReader struct { - aead cipher.AEAD - buffer *alloc.Buffer - reader io.Reader - ivGen BytesGenerator - extraGen BytesGenerator - - chunk []byte +type BytesGenerator interface { + Next() []byte } -func NewAuthenticationReader(aead cipher.AEAD, reader io.Reader, ivGen BytesGenerator, extraGen BytesGenerator) *AuthenticationReader { +type NoOpBytesGenerator struct { + buffer [1]byte +} + +func (v NoOpBytesGenerator) Next() []byte { + return v.buffer[:0] +} + +type StaticBytesGenerator struct { + Content []byte +} + +func (v StaticBytesGenerator) Next() []byte { + return v.Content +} + +type Authenticator interface { + NonceSize() int + Overhead() int + Open(dst, cipherText []byte) ([]byte, error) + Seal(dst, plainText []byte) ([]byte, error) +} + +type AEADAuthenticator struct { + cipher.AEAD + NonceGenerator BytesGenerator + AdditionalDataGenerator BytesGenerator +} + +func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) { + iv := v.NonceGenerator.Next() + if len(iv) != v.AEAD.NonceSize() { + return nil, errInvalidNonce + } + + additionalData := v.AdditionalDataGenerator.Next() + return v.AEAD.Open(dst, iv, cipherText, additionalData) +} + +func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) { + iv := v.NonceGenerator.Next() + if len(iv) != v.AEAD.NonceSize() { + return nil, errInvalidNonce + } + + additionalData := v.AdditionalDataGenerator.Next() + return v.AEAD.Seal(dst, iv, plainText, additionalData), nil +} + +type AuthenticationReader struct { + auth Authenticator + buffer *alloc.Buffer + reader io.Reader + + chunk []byte + aggressive bool +} + +func NewAuthenticationReader(auth Authenticator, reader io.Reader, aggressive bool) *AuthenticationReader { return &AuthenticationReader{ - aead: aead, - buffer: alloc.NewLocalBuffer(32 * 1024), - reader: reader, - ivGen: ivGen, - extraGen: extraGen, + auth: auth, + buffer: alloc.NewLocalBuffer(32 * 1024), + reader: reader, + aggressive: aggressive, } } @@ -43,11 +95,11 @@ func (v *AuthenticationReader) NextChunk() error { if size > v.buffer.Len()-2 { return errInsufficientBuffer } - if size == v.aead.Overhead() { + if size == v.auth.Overhead() { return io.EOF } cipherChunk := v.buffer.BytesRange(2, size+2) - plainChunk, err := v.aead.Open(cipherChunk, v.ivGen(), cipherChunk, v.extraGen()) + plainChunk, err := v.auth.Open(cipherChunk, cipherChunk) if err != nil { return err } @@ -57,6 +109,9 @@ func (v *AuthenticationReader) NextChunk() error { } func (v *AuthenticationReader) CopyChunk(b []byte) int { + if len(v.chunk) == 0 { + return 0 + } nBytes := copy(b, v.chunk) if nBytes == len(v.chunk) { v.chunk = nil @@ -72,49 +127,56 @@ func (v *AuthenticationReader) Read(b []byte) (int, error) { return nBytes, nil } - err := v.NextChunk() - if err == errInsufficientBuffer { - _, err = v.buffer.FillFrom(v.reader) - } - - if err != nil { - return 0, err - } - totalBytes := 0 for { - totalBytes += v.CopyChunk(b) - if len(b) == 0 { - break + err := v.NextChunk() + if err == errInsufficientBuffer { + if totalBytes > 0 { + return totalBytes, nil + } + leftover := v.buffer.Bytes() + v.buffer.SetBytesFunc(func(b []byte) int { + return copy(b, leftover) + }) + _, err = v.buffer.FillFrom(v.reader) } - if err := v.NextChunk(); err != nil { - break + + if err != nil { + return 0, err + } + + nBytes := v.CopyChunk(b) + b = b[nBytes:] + totalBytes += nBytes + + if !v.aggressive { + return totalBytes, nil } } - return totalBytes, nil } type AuthenticationWriter struct { - aead cipher.AEAD + auth Authenticator buffer []byte writer io.Writer ivGen BytesGenerator extraGen BytesGenerator } -func NewAuthenticationWriter(aead cipher.AEAD, writer io.Writer, ivGen BytesGenerator, extraGen BytesGenerator) *AuthenticationWriter { +func NewAuthenticationWriter(auth Authenticator, writer io.Writer) *AuthenticationWriter { return &AuthenticationWriter{ - aead: aead, - buffer: make([]byte, 32*1024), - writer: writer, - ivGen: ivGen, - extraGen: extraGen, + auth: auth, + buffer: make([]byte, 32*1024), + writer: writer, } } func (v *AuthenticationWriter) Write(b []byte) (int, error) { - cipherChunk := v.aead.Seal(v.buffer[2:], v.ivGen(), b, v.extraGen()) + cipherChunk, err := v.auth.Seal(v.buffer[2:], b) + if err != nil { + return 0, err + } serial.Uint16ToBytes(uint16(len(cipherChunk)), b[:0]) - _, err := v.writer.Write(v.buffer[:2+len(cipherChunk)]) + _, err = v.writer.Write(v.buffer[:2+len(cipherChunk)]) return len(b), err } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 0122451d..968cbbd1 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -190,7 +190,12 @@ func (v *VMessInboundHandler) HandleConnection(connection internet.Connection) { bodyReader := session.DecodeRequestBody(reader) var requestReader v2io.Reader if request.Option.Has(protocol.RequestOptionChunkStream) { - authReader := crypto.NewAuthenticationReader(new(encoding.FnvAuthenticator), bodyReader, func() []byte { return nil }, func() []byte { return nil }) + auth := &crypto.AEADAuthenticator{ + AEAD: new(encoding.FnvAuthenticator), + NonceGenerator: crypto.NoOpBytesGenerator{}, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authReader := crypto.NewAuthenticationReader(auth, bodyReader, request.Command == protocol.RequestCommandTCP) requestReader = v2io.NewAdaptiveReader(authReader) } else { requestReader = v2io.NewAdaptiveReader(bodyReader)