v2ray-core/common/protocol/raw/commands.go

148 lines
3.6 KiB
Go

package raw
import (
"errors"
"io"
"github.com/v2ray/v2ray-core/common/alloc"
v2net "github.com/v2ray/v2ray-core/common/net"
"github.com/v2ray/v2ray-core/common/protocol"
"github.com/v2ray/v2ray-core/common/serial"
"github.com/v2ray/v2ray-core/common/uuid"
"github.com/v2ray/v2ray-core/transport"
)
var (
ErrorCommandTypeMismatch = errors.New("Command type mismatch.")
ErrorUnknownCommand = errors.New("Unknown command.")
ErrorCommandTooLarge = errors.New("Command too large.")
)
func MarshalCommand(command interface{}, writer io.Writer) error {
if command == nil {
return ErrorUnknownCommand
}
var cmdId byte
var factory CommandFactory
switch command.(type) {
case *protocol.CommandSwitchAccount:
factory = new(CommandSwitchAccountFactory)
cmdId = 1
default:
return ErrorUnknownCommand
}
buffer := alloc.NewSmallBuffer().Clear()
err := factory.Marshal(command, buffer)
if err != nil {
return err
}
auth := Authenticate(buffer.Value)
len := buffer.Len() + 4
if len > 255 {
return ErrorCommandTooLarge
}
writer.Write([]byte{cmdId, byte(len), byte(auth >> 24), byte(auth >> 16), byte(auth >> 8), byte(auth)})
writer.Write(buffer.Value)
return nil
}
func UnmarshalCommand(cmdId byte, data []byte) (protocol.ResponseCommand, error) {
if len(data) <= 4 {
return nil, transport.ErrorCorruptedPacket
}
expectedAuth := Authenticate(data[4:])
actualAuth := serial.BytesLiteral(data[:4]).Uint32Value()
if expectedAuth != actualAuth {
return nil, transport.ErrorCorruptedPacket
}
var factory CommandFactory
switch cmdId {
case 1:
factory = new(CommandSwitchAccountFactory)
default:
return nil, ErrorUnknownCommand
}
return factory.Unmarshal(data[4:])
}
type CommandFactory interface {
Marshal(command interface{}, writer io.Writer) error
Unmarshal(data []byte) (interface{}, error)
}
type CommandSwitchAccountFactory struct {
}
func (this *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Writer) error {
cmd, ok := command.(*protocol.CommandSwitchAccount)
if !ok {
return ErrorCommandTypeMismatch
}
hostStr := ""
if cmd.Host != nil {
hostStr = cmd.Host.String()
}
writer.Write([]byte{byte(len(hostStr))})
if len(hostStr) > 0 {
writer.Write([]byte(hostStr))
}
writer.Write(cmd.Port.Bytes())
idBytes := cmd.ID.Bytes()
writer.Write(idBytes)
writer.Write(cmd.AlterIds.Bytes())
writer.Write([]byte{byte(cmd.Level)})
writer.Write([]byte{cmd.ValidMin})
return nil
}
func (this *CommandSwitchAccountFactory) Unmarshal(data []byte) (interface{}, error) {
cmd := new(protocol.CommandSwitchAccount)
if len(data) == 0 {
return nil, transport.ErrorCorruptedPacket
}
lenHost := int(data[0])
if len(data) < lenHost+1 {
return nil, transport.ErrorCorruptedPacket
}
if lenHost > 0 {
cmd.Host = v2net.ParseAddress(string(data[1 : 1+lenHost]))
}
portStart := 1 + lenHost
if len(data) < portStart+2 {
return nil, transport.ErrorCorruptedPacket
}
cmd.Port = v2net.PortFromBytes(data[portStart : portStart+2])
idStart := portStart + 2
if len(data) < idStart+16 {
return nil, transport.ErrorCorruptedPacket
}
cmd.ID, _ = uuid.ParseBytes(data[idStart : idStart+16])
alterIdStart := idStart + 16
if len(data) < alterIdStart+2 {
return nil, transport.ErrorCorruptedPacket
}
cmd.AlterIds = serial.BytesLiteral(data[alterIdStart : alterIdStart+2]).Uint16()
levelStart := alterIdStart + 2
if len(data) < levelStart+1 {
return nil, transport.ErrorCorruptedPacket
}
cmd.Level = protocol.UserLevel(data[levelStart])
timeStart := levelStart + 1
if len(data) < timeStart {
return nil, transport.ErrorCorruptedPacket
}
cmd.ValidMin = data[timeStart]
return cmd, nil
}