|
|
|
@ -9,6 +9,8 @@ import (
|
|
|
|
|
"sync" |
|
|
|
|
"time" |
|
|
|
|
|
|
|
|
|
"v2ray.com/core/common/dice" |
|
|
|
|
|
|
|
|
|
"golang.org/x/crypto/chacha20poly1305" |
|
|
|
|
"v2ray.com/core/common" |
|
|
|
|
"v2ray.com/core/common/bitmask" |
|
|
|
@ -103,6 +105,44 @@ func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionH
|
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func readAddress(buffer *buf.Buffer, reader io.Reader) (net.Address, net.Port, error) { |
|
|
|
|
var address net.Address |
|
|
|
|
var port net.Port |
|
|
|
|
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil { |
|
|
|
|
return address, port, newError("failed to read port and address type").Base(err) |
|
|
|
|
} |
|
|
|
|
port = net.PortFromBytes(buffer.BytesRange(-3, -1)) |
|
|
|
|
|
|
|
|
|
addressType := protocol.AddressType(buffer.Byte(buffer.Len() - 1)) |
|
|
|
|
switch addressType { |
|
|
|
|
case protocol.AddressTypeIPv4: |
|
|
|
|
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil { |
|
|
|
|
return address, port, newError("failed to read IPv4 address").Base(err) |
|
|
|
|
} |
|
|
|
|
address = net.IPAddress(buffer.BytesFrom(-4)) |
|
|
|
|
case protocol.AddressTypeIPv6: |
|
|
|
|
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil { |
|
|
|
|
return address, port, newError("failed to read IPv6 address").Base(err) |
|
|
|
|
} |
|
|
|
|
address = net.IPAddress(buffer.BytesFrom(-16)) |
|
|
|
|
case protocol.AddressTypeDomain: |
|
|
|
|
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil { |
|
|
|
|
return address, port, newError("failed to read domain address").Base(err) |
|
|
|
|
} |
|
|
|
|
domainLength := int(buffer.Byte(buffer.Len() - 1)) |
|
|
|
|
if domainLength == 0 { |
|
|
|
|
return address, port, newError("zero length domain") |
|
|
|
|
} |
|
|
|
|
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil { |
|
|
|
|
return address, port, newError("failed to read domain address").Base(err) |
|
|
|
|
} |
|
|
|
|
address = net.DomainAddress(string(buffer.BytesFrom(-domainLength))) |
|
|
|
|
default: |
|
|
|
|
return address, port, newError("invalid address type", addressType) |
|
|
|
|
} |
|
|
|
|
return address, port, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) { |
|
|
|
|
buffer := buf.New() |
|
|
|
|
defer buffer.Release() |
|
|
|
@ -128,7 +168,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
|
|
|
|
|
aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv) |
|
|
|
|
decryptor := crypto.NewCryptionReader(aesStream, reader) |
|
|
|
|
|
|
|
|
|
if err := buffer.Reset(buf.ReadFullFrom(decryptor, 41)); err != nil { |
|
|
|
|
if err := buffer.Reset(buf.ReadFullFrom(decryptor, 38)); err != nil { |
|
|
|
|
return nil, newError("failed to read request header").Base(err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -137,10 +177,6 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
|
|
|
|
|
Version: buffer.Byte(0), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if request.Version != Version { |
|
|
|
|
return nil, newError("invalid protocol version ", request.Version) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
s.requestBodyIV = append([]byte(nil), buffer.BytesRange(1, 17)...) // 16 bytes
|
|
|
|
|
s.requestBodyKey = append([]byte(nil), buffer.BytesRange(17, 33)...) // 16 bytes
|
|
|
|
|
var sid sessionId |
|
|
|
@ -159,33 +195,28 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
|
|
|
|
|
// 1 bytes reserved
|
|
|
|
|
request.Command = protocol.RequestCommand(buffer.Byte(37)) |
|
|
|
|
|
|
|
|
|
if request.Command != protocol.RequestCommandMux { |
|
|
|
|
request.Port = net.PortFromBytes(buffer.BytesRange(38, 40)) |
|
|
|
|
|
|
|
|
|
switch protocol.AddressType(buffer.Byte(40)) { |
|
|
|
|
case protocol.AddressTypeIPv4: |
|
|
|
|
if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil { |
|
|
|
|
return nil, newError("failed to read IPv4 address").Base(err) |
|
|
|
|
} |
|
|
|
|
request.Address = net.IPAddress(buffer.BytesFrom(-4)) |
|
|
|
|
case protocol.AddressTypeIPv6: |
|
|
|
|
if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 16)); err != nil { |
|
|
|
|
return nil, newError("failed to read IPv6 address").Base(err) |
|
|
|
|
} |
|
|
|
|
request.Address = net.IPAddress(buffer.BytesFrom(-16)) |
|
|
|
|
case protocol.AddressTypeDomain: |
|
|
|
|
if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 1)); err != nil { |
|
|
|
|
return nil, newError("failed to read domain address").Base(err) |
|
|
|
|
} |
|
|
|
|
domainLength := int(buffer.Byte(buffer.Len() - 1)) |
|
|
|
|
if domainLength == 0 { |
|
|
|
|
return nil, newError("zero length domain").Base(err) |
|
|
|
|
} |
|
|
|
|
if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, domainLength)); err != nil { |
|
|
|
|
return nil, newError("failed to read domain address").Base(err) |
|
|
|
|
} |
|
|
|
|
request.Address = net.DomainAddress(string(buffer.BytesFrom(-domainLength))) |
|
|
|
|
invalidRequest := false |
|
|
|
|
switch request.Command { |
|
|
|
|
case protocol.RequestCommandMux: |
|
|
|
|
request.Address = net.DomainAddress("v1.mux.cool") |
|
|
|
|
request.Port = 0 |
|
|
|
|
case protocol.RequestCommandTCP, protocol.RequestCommandUDP: |
|
|
|
|
if addr, port, err := readAddress(buffer, decryptor); err == nil { |
|
|
|
|
request.Address = addr |
|
|
|
|
request.Port = port |
|
|
|
|
} else { |
|
|
|
|
invalidRequest = true |
|
|
|
|
newError("failed to read address").Base(err).WriteToLog() |
|
|
|
|
} |
|
|
|
|
default: |
|
|
|
|
invalidRequest = true |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if invalidRequest { |
|
|
|
|
randomLen := dice.Roll(32) |
|
|
|
|
// Read random number of bytes for prevent detection.
|
|
|
|
|
buffer.AppendSupplier(buf.ReadFullFrom(decryptor, randomLen)) |
|
|
|
|
return nil, newError("invalid request") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if padingLen > 0 { |
|
|
|
|