v2ray-core/io/vmess/decryptionreader.go

73 lines
1.7 KiB
Go
Raw Normal View History

package vmess
import (
2015-09-08 13:39:49 +00:00
"bytes"
"crypto/aes"
"crypto/cipher"
"fmt"
"io"
2015-09-08 16:21:33 +00:00
v2io "github.com/v2ray/v2ray-core/io"
)
const (
2015-09-08 16:21:15 +00:00
blockSize = 16 // Decryption block size, inherited from AES
)
2015-09-08 16:21:15 +00:00
// DecryptionReader is a byte stream reader to decrypt AES-128 CBC (for now)
// encrypted content.
type DecryptionReader struct {
2015-09-08 16:21:33 +00:00
reader *v2io.CryptionReader
2015-09-08 13:39:49 +00:00
buffer *bytes.Buffer
}
2015-09-08 16:21:15 +00:00
// NewDecryptionReader creates a new DescriptionReader by given byte Reader and
// AES key.
func NewDecryptionReader(reader io.Reader, key []byte, iv []byte) (*DecryptionReader, error) {
decryptionReader := new(DecryptionReader)
2015-09-08 16:21:15 +00:00
aesCipher, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
2015-09-08 16:21:33 +00:00
aesBlockMode := cipher.NewCBCDecrypter(aesCipher, iv)
2015-09-08 16:21:15 +00:00
decryptionReader.reader = v2io.NewCryptionReader(aesBlockMode, reader)
2015-09-08 13:39:49 +00:00
decryptionReader.buffer = bytes.NewBuffer(make([]byte, 0, 2*blockSize))
return decryptionReader, nil
}
func (reader *DecryptionReader) readBlock() error {
2015-09-08 13:39:49 +00:00
buffer := make([]byte, blockSize)
nBytes, err := reader.reader.Read(buffer)
2015-09-08 16:21:15 +00:00
if err != nil && err != io.EOF {
return err
}
if nBytes < blockSize {
return fmt.Errorf("Expected to read %d bytes, but got %d bytes", blockSize, nBytes)
}
2015-09-08 13:39:49 +00:00
reader.buffer.Write(buffer)
2015-09-08 16:21:15 +00:00
return err
}
2015-09-08 16:21:15 +00:00
// Read returns decrypted bytes of given length
func (reader *DecryptionReader) Read(p []byte) (int, error) {
nBytes, err := reader.buffer.Read(p)
2015-09-08 16:21:15 +00:00
if err != nil && err != io.EOF {
2015-09-08 13:39:49 +00:00
return nBytes, err
}
if nBytes < len(p) {
err = reader.readBlock()
if err != nil {
return nBytes, err
}
moreBytes, err := reader.buffer.Read(p[nBytes:])
if err != nil {
return nBytes, err
}
nBytes += moreBytes
if nBytes != len(p) {
return nBytes, fmt.Errorf("Unable to read %d bytes", len(p))
}
}
return nBytes, err
}