diff --git a/proxy/vmess/protocol/vmess.go b/proxy/vmess/protocol/vmess.go index 9aebfcd2..3d981d74 100644 --- a/proxy/vmess/protocol/vmess.go +++ b/proxy/vmess/protocol/vmess.go @@ -4,12 +4,10 @@ package protocol import ( "crypto/aes" "crypto/cipher" - "crypto/rand" "encoding/binary" "errors" - "fmt" + "hash/fnv" "io" - mrand "math/rand" "time" v2io "github.com/v2ray/v2ray-core/common/io" @@ -34,6 +32,7 @@ const ( var ( ErrorInvalidUser = errors.New("Invalid User") ErrorInvalidVerion = errors.New("Invalid Version") + ErrorInvalidHash = errors.New("Invalid Hash") ) // VMessRequest implements the request message of VMess protocol. It only contains the header of a @@ -97,24 +96,11 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { return nil, err } - nBytes, err = decryptor.Read(buffer[0:1]) - if err != nil { - return nil, err - } - - randomLength := buffer[0] - if randomLength <= 0 || randomLength > 32 { - return nil, fmt.Errorf("Unexpected random length %d", randomLength) - } - _, err = decryptor.Read(buffer[:randomLength]) - if err != nil { - return nil, err - } - - nBytes, err = decryptor.Read(buffer[0:1]) + nBytes, err = decryptor.Read(buffer[:41]) if err != nil { return nil, err } + bufferLen := nBytes request := &VMessRequest{ UserId: *userId, @@ -126,68 +112,54 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { return nil, ErrorInvalidVerion } - // TODO: check number of bytes returned - _, err = decryptor.Read(request.RequestIV[:]) - if err != nil { - return nil, err - } - _, err = decryptor.Read(request.RequestKey[:]) - if err != nil { - return nil, err - } - _, err = decryptor.Read(request.ResponseHeader[:]) - if err != nil { - return nil, err - } - _, err = decryptor.Read(buffer[0:1]) - if err != nil { - return nil, err - } - request.Command = buffer[0] + copy(request.RequestIV[:], buffer[1:17]) // 16 bytes + copy(request.RequestKey[:], buffer[17:33]) // 16 bytes + copy(request.ResponseHeader[:], buffer[33:37]) // 4 bytes + request.Command = buffer[37] - _, err = decryptor.Read(buffer[0:2]) - if err != nil { - return nil, err - } - port := binary.BigEndian.Uint16(buffer[0:2]) + port := binary.BigEndian.Uint16(buffer[38:40]) - _, err = decryptor.Read(buffer[0:1]) - if err != nil { - return nil, err - } - switch buffer[0] { + switch buffer[40] { case addrTypeIPv4: - _, err = decryptor.Read(buffer[1:5]) + _, err = decryptor.Read(buffer[41:45]) // 4 bytes + bufferLen += 4 if err != nil { return nil, err } - request.Address = v2net.IPAddress(buffer[1:5], port) + request.Address = v2net.IPAddress(buffer[41:45], port) case addrTypeIPv6: - _, err = decryptor.Read(buffer[1:17]) + _, err = decryptor.Read(buffer[41:57]) // 16 bytes + bufferLen += 16 if err != nil { return nil, err } - request.Address = v2net.IPAddress(buffer[1:17], port) + request.Address = v2net.IPAddress(buffer[41:57], port) case addrTypeDomain: - _, err = decryptor.Read(buffer[1:2]) + _, err = decryptor.Read(buffer[41:42]) if err != nil { return nil, err } - domainLength := buffer[1] - _, err = decryptor.Read(buffer[2 : 2+domainLength]) + domainLength := int(buffer[41]) + _, err = decryptor.Read(buffer[42 : 42+domainLength]) if err != nil { return nil, err } - request.Address = v2net.DomainAddress(string(buffer[2:2+domainLength]), port) + bufferLen += 1 + domainLength + request.Address = v2net.DomainAddress(string(buffer[42:42+domainLength]), port) } - _, err = decryptor.Read(buffer[0:1]) + + _, err = decryptor.Read(buffer[bufferLen : bufferLen+4]) if err != nil { return nil, err } - randomLength = buffer[0] - _, err = decryptor.Read(buffer[:randomLength]) - if err != nil { - return nil, err + + fnv1a := fnv.New32a() + fnv1a.Write(buffer[:bufferLen]) + actualHash := fnv1a.Sum32() + expectedHash := binary.BigEndian.Uint32(buffer[bufferLen : bufferLen+4]) + + if actualHash != expectedHash { + return nil, ErrorInvalidHash } return request, nil @@ -207,15 +179,6 @@ func (request *VMessRequest) ToBytes(idHash user.CounterHash, randomRangeInt64 u encryptionBegin := len(buffer) - randomLength := mrand.Intn(32) + 1 - randomContent := make([]byte, randomLength) - _, err := rand.Read(randomContent) - if err != nil { - return nil, err - } - buffer = append(buffer, byte(randomLength)) - buffer = append(buffer, randomContent...) - buffer = append(buffer, request.Version) buffer = append(buffer, request.RequestIV[:]...) buffer = append(buffer, request.RequestKey[:]...) @@ -236,16 +199,18 @@ func (request *VMessRequest) ToBytes(idHash user.CounterHash, randomRangeInt64 u buffer = append(buffer, []byte(request.Address.Domain())...) } - paddingLength := mrand.Intn(32) + 1 - paddingBuffer := make([]byte, paddingLength) - _, err = rand.Read(paddingBuffer) - if err != nil { - return nil, err - } - buffer = append(buffer, byte(paddingLength)) - buffer = append(buffer, paddingBuffer...) encryptionEnd := len(buffer) + fnv1a := fnv.New32a() + fnv1a.Write(buffer[encryptionBegin:encryptionEnd]) + + fnvHash := fnv1a.Sum32() + buffer = append(buffer, byte(fnvHash>>24)) + buffer = append(buffer, byte(fnvHash>>16)) + buffer = append(buffer, byte(fnvHash>>8)) + buffer = append(buffer, byte(fnvHash)) + encryptionEnd += 4 + aesCipher, err := aes.NewCipher(request.UserId.CmdKey()) if err != nil { return nil, err diff --git a/spec/vmess.md b/spec/vmess.md index b9e83d25..b3b4f983 100644 --- a/spec/vmess.md +++ b/spec/vmess.md @@ -7,8 +7,6 @@ * 16 字节:基于时间的 hash(用户 [ID](https://github.com/V2Ray/v2ray-core/blob/master/spec/id.md)),见下文 指令部分: -* 1 字节:随机填充长度 M (0 < M <= 32) -* M 字节:随机填充内容 * 1 字节:版本号,目前为 0x1 * 16 字节:请求数据 IV * 16 字节:请求数据 Key @@ -26,8 +24,7 @@ * 4 字节:IPv4 * 1 字节长度 + 域名 * 16 字节:IPv6 -* 1 字节:随机填充长度 M2 (0 < M2 <= 32) -* M2 字节:随机填充内容 +* 4 字节:指令部分前面所有内容的 FNV1a hash 数据部分 * N 字节:请求数据