From 092d73b6b6020afc728905ff8bd5d25a96e80929 Mon Sep 17 00:00:00 2001
From: V2Ray <admin@v2ray.com>
Date: Mon, 14 Sep 2015 21:59:44 +0200
Subject: [PATCH] Update VMess protocol

---
 id.go                  | 23 +++++++++++++++--------
 io/vmess/vmess.go      | 38 ++++++++++++++++++++++++--------------
 io/vmess/vmess_test.go | 25 ++++++++++++++++++++++++-
 net/vmess/vmessin.go   |  1 +
 net/vmess/vmessout.go  |  9 +++++++--
 5 files changed, 71 insertions(+), 25 deletions(-)

diff --git a/id.go b/id.go
index 6d09a6a4..72685814 100644
--- a/id.go
+++ b/id.go
@@ -4,7 +4,6 @@ import (
 	"crypto/hmac"
 	"crypto/md5"
 	"encoding/hex"
-	"hash"
 	mrand "math/rand"
 	"time"
 
@@ -19,7 +18,7 @@ const (
 type ID struct {
 	String string
 	Bytes  []byte
-	hasher hash.Hash
+	cmdKey []byte
 }
 
 func NewID(id string) (ID, error) {
@@ -27,8 +26,13 @@ func NewID(id string) (ID, error) {
 	if err != nil {
 		return ID{}, log.Error("Failed to parse id %s", id)
 	}
-	hasher := hmac.New(md5.New, idBytes)
-	return ID{id, idBytes, hasher}, nil
+  
+  md5hash := md5.New()
+	md5hash.Write(idBytes)
+	md5hash.Write([]byte("c48619fe-8f02-49e0-b9e9-edf763e17e21"))
+	cmdKey := md5.Sum(nil)
+
+	return ID{id, idBytes, cmdKey[:]}, nil
 }
 
 func (v ID) TimeRangeHash(rangeSec int) []byte {
@@ -54,10 +58,13 @@ func (v ID) TimeHash(timeSec int64) []byte {
 }
 
 func (v ID) Hash(data []byte) []byte {
-	v.hasher.Write(data)
-	hash := v.hasher.Sum(nil)
-	v.hasher.Reset()
-	return hash
+  hasher := hmac.New(md5.New, v.Bytes)
+	hasher.Write(data)
+	return hasher.Sum(nil)
+}
+
+func (v ID) CmdKey() []byte {
+	return v.cmdKey
 }
 
 var byteGroups = []int{8, 4, 4, 4, 12}
diff --git a/io/vmess/vmess.go b/io/vmess/vmess.go
index 1aea3d58..df73706a 100644
--- a/io/vmess/vmess.go
+++ b/io/vmess/vmess.go
@@ -13,6 +13,7 @@ import (
 
 	"github.com/v2ray/v2ray-core"
 	v2io "github.com/v2ray/v2ray-core/io"
+	"github.com/v2ray/v2ray-core/log"
 	v2net "github.com/v2ray/v2ray-core/net"
 )
 
@@ -24,12 +25,11 @@ const (
 	Version = byte(0x01)
 
 	blockSize = 16
-
-	CryptoMessage = "c48619fe-8f02-49e0-b9e9-edf763e17e21"
 )
 
 var (
-	ErrorInvalidUser = errors.New("Invalid User")
+	ErrorInvalidUser   = errors.New("Invalid User")
+	ErrorInvalidVerion = errors.New("Invalid Version")
 
 	emptyIV = make([]byte, blockSize)
 )
@@ -62,17 +62,13 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 	request := new(VMessRequest)
 
 	buffer := make([]byte, 256)
-	nBytes, err := reader.Read(buffer[0:1])
-	if err != nil {
-		return nil, err
-	}
-	// TODO: verify version number
-	request.Version = buffer[0]
 
-	nBytes, err = reader.Read(buffer[:core.IDBytesLen])
+	nBytes, err := reader.Read(buffer[:core.IDBytesLen])
 	if err != nil {
 		return nil, err
 	}
+  
+  log.Debug("Read user hash: %v", buffer[:nBytes])
 
 	userId, valid := r.vUserSet.GetUser(buffer[:nBytes])
 	if !valid {
@@ -80,7 +76,7 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 	}
 	request.UserId = *userId
 
-	aesCipher, err := aes.NewCipher(userId.Hash([]byte(CryptoMessage)))
+	aesCipher, err := aes.NewCipher(userId.CmdKey())
 	if err != nil {
 		return nil, err
 	}
@@ -105,6 +101,17 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 		return nil, err
 	}
 
+	nBytes, err = decryptor.Read(buffer[0:1])
+	if err != nil {
+		return nil, err
+	}
+
+	request.Version = buffer[0]
+	if request.Version != Version {
+		log.Error("Unknown VMess version %d", request.Version)
+		return nil, ErrorInvalidVerion
+	}
+
 	// TODO: check number of bytes returned
 	_, err = decryptor.Read(request.RequestIV[:])
 	if err != nil {
@@ -182,8 +189,10 @@ func NewVMessRequestWriter() *VMessRequestWriter {
 
 func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) error {
 	buffer := make([]byte, 0, 300)
-	buffer = append(buffer, request.Version)
-	buffer = append(buffer, request.UserId.TimeRangeHash(30)...)
+  userHash := request.UserId.TimeRangeHash(30)
+  
+  log.Debug("Writing userhash: %v", userHash)
+	buffer = append(buffer, userHash...)
 
 	encryptionBegin := len(buffer)
 
@@ -196,6 +205,7 @@ func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) erro
 	buffer = append(buffer, byte(randomLength))
 	buffer = append(buffer, randomContent...)
 
+	buffer = append(buffer, request.Version)
 	buffer = append(buffer, request.RequestIV[:]...)
 	buffer = append(buffer, request.RequestKey[:]...)
 	buffer = append(buffer, request.ResponseHeader[:]...)
@@ -231,7 +241,7 @@ func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) erro
 	buffer = append(buffer, paddingBuffer...)
 	encryptionEnd := len(buffer)
 
-	aesCipher, err := aes.NewCipher(request.UserId.Hash([]byte(CryptoMessage)))
+	aesCipher, err := aes.NewCipher(request.UserId.CmdKey())
 	if err != nil {
 		return err
 	}
diff --git a/io/vmess/vmess_test.go b/io/vmess/vmess_test.go
index 1d1900da..1c1b0aef 100644
--- a/io/vmess/vmess_test.go
+++ b/io/vmess/vmess_test.go
@@ -3,6 +3,7 @@ package vmess
 import (
 	"bytes"
 	"crypto/rand"
+	"io/ioutil"
 	"testing"
 
 	"github.com/v2ray/v2ray-core"
@@ -51,7 +52,7 @@ func TestVMessSerialization(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	userSet.UserHashes[string(buffer.Bytes()[1:17])] = 0
+	userSet.UserHashes[string(buffer.Bytes()[:16])] = 0
 
 	requestReader := NewVMessRequestReader(&userSet)
 	actualRequest, err := requestReader.Read(buffer)
@@ -67,3 +68,25 @@ func TestVMessSerialization(t *testing.T) {
 	assert.Byte(actualRequest.Command).Named("Command").Equals(request.Command)
 	assert.String(actualRequest.Address.String()).Named("Address").Equals(request.Address.String())
 }
+
+func BenchmarkVMessRequestWriting(b *testing.B) {
+	userId, _ := core.NewID("2b2966ac-16aa-4fbf-8d81-c5f172a3da51")
+	userSet := mocks.MockUserSet{[]core.ID{}, make(map[string]int)}
+	userSet.AddUser(core.User{userId})
+
+	request := new(VMessRequest)
+	request.Version = byte(0x01)
+	request.UserId = userId
+
+	rand.Read(request.RequestIV[:])
+	rand.Read(request.RequestKey[:])
+	rand.Read(request.ResponseHeader[:])
+
+	request.Command = byte(0x01)
+	request.Address = v2net.DomainAddress("v2ray.com", 80)
+
+	requestWriter := NewVMessRequestWriter()
+	for i := 0; i < b.N; i++ {
+		requestWriter.Write(ioutil.Discard, request)
+	}
+}
diff --git a/net/vmess/vmessin.go b/net/vmess/vmessin.go
index 09d0bbf5..22e4a451 100644
--- a/net/vmess/vmessin.go
+++ b/net/vmess/vmessin.go
@@ -55,6 +55,7 @@ func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error
 
 	request, err := reader.Read(connection)
 	if err != nil {
+    log.Debug("Failed to parse VMess request: %v", err)
 		return err
 	}
 	log.Debug("Received request for %s", request.Address.String())
diff --git a/net/vmess/vmessout.go b/net/vmess/vmessout.go
index e4ef7ebb..9f853366 100644
--- a/net/vmess/vmessout.go
+++ b/net/vmess/vmessout.go
@@ -74,11 +74,15 @@ func (handler *VMessOutboundHandler) startCommunicate(request *vmessio.VMessRequ
 		return err
 	}
 	defer conn.Close()
+  
+  input := ray.OutboundInput()
+	output := ray.OutboundOutput()
 
 	requestWriter := vmessio.NewVMessRequestWriter()
 	err = requestWriter.Write(conn, request)
 	if err != nil {
 		log.Error("Failed to write VMess request: %v", err)
+    close(output)
 		return err
 	}
 
@@ -90,6 +94,7 @@ func (handler *VMessOutboundHandler) startCommunicate(request *vmessio.VMessRequ
 	response := vmessio.VMessResponse{}
 	nBytes, err := conn.Read(response[:])
 	if err != nil {
+    close(output)
 		log.Error("Failed to read VMess response (%d bytes): %v", nBytes, err)
 		return err
 	}
@@ -98,17 +103,17 @@ func (handler *VMessOutboundHandler) startCommunicate(request *vmessio.VMessRequ
 
 	encryptRequestWriter, err := v2io.NewAesEncryptWriter(requestKey, requestIV, conn)
 	if err != nil {
+    close(output)
 		log.Error("Failed to create encrypt writer: %v", err)
 		return err
 	}
 	decryptResponseReader, err := v2io.NewAesDecryptReader(responseKey[:], responseIV[:], conn)
 	if err != nil {
+    close(output)
 		log.Error("Failed to create decrypt reader: %v", err)
 		return err
 	}
 
-	input := ray.OutboundInput()
-	output := ray.OutboundOutput()
 	readFinish := make(chan bool)
 	writeFinish := make(chan bool)