diff --git a/common/protocol/headers.go b/common/protocol/headers.go index 8e240e1e..ab4c6439 100644 --- a/common/protocol/headers.go +++ b/common/protocol/headers.go @@ -13,12 +13,16 @@ const ( RequestCommandUDP = RequestCommand(0x02) ) -type RequestOption byte - const ( RequestOptionChunkStream = RequestOption(0x01) ) +type RequestOption byte + +func (this RequestOption) IsChunkStream() bool { + return (this & RequestOptionChunkStream) == RequestOptionChunkStream +} + type RequestHeader struct { Version byte User *User diff --git a/common/protocol/raw/auth.go b/common/protocol/raw/auth.go new file mode 100644 index 00000000..1e3725fd --- /dev/null +++ b/common/protocol/raw/auth.go @@ -0,0 +1,11 @@ +package raw + +import ( + "hash/fnv" +) + +func Authenticate(b []byte) uint32 { + fnv1hash := fnv.New32a() + fnv1hash.Write(b) + return fnv1hash.Sum32() +} diff --git a/common/protocol/raw/client.go b/common/protocol/raw/client.go index df8f31d3..5dd4c250 100644 --- a/common/protocol/raw/client.go +++ b/common/protocol/raw/client.go @@ -51,15 +51,13 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession { } func (this *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) { - buffer := alloc.NewSmallBuffer().Clear() - defer buffer.Release() - timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() idHash := this.idHash(header.User.AnyValidID().Bytes()) idHash.Write(timestamp.Bytes()) - idHash.Sum(buffer.Value) + writer.Write(idHash.Sum(nil)) - encryptionBegin := buffer.Len() + buffer := alloc.NewSmallBuffer().Clear() + defer buffer.Release() buffer.AppendBytes(Version) buffer.Append(this.requestBodyIV) @@ -80,20 +78,17 @@ func (this *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, w buffer.Append([]byte(header.Address.Domain())) } - encryptionEnd := buffer.Len() - fnv1a := fnv.New32a() - fnv1a.Write(buffer.Value[encryptionBegin:encryptionEnd]) + fnv1a.Write(buffer.Value) fnvHash := fnv1a.Sum32() buffer.AppendBytes(byte(fnvHash>>24), byte(fnvHash>>16), byte(fnvHash>>8), byte(fnvHash)) - encryptionEnd += 4 timestampHash := md5.New() timestampHash.Write(hashTimestamp(timestamp)) iv := timestampHash.Sum(nil) aesStream := crypto.NewAesEncryptionStream(header.User.ID.CmdKey(), iv) - aesStream.XORKeyStream(buffer.Value[encryptionBegin:encryptionEnd], buffer.Value[encryptionBegin:encryptionEnd]) + aesStream.XORKeyStream(buffer.Value, buffer.Value) writer.Write(buffer.Value) return diff --git a/common/protocol/raw/commands.go b/common/protocol/raw/commands.go index 57e12b7f..e4e2fb2c 100644 --- a/common/protocol/raw/commands.go +++ b/common/protocol/raw/commands.go @@ -4,6 +4,7 @@ import ( "errors" "io" + "github.com/v2ray/v2ray-core/common/alloc" v2net "github.com/v2ray/v2ray-core/common/net" "github.com/v2ray/v2ray-core/common/protocol" "github.com/v2ray/v2ray-core/common/serial" @@ -14,20 +15,47 @@ import ( var ( ErrorCommandTypeMismatch = errors.New("Command type mismatch.") ErrorUnknownCommand = errors.New("Unknown command.") + ErrorCommandTooLarge = errors.New("Command too large.") ) func MarshalCommand(command interface{}, writer io.Writer) error { + var cmdId byte var factory CommandFactory switch command.(type) { case *protocol.CommandSwitchAccount: factory = new(CommandSwitchAccountFactory) + cmdId = 1 default: return ErrorUnknownCommand } - return factory.Marshal(command, writer) + + buffer := alloc.NewSmallBuffer() + err := factory.Marshal(command, buffer) + if err != nil { + return err + } + + auth := Authenticate(buffer.Value) + len := buffer.Len() + 4 + if len > 255 { + return ErrorCommandTooLarge + } + + writer.Write([]byte{cmdId, byte(len), byte(auth >> 24), byte(auth >> 16), byte(auth >> 8), byte(auth)}) + writer.Write(buffer.Value) + return nil } func UnmarshalCommand(cmdId byte, data []byte) (protocol.ResponseCommand, error) { + if len(data) <= 4 { + return nil, transport.ErrorCorruptedPacket + } + expectedAuth := Authenticate(data[4:]) + actualAuth := serial.BytesLiteral(data[:4]).Uint32Value() + if expectedAuth != actualAuth { + return nil, transport.ErrorCorruptedPacket + } + var factory CommandFactory switch cmdId { case 1: @@ -35,7 +63,7 @@ func UnmarshalCommand(cmdId byte, data []byte) (protocol.ResponseCommand, error) default: return nil, ErrorUnknownCommand } - return factory.Unmarshal(data) + return factory.Unmarshal(data[4:]) } type CommandFactory interface { diff --git a/common/protocol/raw/server.go b/common/protocol/raw/server.go index 0e72fb46..ad3ee409 100644 --- a/common/protocol/raw/server.go +++ b/common/protocol/raw/server.go @@ -140,4 +140,10 @@ func (this *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, encryptionWriter := crypto.NewCryptionWriter(aesStream, writer) this.responseWriter = encryptionWriter + encryptionWriter.Write([]byte{this.responseHeader, 0x00}) + MarshalCommand(header.Command, encryptionWriter) +} + +func (this *ServerSession) EncodeResponseBody(writer io.Writer) io.Writer { + return this.responseWriter } diff --git a/proxy/vmess/outbound/command.go b/proxy/vmess/outbound/command.go index b269e3c1..b9c7f6b8 100644 --- a/proxy/vmess/outbound/command.go +++ b/proxy/vmess/outbound/command.go @@ -1,44 +1,20 @@ package outbound import ( - "hash/fnv" - - "github.com/v2ray/v2ray-core/common/log" v2net "github.com/v2ray/v2ray-core/common/net" + "github.com/v2ray/v2ray-core/common/protocol" proto "github.com/v2ray/v2ray-core/common/protocol" - "github.com/v2ray/v2ray-core/common/serial" - "github.com/v2ray/v2ray-core/proxy/vmess/command" ) -func (this *VMessOutboundHandler) handleSwitchAccount(cmd *command.SwitchAccount) { +func (this *VMessOutboundHandler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) { user := proto.NewUser(proto.NewID(cmd.ID), cmd.Level, cmd.AlterIds.Value(), "") dest := v2net.TCPDestination(cmd.Host, cmd.Port) this.receiverManager.AddDetour(NewReceiver(dest, user), cmd.ValidMin) } -func (this *VMessOutboundHandler) handleCommand(dest v2net.Destination, cmdId byte, data []byte) { - if len(data) < 4 { - return - } - fnv1hash := fnv.New32a() - fnv1hash.Write(data[4:]) - actualHashValue := fnv1hash.Sum32() - expectedHashValue := serial.BytesLiteral(data[:4]).Uint32Value() - if actualHashValue != expectedHashValue { - return - } - data = data[4:] - cmd, err := command.CreateResponseCommand(cmdId) - if err != nil { - log.Warning("VMessOut: Unknown response command (", cmdId, "): ", err) - return - } - if err := cmd.Unmarshal(data); err != nil { - log.Warning("VMessOut: Failed to parse response command: ", err) - return - } +func (this *VMessOutboundHandler) handleCommand(dest v2net.Destination, cmd protocol.ResponseCommand) { switch typedCommand := cmd.(type) { - case *command.SwitchAccount: + case *protocol.CommandSwitchAccount: if typedCommand.Host == nil { typedCommand.Host = dest.Address() } diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 08fe50b2..0f01a713 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -1,20 +1,16 @@ package outbound import ( - "crypto/md5" - "crypto/rand" - "io" "net" "sync" - "time" "github.com/v2ray/v2ray-core/app" "github.com/v2ray/v2ray-core/common/alloc" - v2crypto "github.com/v2ray/v2ray-core/common/crypto" v2io "github.com/v2ray/v2ray-core/common/io" "github.com/v2ray/v2ray-core/common/log" v2net "github.com/v2ray/v2ray-core/common/net" proto "github.com/v2ray/v2ray-core/common/protocol" + raw "github.com/v2ray/v2ray-core/common/protocol/raw" "github.com/v2ray/v2ray-core/proxy" "github.com/v2ray/v2ray-core/proxy/internal" vmessio "github.com/v2ray/v2ray-core/proxy/vmess/io" @@ -29,32 +25,25 @@ type VMessOutboundHandler struct { func (this *VMessOutboundHandler) Dispatch(firstPacket v2net.Packet, ray ray.OutboundRay) error { vNextAddress, vNextUser := this.receiverManager.PickReceiver() - command := protocol.CmdTCP + command := proto.RequestCommandTCP if firstPacket.Destination().IsUDP() { - command = protocol.CmdUDP + command = proto.RequestCommandUDP } - request := &protocol.VMessRequest{ + request := &proto.RequestHeader{ Version: protocol.Version, User: vNextUser, Command: command, Address: firstPacket.Destination().Address(), Port: firstPacket.Destination().Port(), } - if command == protocol.CmdUDP { - request.Option |= protocol.OptionChunk + if command == proto.RequestCommandUDP { + request.Option |= proto.RequestOptionChunkStream } - buffer := alloc.NewSmallBuffer() - defer buffer.Release() // Buffer is released after communication finishes. - io.ReadFull(rand.Reader, buffer.Value[:33]) // 16 + 16 + 1 - request.RequestIV = buffer.Value[:16] - request.RequestKey = buffer.Value[16:32] - request.ResponseHeader = buffer.Value[32] - return this.startCommunicate(request, vNextAddress, ray, firstPacket) } -func (this *VMessOutboundHandler) startCommunicate(request *protocol.VMessRequest, dest v2net.Destination, ray ray.OutboundRay, firstPacket v2net.Packet) error { +func (this *VMessOutboundHandler) startCommunicate(request *proto.RequestHeader, dest v2net.Destination, ray ray.OutboundRay, firstPacket v2net.Packet) error { var destIP net.IP if dest.Address().IsIPv4() || dest.Address().IsIPv6() { destIP = dest.Address().IP() @@ -87,8 +76,10 @@ func (this *VMessOutboundHandler) startCommunicate(request *protocol.VMessReques requestFinish.Lock() responseFinish.Lock() - go this.handleRequest(conn, request, firstPacket, input, &requestFinish) - go this.handleResponse(conn, request, dest, output, &responseFinish) + session := raw.NewClientSession(proto.DefaultIDHash) + + go this.handleRequest(session, conn, request, firstPacket, input, &requestFinish) + go this.handleResponse(session, conn, request, dest, output, &responseFinish) requestFinish.Lock() conn.CloseWrite() @@ -96,18 +87,11 @@ func (this *VMessOutboundHandler) startCommunicate(request *protocol.VMessReques return nil } -func (this *VMessOutboundHandler) handleRequest(conn net.Conn, request *protocol.VMessRequest, firstPacket v2net.Packet, input <-chan *alloc.Buffer, finish *sync.Mutex) { +func (this *VMessOutboundHandler) handleRequest(session *raw.ClientSession, conn net.Conn, request *proto.RequestHeader, firstPacket v2net.Packet, input <-chan *alloc.Buffer, finish *sync.Mutex) { defer finish.Unlock() - aesStream := v2crypto.NewAesEncryptionStream(request.RequestKey[:], request.RequestIV[:]) - encryptRequestWriter := v2crypto.NewCryptionWriter(aesStream, conn) - buffer := alloc.NewBuffer().Clear() - defer buffer.Release() - buffer, err := request.ToBytes(proto.NewTimestampGenerator(proto.Timestamp(time.Now().Unix()), 30), buffer) - if err != nil { - log.Error("VMessOut: Failed to serialize VMess request: ", err) - return - } + writer := v2io.NewBufferedWriter(conn) + session.EncodeRequestHeader(request, writer) // Send first packet of payload together with request, in favor of small requests. firstChunk := firstPacket.Chunk() @@ -122,23 +106,19 @@ func (this *VMessOutboundHandler) handleRequest(conn net.Conn, request *protocol return } - if request.IsChunkStream() { + if request.Option.IsChunkStream() { vmessio.Authenticate(firstChunk) } - aesStream.XORKeyStream(firstChunk.Value, firstChunk.Value) - buffer.Append(firstChunk.Value) + bodyWriter := session.EncodeRequestBody(writer) + bodyWriter.Write(firstChunk.Value) firstChunk.Release() - _, err = conn.Write(buffer.Value) - if err != nil { - log.Error("VMessOut: Failed to write VMess request: ", err) - return - } + writer.SetCached(false) if moreChunks { - var streamWriter v2io.Writer = v2io.NewAdaptiveWriter(encryptRequestWriter) - if request.IsChunkStream() { + var streamWriter v2io.Writer = v2io.NewAdaptiveWriter(bodyWriter) + if request.Option.IsChunkStream() { streamWriter = vmessio.NewAuthChunkWriter(streamWriter) } v2io.ChanToWriter(streamWriter, input) @@ -150,48 +130,30 @@ func headerMatch(request *protocol.VMessRequest, responseHeader byte) bool { return request.ResponseHeader == responseHeader } -func (this *VMessOutboundHandler) handleResponse(conn net.Conn, request *protocol.VMessRequest, dest v2net.Destination, output chan<- *alloc.Buffer, finish *sync.Mutex) { +func (this *VMessOutboundHandler) handleResponse(session *raw.ClientSession, conn net.Conn, request *proto.RequestHeader, dest v2net.Destination, output chan<- *alloc.Buffer, finish *sync.Mutex) { defer finish.Unlock() defer close(output) - responseKey := md5.Sum(request.RequestKey[:]) - responseIV := md5.Sum(request.RequestIV[:]) - aesStream := v2crypto.NewAesDecryptionStream(responseKey[:], responseIV[:]) - decryptResponseReader := v2crypto.NewCryptionReader(aesStream, conn) - - buffer := alloc.NewSmallBuffer() - defer buffer.Release() - _, err := io.ReadFull(decryptResponseReader, buffer.Value[:4]) + reader := v2io.NewBufferedReader(conn) + header, err := session.DecodeResponseHeader(reader) if err != nil { - log.Error("VMessOut: Failed to read VMess response (", buffer.Len(), " bytes): ", err) - return - } - if !headerMatch(request, buffer.Value[0]) { - log.Warning("VMessOut: unexepcted response header. The connection is probably hijacked.") + log.Warning("VMessOut: Failed to read response: ", err) return } + go this.handleCommand(dest, header.Command) - if buffer.Value[2] != 0 { - command := buffer.Value[2] - dataLen := int(buffer.Value[3]) - _, err := io.ReadFull(decryptResponseReader, buffer.Value[:dataLen]) - if err != nil { - log.Error("VMessOut: Failed to read response command: ", err) - return - } - data := buffer.Value[:dataLen] - go this.handleCommand(dest, command, data) - } + reader.SetCached(false) + decryptReader := session.DecodeResponseBody(conn) - var reader v2io.Reader - if request.IsChunkStream() { - reader = vmessio.NewAuthChunkReader(decryptResponseReader) + var bodyReader v2io.Reader + if request.Option.IsChunkStream() { + bodyReader = vmessio.NewAuthChunkReader(decryptReader) } else { - reader = v2io.NewAdaptiveReader(decryptResponseReader) + bodyReader = v2io.NewAdaptiveReader(decryptReader) } - v2io.ReaderToChan(output, reader) + v2io.ReaderToChan(output, bodyReader) return }