diff --git a/common/serial/bytes.go b/common/serial/bytes.go new file mode 100644 index 00000000..28818768 --- /dev/null +++ b/common/serial/bytes.go @@ -0,0 +1,23 @@ +package serial + +type Bytes interface { + Bytes() []byte +} + +type BytesLiteral []byte + +func (this BytesLiteral) Value() []byte { + return []byte(this) +} + +func (this BytesLiteral) Int64Value() int64 { + value := this.Value() + return int64(value[0])<<56 + + int64(value[1])<<48 + + int64(value[2])<<40 + + int64(value[3])<<32 + + int64(value[4])<<24 + + int64(value[5])<<16 + + int64(value[6])<<8 + + int64(value[7]) +} diff --git a/common/serial/numbers.go b/common/serial/numbers.go index 5e0d8389..fd62914e 100644 --- a/common/serial/numbers.go +++ b/common/serial/numbers.go @@ -31,3 +31,27 @@ func (this IntLiteral) String() string { func (this IntLiteral) Value() int { return int(this) } + +type Int64Literal int64 + +func (this Int64Literal) String() string { + return strconv.FormatInt(this.Value(), 10) +} + +func (this Int64Literal) Value() int64 { + return int64(this) +} + +func (this Int64Literal) Bytes() []byte { + value := this.Value() + return []byte{ + byte(value >> 56), + byte(value >> 48), + byte(value >> 40), + byte(value >> 32), + byte(value >> 24), + byte(value >> 16), + byte(value >> 8), + byte(value), + } +} diff --git a/common/uuid/uuid.go b/common/uuid/uuid.go index 39da16e9..9ac34086 100644 --- a/common/uuid/uuid.go +++ b/common/uuid/uuid.go @@ -13,7 +13,7 @@ var ( ) type UUID struct { - byteValue [16]byte + byteValue []byte stringValue string } @@ -25,7 +25,7 @@ func (this *UUID) Bytes() []byte { return this.byteValue[:] } -func bytesToString(bytes [16]byte) string { +func bytesToString(bytes []byte) string { result := hex.EncodeToString(bytes[0 : byteGroups[0]/2]) start := byteGroups[0] / 2 for i := 1; i < len(byteGroups); i++ { @@ -38,12 +38,20 @@ func bytesToString(bytes [16]byte) string { } func New() *UUID { - var bytes [16]byte - rand.Read(bytes[:]) + bytes := make([]byte, 16) + rand.Read(bytes) + uuid, _ := ParseBytes(bytes) + return uuid +} + +func ParseBytes(bytes []byte) (*UUID, error) { + if len(bytes) != 16 { + return nil, InvalidID + } return &UUID{ byteValue: bytes, stringValue: bytesToString(bytes), - } + }, nil } func ParseString(str string) (*UUID, error) { @@ -52,8 +60,10 @@ func ParseString(str string) (*UUID, error) { return nil, InvalidID } - var uuid UUID - uuid.stringValue = str + uuid := &UUID{ + byteValue: make([]byte, 16), + stringValue: str, + } b := uuid.byteValue[:] for _, byteGroup := range byteGroups { @@ -71,5 +81,5 @@ func ParseString(str string) (*UUID, error) { b = b[byteGroup/2:] } - return &uuid, nil + return uuid, nil } diff --git a/common/uuid/uuid_test.go b/common/uuid/uuid_test.go index f1a4b787..04bd9f6f 100644 --- a/common/uuid/uuid_test.go +++ b/common/uuid/uuid_test.go @@ -8,6 +8,17 @@ import ( "github.com/v2ray/v2ray-core/testing/assert" ) +func TestParseBytes(t *testing.T) { + v2testing.Current(t) + + str := "2418d087-648d-4990-86e8-19dca1d006d3" + bytes := []byte{0x24, 0x18, 0xd0, 0x87, 0x64, 0x8d, 0x49, 0x90, 0x86, 0xe8, 0x19, 0xdc, 0xa1, 0xd0, 0x06, 0xd3} + + uuid, err := ParseBytes(bytes) + assert.Error(err).IsNil() + assert.String(uuid).Equals(str) +} + func TestParseString(t *testing.T) { v2testing.Current(t) diff --git a/proxy/vmess/command/accounts.go b/proxy/vmess/command/accounts.go new file mode 100644 index 00000000..99af276c --- /dev/null +++ b/proxy/vmess/command/accounts.go @@ -0,0 +1,40 @@ +package command + +import ( + "io" + "time" + + "github.com/v2ray/v2ray-core/common/serial" + "github.com/v2ray/v2ray-core/common/uuid" + "github.com/v2ray/v2ray-core/transport" +) + +func init() { + RegisterResponseCommand(1, func() Command { return new(SwitchAccount) }) +} + +// Size: 16 + 8 = 24 +type SwitchAccount struct { + ID *uuid.UUID + ValidUntil time.Time +} + +func (this *SwitchAccount) Marshal(writer io.Writer) (int, error) { + idBytes := this.ID.Bytes() + timestamp := this.ValidUntil.Unix() + timeBytes := serial.Int64Literal(timestamp).Bytes() + + writer.Write(idBytes) + writer.Write(timeBytes) + + return 24, nil +} + +func (this *SwitchAccount) Unmarshal(data []byte) error { + if len(data) != 24 { + return transport.CorruptedPacket + } + this.ID, _ = uuid.ParseBytes(data[0:16]) + this.ValidUntil = time.Unix(serial.BytesLiteral(data[16:24]).Int64Value(), 0) + return nil +} diff --git a/proxy/vmess/command/accounts_test.go b/proxy/vmess/command/accounts_test.go new file mode 100644 index 00000000..9aee1a04 --- /dev/null +++ b/proxy/vmess/command/accounts_test.go @@ -0,0 +1,35 @@ +package command_test + +import ( + "bytes" + "testing" + "time" + + "github.com/v2ray/v2ray-core/common/uuid" + . "github.com/v2ray/v2ray-core/proxy/vmess/command" + v2testing "github.com/v2ray/v2ray-core/testing" + "github.com/v2ray/v2ray-core/testing/assert" +) + +func TestSwitchAccount(t *testing.T) { + v2testing.Current(t) + + sa := &SwitchAccount{ + ID: uuid.New(), + ValidUntil: time.Now(), + } + + cmd, err := CreateResponseCommand(1) + assert.Error(err).IsNil() + + buffer := bytes.NewBuffer(make([]byte, 0, 1024)) + nBytes, err := sa.Marshal(buffer) + assert.Error(err).IsNil() + assert.Int(nBytes).Equals(buffer.Len()) + + cmd.Unmarshal(buffer.Bytes()) + sa2, ok := cmd.(*SwitchAccount) + assert.Bool(ok).IsTrue() + assert.String(sa.ID).Equals(sa2.ID.String()) + assert.Int64(sa.ValidUntil.Unix()).Equals(sa2.ValidUntil.Unix()) +} diff --git a/proxy/vmess/command/command.go b/proxy/vmess/command/command.go new file mode 100644 index 00000000..08ced488 --- /dev/null +++ b/proxy/vmess/command/command.go @@ -0,0 +1,17 @@ +package command + +import ( + "errors" + "io" +) + +var ( + ErrorNoSuchCommand = errors.New("No such command.") +) + +type Command interface { + Marshal(io.Writer) (int, error) + Unmarshal([]byte) error +} + +type CommandCreator func() Command diff --git a/proxy/vmess/command/response.go b/proxy/vmess/command/response.go index 594cffd3..03d8a509 100644 --- a/proxy/vmess/command/response.go +++ b/proxy/vmess/command/response.go @@ -1,3 +1,18 @@ package command -type ResponseCmd byte +var ( + cmdCache = make(map[byte]CommandCreator) +) + +func RegisterResponseCommand(id byte, cmdFactory CommandCreator) error { + cmdCache[id] = cmdFactory + return nil +} + +func CreateResponseCommand(id byte) (Command, error) { + creator, found := cmdCache[id] + if !found { + return nil, ErrorNoSuchCommand + } + return creator(), nil +} diff --git a/testing/assert/bytessubject.go b/testing/assert/bytessubject.go index 07c86aaf..28940e0e 100644 --- a/testing/assert/bytessubject.go +++ b/testing/assert/bytessubject.go @@ -33,3 +33,9 @@ func (subject *BytesSubject) Equals(expectation []byte) { subject.Fail("is equal to", expectation) } } + +func (subject *BytesSubject) NotEquals(expectation []byte) { + if bytes.Equal(subject.value, expectation) { + subject.Fail("is not equal to", expectation) + } +} diff --git a/testing/assert/stringsubject.go b/testing/assert/stringsubject.go index 1a8f3e48..8be4cec2 100644 --- a/testing/assert/stringsubject.go +++ b/testing/assert/stringsubject.go @@ -34,6 +34,12 @@ func (subject *StringSubject) Equals(expectation string) { } } +func (subject *StringSubject) NotEquals(expectation string) { + if subject.value.String() == expectation { + subject.Fail(subject.DisplayString(), "is not equal to ", serial.StringLiteral(expectation)) + } +} + func (subject *StringSubject) Contains(substring serial.String) { if !strings.Contains(subject.value.String(), substring.String()) { subject.Fail(subject.DisplayString(), "contains", substring)