diff --git a/proxy/vmess/protocol/io/validation.go b/proxy/vmess/protocol/io/validation.go new file mode 100644 index 00000000..ea47cf90 --- /dev/null +++ b/proxy/vmess/protocol/io/validation.go @@ -0,0 +1,75 @@ +package io + +import ( + "errors" + "hash/fnv" + "io" + + "github.com/v2ray/v2ray-core/common/alloc" + "github.com/v2ray/v2ray-core/transport" +) + +var ( + TruncatedPayload = errors.New("Truncated payload.") +) + +type ValidationReader struct { + reader io.Reader + buffer *alloc.Buffer +} + +func NewValidationReader(reader io.Reader) *ValidationReader { + return &ValidationReader{ + reader: reader, + buffer: alloc.NewLargeBuffer().Clear(), + } +} + +func (this *ValidationReader) Read(data []byte) (int, error) { + nBytes, err := this.reader.Read(data) + if err != nil { + return nBytes, err + } + nBytesActual := 0 + dataActual := data[:] + for { + payload, rest, err := parsePayload(data) + if err != nil { + return nBytesActual, err + } + copy(dataActual, payload) + nBytesActual += len(payload) + dataActual = dataActual[nBytesActual:] + if len(rest) == 0 { + break + } + data = rest + } + return nBytesActual, nil +} + +func parsePayload(data []byte) (payload []byte, rest []byte, err error) { + dataLen := len(data) + if dataLen < 6 { + err = TruncatedPayload + return + } + payloadLen := int(data[0])<<8 + int(data[1]) + if dataLen < payloadLen+6 { + err = TruncatedPayload + return + } + + payload = data[6 : 6+payloadLen] + rest = data[6+payloadLen:] + + fnv1a := fnv.New32a() + fnv1a.Write(payload) + actualHash := fnv1a.Sum32() + expectedHash := uint32(data[2])<<24 + uint32(data[3])<<16 + uint32(data[4])<<8 + uint32(data[5]) + if actualHash != expectedHash { + err = transport.CorruptedPacket + return + } + return +}