Browse Source

format code

pull/298/head
V2Ray 9 years ago
parent
commit
9371b3d080
  1. 64
      io/vmess/decryptionreader.go
  2. 116
      io/vmess/decryptionreader_test.go

64
io/vmess/decryptionreader.go

@ -1,7 +1,7 @@
package vmess package vmess
import ( import (
"bytes" "bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"fmt" "fmt"
@ -13,9 +13,9 @@ const (
) )
type DecryptionReader struct { type DecryptionReader struct {
cipher cipher.Block cipher cipher.Block
reader io.Reader reader io.Reader
buffer *bytes.Buffer buffer *bytes.Buffer
} }
func NewDecryptionReader(reader io.Reader, key []byte) (*DecryptionReader, error) { func NewDecryptionReader(reader io.Reader, key []byte) (*DecryptionReader, error) {
@ -26,13 +26,13 @@ func NewDecryptionReader(reader io.Reader, key []byte) (*DecryptionReader, error
} }
decryptionReader.cipher = cipher decryptionReader.cipher = cipher
decryptionReader.reader = reader decryptionReader.reader = reader
decryptionReader.buffer = bytes.NewBuffer(make([]byte, 0, 2 * blockSize)) decryptionReader.buffer = bytes.NewBuffer(make([]byte, 0, 2*blockSize))
return decryptionReader, nil return decryptionReader, nil
} }
func (reader *DecryptionReader) readBlock() error { func (reader *DecryptionReader) readBlock() error {
buffer := make([]byte, blockSize) buffer := make([]byte, blockSize)
nBytes, err := reader.reader.Read(buffer) nBytes, err := reader.reader.Read(buffer)
if err != nil { if err != nil {
return err return err
} }
@ -40,34 +40,34 @@ func (reader *DecryptionReader) readBlock() error {
return fmt.Errorf("Expected to read %d bytes, but got %d bytes", blockSize, nBytes) return fmt.Errorf("Expected to read %d bytes, but got %d bytes", blockSize, nBytes)
} }
reader.cipher.Decrypt(buffer, buffer) reader.cipher.Decrypt(buffer, buffer)
reader.buffer.Write(buffer) reader.buffer.Write(buffer)
return nil return nil
} }
func (reader *DecryptionReader) Read(p []byte) (int, error) { func (reader *DecryptionReader) Read(p []byte) (int, error) {
if reader.buffer.Len() == 0 { if reader.buffer.Len() == 0 {
err := reader.readBlock() err := reader.readBlock()
if err != nil { if err != nil {
return 0, err return 0, err
} }
} }
nBytes, err := reader.buffer.Read(p) nBytes, err := reader.buffer.Read(p)
if err != nil { if err != nil {
return nBytes, err return nBytes, err
} }
if nBytes < len(p) { if nBytes < len(p) {
err = reader.readBlock() err = reader.readBlock()
if err != nil { if err != nil {
return nBytes, err return nBytes, err
} }
moreBytes, err := reader.buffer.Read(p[nBytes:]) moreBytes, err := reader.buffer.Read(p[nBytes:])
if err != nil { if err != nil {
return nBytes, err return nBytes, err
} }
nBytes += moreBytes nBytes += moreBytes
if nBytes != len(p) { if nBytes != len(p) {
return nBytes, fmt.Errorf("Unable to read %d bytes", len(p)) return nBytes, fmt.Errorf("Unable to read %d bytes", len(p))
} }
} }
return nBytes, err return nBytes, err
} }

116
io/vmess/decryptionreader_test.go

@ -1,67 +1,67 @@
package vmess package vmess
import ( import (
"bytes" "bytes"
"crypto/aes" "crypto/aes"
"crypto/rand" "crypto/rand"
mrand "math/rand" mrand "math/rand"
"testing" "testing"
) )
func randomBytes(p []byte, t *testing.T) { func randomBytes(p []byte, t *testing.T) {
nBytes, err := rand.Read(p) nBytes, err := rand.Read(p)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if nBytes != len(p) { if nBytes != len(p) {
t.Error("Unable to generate %d bytes of random buffer", len(p)) t.Error("Unable to generate %d bytes of random buffer", len(p))
} }
} }
func TestNormalReading(t *testing.T) { func TestNormalReading(t *testing.T) {
testSize := 256 testSize := 256
plaintext := make([]byte, testSize) plaintext := make([]byte, testSize)
randomBytes(plaintext, t) randomBytes(plaintext, t)
keySize := 16 keySize := 16
key := make([]byte, keySize) key := make([]byte, keySize)
randomBytes(key, t) randomBytes(key, t)
cipher, err := aes.NewCipher(key) cipher, err := aes.NewCipher(key)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ciphertext := make([]byte, testSize) ciphertext := make([]byte, testSize)
for encryptSize := 0; encryptSize < testSize; encryptSize += blockSize { for encryptSize := 0; encryptSize < testSize; encryptSize += blockSize {
cipher.Encrypt(ciphertext[encryptSize:], plaintext[encryptSize:]) cipher.Encrypt(ciphertext[encryptSize:], plaintext[encryptSize:])
} }
ciphertextcopy := make([]byte, testSize) ciphertextcopy := make([]byte, testSize)
copy(ciphertextcopy, ciphertext) copy(ciphertextcopy, ciphertext)
reader, err := NewDecryptionReader(bytes.NewReader(ciphertextcopy), key) reader, err := NewDecryptionReader(bytes.NewReader(ciphertextcopy), key)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
readtext := make([]byte, testSize) readtext := make([]byte, testSize)
readSize := 0 readSize := 0
for readSize < testSize { for readSize < testSize {
nBytes := mrand.Intn(16) + 1 nBytes := mrand.Intn(16) + 1
if nBytes > testSize - readSize { if nBytes > testSize-readSize {
nBytes = testSize - readSize nBytes = testSize - readSize
} }
bytesRead, err := reader.Read(readtext[readSize:readSize + nBytes]) bytesRead, err := reader.Read(readtext[readSize : readSize+nBytes])
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if bytesRead != nBytes { if bytesRead != nBytes {
t.Errorf("Expected to read %d bytes, but only read %d bytes", nBytes, bytesRead) t.Errorf("Expected to read %d bytes, but only read %d bytes", nBytes, bytesRead)
} }
readSize += nBytes readSize += nBytes
} }
if ! bytes.Equal(readtext, plaintext) { if !bytes.Equal(readtext, plaintext) {
t.Errorf("Expected plaintext %v, but got %v", plaintext, readtext) t.Errorf("Expected plaintext %v, but got %v", plaintext, readtext)
} }
} }

Loading…
Cancel
Save