diff --git a/common/net/address.go b/common/net/address.go index d3eb9b85..e36ab6a7 100644 --- a/common/net/address.go +++ b/common/net/address.go @@ -19,6 +19,14 @@ type Address interface { String() string // String representation of this Address } +func ParseAddress(addr string) Address { + ip := net.ParseIP(addr) + if ip != nil { + return IPAddress(ip) + } + return DomainAddress(addr) +} + func allZeros(data []byte) bool { for _, v := range data { if v != 0 { diff --git a/common/net/address_json.go b/common/net/address_json.go index 60ba8e21..2434e059 100644 --- a/common/net/address_json.go +++ b/common/net/address_json.go @@ -16,11 +16,6 @@ func (this *AddressJson) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &rawStr); err != nil { return err } - ip := net.ParseIP(rawStr) - if ip != nil { - this.Address = IPAddress(ip) - } else { - this.Address = DomainAddress(rawStr) - } + this.Address = ParseAddress(rawStr) return nil } diff --git a/common/net/port.go b/common/net/port.go index e22d58df..fc791e02 100644 --- a/common/net/port.go +++ b/common/net/port.go @@ -7,7 +7,7 @@ import ( type Port serial.Uint16Literal func PortFromBytes(port []byte) Port { - return Port(uint16(port[0])<<8 + uint16(port[1])) + return Port(serial.ParseUint16(port)) } func (this Port) Value() uint16 { @@ -15,7 +15,7 @@ func (this Port) Value() uint16 { } func (this Port) Bytes() []byte { - return []byte{byte(this >> 8), byte(this)} + return serial.Uint16Literal(this).Bytes() } func (this Port) String() string { diff --git a/common/serial/numbers.go b/common/serial/numbers.go index fd62914e..235dab18 100644 --- a/common/serial/numbers.go +++ b/common/serial/numbers.go @@ -10,6 +10,17 @@ type Uint16 interface { type Uint16Literal uint16 +func ParseUint16(data []byte) Uint16Literal { + switch len(data) { + case 0: + return Uint16Literal(0) + case 1: + return Uint16Literal(uint16(data[0])) + default: + return Uint16Literal(uint16(data[0])<<8 + uint16(data[1])) + } +} + func (this Uint16Literal) String() string { return strconv.Itoa(int(this)) } @@ -18,6 +29,10 @@ func (this Uint16Literal) Value() uint16 { return uint16(this) } +func (this Uint16Literal) Bytes() []byte { + return []byte{byte(this >> 8), byte(this)} +} + type Int interface { Value() int } diff --git a/proxy/vmess/command/accounts.go b/proxy/vmess/command/accounts.go index 99af276c..29a94375 100644 --- a/proxy/vmess/command/accounts.go +++ b/proxy/vmess/command/accounts.go @@ -4,6 +4,7 @@ import ( "io" "time" + v2net "github.com/v2ray/v2ray-core/common/net" "github.com/v2ray/v2ray-core/common/serial" "github.com/v2ray/v2ray-core/common/uuid" "github.com/v2ray/v2ray-core/transport" @@ -13,28 +14,79 @@ func init() { RegisterResponseCommand(1, func() Command { return new(SwitchAccount) }) } -// Size: 16 + 8 = 24 +// Structure +// 1 byte: host len N +// N bytes: host +// 2 bytes: port +// 16 bytes: uuid +// 2 bytes: alterid +// 8 bytes: time type SwitchAccount struct { + Host v2net.Address + Port v2net.Port ID *uuid.UUID + AlterIds serial.Uint16Literal ValidUntil time.Time } func (this *SwitchAccount) Marshal(writer io.Writer) (int, error) { + outBytes := 0 + hostStr := "" + if this.Host != nil { + hostStr = this.Host.String() + } + writer.Write([]byte{byte(len(hostStr))}) + outBytes++ + + if len(hostStr) > 0 { + writer.Write([]byte(hostStr)) + outBytes += len(hostStr) + } + + writer.Write(this.Port.Bytes()) + outBytes += 2 + idBytes := this.ID.Bytes() + writer.Write(idBytes) + outBytes += len(idBytes) + + writer.Write(this.AlterIds.Bytes()) + outBytes += 2 + timestamp := this.ValidUntil.Unix() timeBytes := serial.Int64Literal(timestamp).Bytes() - writer.Write(idBytes) writer.Write(timeBytes) + outBytes += len(timeBytes) - return 24, nil + return outBytes, nil } func (this *SwitchAccount) Unmarshal(data []byte) error { - if len(data) != 24 { + lenHost := int(data[0]) + if len(data) < lenHost+1 { return transport.CorruptedPacket } - this.ID, _ = uuid.ParseBytes(data[0:16]) - this.ValidUntil = time.Unix(serial.BytesLiteral(data[16:24]).Int64Value(), 0) + this.Host = v2net.ParseAddress(string(data[1 : 1+lenHost])) + portStart := 1 + lenHost + if len(data) < portStart+2 { + return transport.CorruptedPacket + } + this.Port = v2net.PortFromBytes(data[portStart : portStart+2]) + idStart := portStart + 2 + if len(data) < idStart+16 { + return transport.CorruptedPacket + } + this.ID, _ = uuid.ParseBytes(data[idStart : idStart+16]) + alterIdStart := idStart + 16 + if len(data) < alterIdStart+2 { + return transport.CorruptedPacket + } + this.AlterIds = serial.ParseUint16(data[alterIdStart : alterIdStart+2]) + timeStart := alterIdStart + 2 + if len(data) < timeStart+8 { + return transport.CorruptedPacket + } + this.ValidUntil = time.Unix(serial.BytesLiteral(data[timeStart:timeStart+8]).Int64Value(), 0) return nil } diff --git a/proxy/vmess/command/accounts_test.go b/proxy/vmess/command/accounts_test.go index 9aee1a04..a699e5f9 100644 --- a/proxy/vmess/command/accounts_test.go +++ b/proxy/vmess/command/accounts_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + netassert "github.com/v2ray/v2ray-core/common/net/testing/assert" "github.com/v2ray/v2ray-core/common/uuid" . "github.com/v2ray/v2ray-core/proxy/vmess/command" v2testing "github.com/v2ray/v2ray-core/testing" @@ -15,7 +16,9 @@ func TestSwitchAccount(t *testing.T) { v2testing.Current(t) sa := &SwitchAccount{ + Port: 1234, ID: uuid.New(), + AlterIds: 1024, ValidUntil: time.Now(), } @@ -30,6 +33,8 @@ func TestSwitchAccount(t *testing.T) { cmd.Unmarshal(buffer.Bytes()) sa2, ok := cmd.(*SwitchAccount) assert.Bool(ok).IsTrue() + netassert.Port(sa.Port).Equals(sa2.Port) assert.String(sa.ID).Equals(sa2.ID.String()) + assert.Uint16(sa.AlterIds.Value()).Equals(sa2.AlterIds.Value()) assert.Int64(sa.ValidUntil.Unix()).Equals(sa2.ValidUntil.Unix()) } diff --git a/proxy/vmess/inbound/config.go b/proxy/vmess/inbound/config.go index 1d085f95..dcdbccea 100644 --- a/proxy/vmess/inbound/config.go +++ b/proxy/vmess/inbound/config.go @@ -4,6 +4,15 @@ import ( "github.com/v2ray/v2ray-core/proxy/vmess" ) +type DetourConfig struct { + ToTag string +} + +type FeaturesConfig struct { + Detour *DetourConfig +} + type Config struct { AllowedUsers []*vmess.User + Features *FeaturesConfig } diff --git a/proxy/vmess/inbound/config_json.go b/proxy/vmess/inbound/config_json.go index e62ab1a5..dfcc52bb 100644 --- a/proxy/vmess/inbound/config_json.go +++ b/proxy/vmess/inbound/config_json.go @@ -9,18 +9,49 @@ import ( "github.com/v2ray/v2ray-core/proxy/vmess" ) +func (this *DetourConfig) UnmarshalJSON(data []byte) error { + type JsonDetourConfig struct { + ToTag string `json:"to"` + } + jsonConfig := new(JsonDetourConfig) + if err := json.Unmarshal(data, jsonConfig); err != nil { + return err + } + this.ToTag = jsonConfig.ToTag + return nil +} + +func (this *FeaturesConfig) UnmarshalJSON(data []byte) error { + type JsonFeaturesConfig struct { + Detour *DetourConfig `json:"detour"` + } + jsonConfig := new(JsonFeaturesConfig) + if err := json.Unmarshal(data, jsonConfig); err != nil { + return err + } + this.Detour = jsonConfig.Detour + return nil +} + +func (this *Config) UnmarshalJSON(data []byte) error { + type JsonConfig struct { + Users []*vmess.User `json:"clients"` + Features *FeaturesConfig `json:"features"` + } + jsonConfig := new(JsonConfig) + if err := json.Unmarshal(data, jsonConfig); err != nil { + return err + } + this.AllowedUsers = jsonConfig.Users + this.Features = jsonConfig.Features + return nil +} + func init() { config.RegisterInboundConnectionConfig("vmess", func(data []byte) (interface{}, error) { - type JsonConfig struct { - Users []*vmess.User `json:"clients"` - } - jsonConfig := new(JsonConfig) - if err := json.Unmarshal(data, jsonConfig); err != nil { - return nil, err - } - return &Config{ - AllowedUsers: jsonConfig.Users, - }, nil + config := new(Config) + err := json.Unmarshal(data, config) + return config, err }) } diff --git a/proxy/vmess/user_json.go b/proxy/vmess/user_json.go index 3d538225..23038332 100644 --- a/proxy/vmess/user_json.go +++ b/proxy/vmess/user_json.go @@ -13,7 +13,7 @@ func (u *User) UnmarshalJSON(data []byte) error { IdString string `json:"id"` EmailString string `json:"email"` LevelInt int `json:"level"` - AlterIdCount int `json:"alterId"` + AlterIdCount uint16 `json:"alterId"` } var rawUserValue rawUser if err := json.Unmarshal(data, &rawUserValue); err != nil {