diff --git a/io/encryption.go b/io/encryption.go new file mode 100644 index 00000000..3fe17848 --- /dev/null +++ b/io/encryption.go @@ -0,0 +1,65 @@ +package io + +import ( + "crypto/cipher" + "io" +) + +// CryptionReader is a general purpose reader that applies +// block cipher on top of a regular reader. +type CryptionReader struct { + mode cipher.BlockMode + reader io.Reader +} + +func NewCryptionReader(mode cipher.BlockMode, reader io.Reader) *CryptionReader { + this := new(CryptionReader) + this.mode = mode + this.reader = reader + return this +} + +// Read reads blocks from underlying reader, the length of blocks must be +// a multiply of BlockSize() +func (reader CryptionReader) Read(blocks []byte) (int, error) { + nBytes, err := reader.reader.Read(blocks) + if err != nil && err != io.EOF { + return nBytes, err + } + if nBytes < len(blocks) { + for i, _ := range blocks[nBytes:] { + blocks[i] = 0 + } + } + reader.mode.CryptBlocks(blocks, blocks) + return nBytes, err +} + +func (reader CryptionReader) BlockSize() int { + return reader.mode.BlockSize() +} + +// Cryption writer is a general purpose of byte stream writer that applies +// block cipher on top of a regular writer. +type CryptionWriter struct { + mode cipher.BlockMode + writer io.Writer +} + +func NewCryptionWriter(mode cipher.BlockMode, writer io.Writer) *CryptionWriter { + this := new(CryptionWriter) + this.mode = mode + this.writer = writer + return this +} + +// Write writes the give blocks to underlying writer. The length of the blocks +// must be a multiply of BlockSize() +func (writer CryptionWriter) Write(blocks []byte) (int, error) { + writer.mode.CryptBlocks(blocks, blocks) + return writer.writer.Write(blocks) +} + +func (writer CryptionWriter) BlockSize() int { + return writer.mode.BlockSize() +} diff --git a/io/vmess/decryptionreader.go b/io/vmess/decryptionreader.go index 92d97d32..69932c1a 100644 --- a/io/vmess/decryptionreader.go +++ b/io/vmess/decryptionreader.go @@ -6,26 +6,31 @@ import ( "crypto/cipher" "fmt" "io" + + v2io "github.com/v2ray/v2ray-core/io" ) const ( - blockSize = 16 + blockSize = 16 // Decryption block size, inherited from AES ) +// DecryptionReader is a byte stream reader to decrypt AES-128 CBC (for now) +// encrypted content. type DecryptionReader struct { - cipher cipher.Block - reader io.Reader + reader *v2io.CryptionReader buffer *bytes.Buffer } -func NewDecryptionReader(reader io.Reader, key []byte) (*DecryptionReader, error) { +// NewDecryptionReader creates a new DescriptionReader by given byte Reader and +// AES key. +func NewDecryptionReader(reader io.Reader, key []byte, iv []byte) (*DecryptionReader, error) { decryptionReader := new(DecryptionReader) - cipher, err := aes.NewCipher(key) + aesCipher, err := aes.NewCipher(key) if err != nil { return nil, err } - decryptionReader.cipher = cipher - decryptionReader.reader = reader + aesBlockMode := cipher.NewCBCDecrypter(aesCipher, iv) + decryptionReader.reader = v2io.NewCryptionReader(aesBlockMode, reader) decryptionReader.buffer = bytes.NewBuffer(make([]byte, 0, 2*blockSize)) return decryptionReader, nil } @@ -33,26 +38,20 @@ func NewDecryptionReader(reader io.Reader, key []byte) (*DecryptionReader, error func (reader *DecryptionReader) readBlock() error { buffer := make([]byte, blockSize) nBytes, err := reader.reader.Read(buffer) - if err != nil { + if err != nil && err != io.EOF { return err } if nBytes < blockSize { return fmt.Errorf("Expected to read %d bytes, but got %d bytes", blockSize, nBytes) } - reader.cipher.Decrypt(buffer, buffer) reader.buffer.Write(buffer) - return nil + return err } +// Read returns decrypted bytes of given length func (reader *DecryptionReader) Read(p []byte) (int, error) { - if reader.buffer.Len() == 0 { - err := reader.readBlock() - if err != nil { - return 0, err - } - } nBytes, err := reader.buffer.Read(p) - if err != nil { + if err != nil && err != io.EOF { return nBytes, err } if nBytes < len(p) { diff --git a/io/vmess/decryptionreader_test.go b/io/vmess/decryptionreader_test.go index 52c483ee..340b549c 100644 --- a/io/vmess/decryptionreader_test.go +++ b/io/vmess/decryptionreader_test.go @@ -3,6 +3,7 @@ package vmess import ( "bytes" "crypto/aes" + "crypto/cipher" "crypto/rand" mrand "math/rand" "testing" @@ -26,21 +27,22 @@ func TestNormalReading(t *testing.T) { keySize := 16 key := make([]byte, keySize) randomBytes(key, t) + iv := make([]byte, keySize) + randomBytes(iv, t) - cipher, err := aes.NewCipher(key) + aesBlock, err := aes.NewCipher(key) if err != nil { t.Fatal(err) } + aesMode := cipher.NewCBCEncrypter(aesBlock, iv) ciphertext := make([]byte, testSize) - for encryptSize := 0; encryptSize < testSize; encryptSize += blockSize { - cipher.Encrypt(ciphertext[encryptSize:], plaintext[encryptSize:]) - } + aesMode.CryptBlocks(ciphertext, plaintext) ciphertextcopy := make([]byte, testSize) copy(ciphertextcopy, ciphertext) - reader, err := NewDecryptionReader(bytes.NewReader(ciphertextcopy), key) + reader, err := NewDecryptionReader(bytes.NewReader(ciphertextcopy), key, iv) if err != nil { t.Fatal(err) } diff --git a/io/vmess/vmess.go b/io/vmess/vmess.go index 2f795299..eb46c54e 100644 --- a/io/vmess/vmess.go +++ b/io/vmess/vmess.go @@ -12,9 +12,9 @@ import ( type VMessInput struct { version byte userHash [16]byte - randHash [256]byte - respKey [32]byte + respKey [16]byte iv [16]byte + respHead [4]byte command byte port uint16 target [256]byte