diff --git a/common/protocol/encoding.go b/common/protocol/encoding.go index 5577b616..b3f4da62 100644 --- a/common/protocol/encoding.go +++ b/common/protocol/encoding.go @@ -10,7 +10,7 @@ type RequestEncoder interface { } type RequestDecoder interface { - DecodeRequestHeader(io.Reader) *RequestHeader + DecodeRequestHeader(io.Reader) (*RequestHeader, error) DecodeRequestBody(io.Reader) io.Reader } @@ -20,6 +20,6 @@ type ResponseEncoder interface { } type ResponseDecoder interface { - DecodeResponseHeader(io.Reader) *ResponseHeader + DecodeResponseHeader(io.Reader) (*ResponseHeader, error) DecodeResponseBody(io.Reader) io.Reader } diff --git a/common/protocol/errors.go b/common/protocol/errors.go new file mode 100644 index 00000000..8f02e270 --- /dev/null +++ b/common/protocol/errors.go @@ -0,0 +1,10 @@ +package protocol + +import ( + "errors" +) + +var ( + ErrorInvalidUser = errors.New("Invalid user.") + ErrorInvalidVersion = errors.New("Invalid version.") +) diff --git a/common/protocol/headers.go b/common/protocol/headers.go index 887fbae4..8e240e1e 100644 --- a/common/protocol/headers.go +++ b/common/protocol/headers.go @@ -2,6 +2,8 @@ package protocol import ( v2net "github.com/v2ray/v2ray-core/common/net" + "github.com/v2ray/v2ray-core/common/serial" + "github.com/v2ray/v2ray-core/common/uuid" ) type RequestCommand byte @@ -31,3 +33,12 @@ type ResponseCommand interface{} type ResponseHeader struct { Command ResponseCommand } + +type CommandSwitchAccount struct { + Host v2net.Address + Port v2net.Port + ID *uuid.UUID + AlterIds serial.Uint16Literal + Level UserLevel + ValidMin byte +} diff --git a/common/protocol/raw/client.go b/common/protocol/raw/client.go index 18ba229a..df8f31d3 100644 --- a/common/protocol/raw/client.go +++ b/common/protocol/raw/client.go @@ -8,7 +8,9 @@ import ( "github.com/v2ray/v2ray-core/common/alloc" "github.com/v2ray/v2ray-core/common/crypto" + "github.com/v2ray/v2ray-core/common/log" "github.com/v2ray/v2ray-core/common/protocol" + "github.com/v2ray/v2ray-core/transport" ) func hashTimestamp(t protocol.Timestamp) []byte { @@ -27,6 +29,7 @@ type ClientSession struct { responseHeader byte responseBodyKey []byte responseBodyIV []byte + responseReader io.Reader idHash protocol.IDHash } @@ -38,6 +41,10 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession { session.requestBodyKey = randomBytes[:16] session.requestBodyIV = randomBytes[16:32] session.responseHeader = randomBytes[32] + responseBodyKey := md5.Sum(session.requestBodyKey) + responseBodyIV := md5.Sum(session.requestBodyIV) + session.responseBodyKey = responseBodyKey[:] + session.responseBodyIV = responseBodyIV[:] session.idHash = idHash return session @@ -97,3 +104,42 @@ func (this *ClientSession) EncodeRequestBody(writer io.Writer) io.Writer { return crypto.NewCryptionWriter(aesStream, writer) } +func (this *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) { + aesStream := crypto.NewAesDecryptionStream(this.responseBodyKey, this.responseBodyIV) + this.responseReader = crypto.NewCryptionReader(aesStream, reader) + + buffer := alloc.NewSmallBuffer() + defer buffer.Release() + + _, err := io.ReadFull(this.responseReader, buffer.Value[:4]) + if err != nil { + log.Error("Raw: Failed to read response header: ", err) + return nil, err + } + + if buffer.Value[0] != this.responseHeader { + log.Warning("Raw: Unexpected response header. Expecting %d, but actually %d", this.responseHeader, buffer.Value[0]) + return nil, transport.ErrorCorruptedPacket + } + + header := new(protocol.ResponseHeader) + + if buffer.Value[2] != 0 { + cmdId := buffer.Value[2] + dataLen := int(buffer.Value[3]) + _, err := io.ReadFull(this.responseReader, buffer.Value[:dataLen]) + if err != nil { + log.Error("Raw: Failed to read response command: ", err) + return nil, err + } + data := buffer.Value[:dataLen] + command, err := UnmarshalCommand(cmdId, data) + header.Command = command + } + + return header, nil +} + +func (this *ClientSession) DecodeResponseBody(reader io.Reader) io.Reader { + return this.responseReader +} diff --git a/common/protocol/raw/commands.go b/common/protocol/raw/commands.go new file mode 100644 index 00000000..57e12b7f --- /dev/null +++ b/common/protocol/raw/commands.go @@ -0,0 +1,115 @@ +package raw + +import ( + "errors" + "io" + + v2net "github.com/v2ray/v2ray-core/common/net" + "github.com/v2ray/v2ray-core/common/protocol" + "github.com/v2ray/v2ray-core/common/serial" + "github.com/v2ray/v2ray-core/common/uuid" + "github.com/v2ray/v2ray-core/transport" +) + +var ( + ErrorCommandTypeMismatch = errors.New("Command type mismatch.") + ErrorUnknownCommand = errors.New("Unknown command.") +) + +func MarshalCommand(command interface{}, writer io.Writer) error { + var factory CommandFactory + switch command.(type) { + case *protocol.CommandSwitchAccount: + factory = new(CommandSwitchAccountFactory) + default: + return ErrorUnknownCommand + } + return factory.Marshal(command, writer) +} + +func UnmarshalCommand(cmdId byte, data []byte) (protocol.ResponseCommand, error) { + var factory CommandFactory + switch cmdId { + case 1: + factory = new(CommandSwitchAccountFactory) + default: + return nil, ErrorUnknownCommand + } + return factory.Unmarshal(data) +} + +type CommandFactory interface { + Marshal(command interface{}, writer io.Writer) error + Unmarshal(data []byte) (interface{}, error) +} + +type CommandSwitchAccountFactory struct { +} + +func (this *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Writer) error { + cmd, ok := command.(*protocol.CommandSwitchAccount) + if !ok { + return ErrorCommandTypeMismatch + } + + hostStr := "" + if cmd.Host != nil { + hostStr = cmd.Host.String() + } + writer.Write([]byte{byte(len(hostStr))}) + + if len(hostStr) > 0 { + writer.Write([]byte(hostStr)) + } + + writer.Write(cmd.Port.Bytes()) + + idBytes := cmd.ID.Bytes() + writer.Write(idBytes) + + writer.Write(cmd.AlterIds.Bytes()) + writer.Write([]byte{byte(cmd.Level)}) + + writer.Write([]byte{cmd.ValidMin}) + return nil +} + +func (this *CommandSwitchAccountFactory) Unmarshal(data []byte) (interface{}, error) { + cmd := new(protocol.CommandSwitchAccount) + if len(data) == 0 { + return nil, transport.ErrorCorruptedPacket + } + lenHost := int(data[0]) + if len(data) < lenHost+1 { + return nil, transport.ErrorCorruptedPacket + } + if lenHost > 0 { + cmd.Host = v2net.ParseAddress(string(data[1 : 1+lenHost])) + } + portStart := 1 + lenHost + if len(data) < portStart+2 { + return nil, transport.ErrorCorruptedPacket + } + cmd.Port = v2net.PortFromBytes(data[portStart : portStart+2]) + idStart := portStart + 2 + if len(data) < idStart+16 { + return nil, transport.ErrorCorruptedPacket + } + cmd.ID, _ = uuid.ParseBytes(data[idStart : idStart+16]) + alterIdStart := idStart + 16 + if len(data) < alterIdStart+2 { + return nil, transport.ErrorCorruptedPacket + } + cmd.AlterIds = serial.BytesLiteral(data[alterIdStart : alterIdStart+2]).Uint16() + levelStart := alterIdStart + 2 + if len(data) < levelStart+1 { + return nil, transport.ErrorCorruptedPacket + } + cmd.Level = protocol.UserLevel(data[levelStart]) + timeStart := levelStart + 1 + if len(data) < timeStart { + return nil, transport.ErrorCorruptedPacket + } + cmd.ValidMin = data[timeStart] + return cmd, nil +} diff --git a/common/protocol/raw/commands_test.go b/common/protocol/raw/commands_test.go new file mode 100644 index 00000000..d1198319 --- /dev/null +++ b/common/protocol/raw/commands_test.go @@ -0,0 +1,42 @@ +package raw_test + +import ( + "bytes" + "testing" + + netassert "github.com/v2ray/v2ray-core/common/net/testing/assert" + "github.com/v2ray/v2ray-core/common/protocol" + . "github.com/v2ray/v2ray-core/common/protocol/raw" + "github.com/v2ray/v2ray-core/common/uuid" + v2testing "github.com/v2ray/v2ray-core/testing" + "github.com/v2ray/v2ray-core/testing/assert" +) + +func TestSwitchAccount(t *testing.T) { + v2testing.Current(t) + + sa := &protocol.CommandSwitchAccount{ + Port: 1234, + ID: uuid.New(), + AlterIds: 1024, + Level: 128, + ValidMin: 16, + } + + buffer := bytes.NewBuffer(make([]byte, 0, 1024)) + err := MarshalCommand(sa, buffer) + assert.Error(err).IsNil() + + cmd, err := UnmarshalCommand(1, buffer.Bytes()) + assert.Error(err).IsNil() + + sa2, ok := cmd.(*protocol.CommandSwitchAccount) + assert.Bool(ok).IsTrue() + assert.Pointer(sa.Host).IsNil() + assert.Pointer(sa2.Host).IsNil() + netassert.Port(sa.Port).Equals(sa2.Port) + assert.String(sa.ID).Equals(sa2.ID.String()) + assert.Uint16(sa.AlterIds.Value()).Equals(sa2.AlterIds.Value()) + assert.Byte(byte(sa.Level)).Equals(byte(sa2.Level)) + assert.Byte(sa.ValidMin).Equals(sa2.ValidMin) +} diff --git a/common/protocol/raw/server.go b/common/protocol/raw/server.go new file mode 100644 index 00000000..0e72fb46 --- /dev/null +++ b/common/protocol/raw/server.go @@ -0,0 +1,143 @@ +package raw + +import ( + "crypto/md5" + "hash/fnv" + "io" + + "github.com/v2ray/v2ray-core/common/alloc" + "github.com/v2ray/v2ray-core/common/crypto" + "github.com/v2ray/v2ray-core/common/log" + v2net "github.com/v2ray/v2ray-core/common/net" + "github.com/v2ray/v2ray-core/common/protocol" + "github.com/v2ray/v2ray-core/common/serial" + "github.com/v2ray/v2ray-core/transport" +) + +type ServerSession struct { + userValidator protocol.UserValidator + requestBodyKey []byte + requestBodyIV []byte + responseBodyKey []byte + responseBodyIV []byte + responseHeader byte + responseWriter io.Writer +} + +func (this *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) { + buffer := alloc.NewSmallBuffer() + defer buffer.Release() + + _, err := io.ReadFull(reader, buffer.Value[:protocol.IDBytesLen]) + if err != nil { + log.Error("Raw: Failed to read request header: ", err) + return nil, err + } + + user, timestamp, valid := this.userValidator.Get(buffer.Value[:protocol.IDBytesLen]) + if !valid { + return nil, protocol.ErrorInvalidUser + } + + timestampHash := md5.New() + timestampHash.Write(hashTimestamp(timestamp)) + iv := timestampHash.Sum(nil) + aesStream := crypto.NewAesDecryptionStream(user.ID.CmdKey(), iv) + decryptor := crypto.NewCryptionReader(aesStream, reader) + + nBytes, err := io.ReadFull(decryptor, buffer.Value[:41]) + if err != nil { + log.Debug("Raw: Failed to read request header (", nBytes, " bytes): ", err) + return nil, err + } + bufferLen := nBytes + + request := &protocol.RequestHeader{ + User: user, + Version: buffer.Value[0], + } + + if request.Version != Version { + log.Warning("Raw: Invalid protocol version ", request.Version) + return nil, protocol.ErrorInvalidVersion + } + + this.requestBodyIV = append([]byte(nil), buffer.Value[1:17]...) // 16 bytes + this.requestBodyKey = append([]byte(nil), buffer.Value[17:33]...) // 16 bytes + this.responseHeader = buffer.Value[33] // 1 byte + request.Option = protocol.RequestOption(buffer.Value[34]) // 1 byte + 2 bytes reserved + request.Command = protocol.RequestCommand(buffer.Value[37]) + + request.Port = v2net.PortFromBytes(buffer.Value[38:40]) + + switch buffer.Value[40] { + case AddrTypeIPv4: + nBytes, err = io.ReadFull(decryptor, buffer.Value[41:45]) // 4 bytes + bufferLen += 4 + if err != nil { + log.Debug("VMess: Failed to read target IPv4 (", nBytes, " bytes): ", err) + return nil, err + } + request.Address = v2net.IPAddress(buffer.Value[41:45]) + case AddrTypeIPv6: + nBytes, err = io.ReadFull(decryptor, buffer.Value[41:57]) // 16 bytes + bufferLen += 16 + if err != nil { + log.Debug("VMess: Failed to read target IPv6 (", nBytes, " bytes): ", nBytes, err) + return nil, err + } + request.Address = v2net.IPAddress(buffer.Value[41:57]) + case AddrTypeDomain: + nBytes, err = io.ReadFull(decryptor, buffer.Value[41:42]) + if err != nil { + log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err) + return nil, err + } + domainLength := int(buffer.Value[41]) + if domainLength == 0 { + return nil, transport.ErrorCorruptedPacket + } + nBytes, err = io.ReadFull(decryptor, buffer.Value[42:42+domainLength]) + if err != nil { + log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err) + return nil, err + } + bufferLen += 1 + domainLength + domainBytes := append([]byte(nil), buffer.Value[42:42+domainLength]...) + request.Address = v2net.DomainAddress(string(domainBytes)) + } + + nBytes, err = io.ReadFull(decryptor, buffer.Value[bufferLen:bufferLen+4]) + if err != nil { + log.Debug("VMess: Failed to read checksum (", nBytes, " bytes): ", nBytes, err) + return nil, err + } + + fnv1a := fnv.New32a() + fnv1a.Write(buffer.Value[:bufferLen]) + actualHash := fnv1a.Sum32() + expectedHash := serial.BytesLiteral(buffer.Value[bufferLen : bufferLen+4]).Uint32Value() + + if actualHash != expectedHash { + return nil, transport.ErrorCorruptedPacket + } + + return request, nil +} + +func (this *ServerSession) DecodeRequestBody(reader io.Reader) io.Reader { + aesStream := crypto.NewAesDecryptionStream(this.requestBodyKey, this.requestBodyIV) + return crypto.NewCryptionReader(aesStream, reader) +} + +func (this *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, writer io.Writer) { + responseBodyKey := md5.Sum(this.requestBodyKey) + responseBodyIV := md5.Sum(this.requestBodyIV) + this.responseBodyKey = responseBodyKey[:] + this.requestBodyIV = responseBodyIV[:] + + aesStream := crypto.NewAesEncryptionStream(this.responseBodyKey, this.responseBodyIV) + encryptionWriter := crypto.NewCryptionWriter(aesStream, writer) + this.responseWriter = encryptionWriter + +}