mirror of https://github.com/v2ray/v2ray-core
				
				
				
			rewrite vmess encoding using buf
							parent
							
								
									5901192a58
								
							
						
					
					
						commit
						02685094d3
					
				|  | @ -71,57 +71,61 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ | |||
| 	common.Must2(idHash.Write(timestamp.Bytes(nil))) | ||||
| 	common.Must2(writer.Write(idHash.Sum(nil))) | ||||
| 
 | ||||
| 	buffer := make([]byte, 0, 512) | ||||
| 	buffer = append(buffer, Version) | ||||
| 	buffer = append(buffer, c.requestBodyIV...) | ||||
| 	buffer = append(buffer, c.requestBodyKey...) | ||||
| 	buffer = append(buffer, c.responseHeader, byte(header.Option)) | ||||
| 	buffer := buf.New() | ||||
| 	defer buffer.Release() | ||||
| 
 | ||||
| 	buffer.AppendBytes(Version) | ||||
| 	buffer.Append(c.requestBodyIV) | ||||
| 	buffer.Append(c.requestBodyKey) | ||||
| 	buffer.AppendBytes(c.responseHeader, byte(header.Option)) | ||||
| 
 | ||||
| 	padingLen := dice.Roll(16) | ||||
| 	if header.Security.Is(protocol.SecurityType_LEGACY) { | ||||
| 		// Disable padding in legacy mode for a smooth transition.
 | ||||
| 		padingLen = 0 | ||||
| 	} | ||||
| 	security := byte(padingLen<<4) | byte(header.Security) | ||||
| 	buffer = append(buffer, security, byte(0), byte(header.Command)) | ||||
| 	buffer.AppendBytes(security, byte(0), byte(header.Command)) | ||||
| 
 | ||||
| 	if header.Command != protocol.RequestCommandMux { | ||||
| 		buffer = header.Port.Bytes(buffer) | ||||
| 		common.Must(buffer.AppendSupplier(serial.WriteUint16(header.Port.Value()))) | ||||
| 
 | ||||
| 		switch header.Address.Family() { | ||||
| 		case net.AddressFamilyIPv4: | ||||
| 			buffer = append(buffer, byte(protocol.AddressTypeIPv4)) | ||||
| 			buffer = append(buffer, header.Address.IP()...) | ||||
| 			buffer.AppendBytes(byte(protocol.AddressTypeIPv4)) | ||||
| 			buffer.Append(header.Address.IP()) | ||||
| 		case net.AddressFamilyIPv6: | ||||
| 			buffer = append(buffer, byte(protocol.AddressTypeIPv6)) | ||||
| 			buffer = append(buffer, header.Address.IP()...) | ||||
| 			buffer.AppendBytes(byte(protocol.AddressTypeIPv6)) | ||||
| 			buffer.Append(header.Address.IP()) | ||||
| 		case net.AddressFamilyDomain: | ||||
| 			domain := header.Address.Domain() | ||||
| 			if protocol.IsDomainTooLong(domain) { | ||||
| 				return newError("long domain not supported: ", domain) | ||||
| 			} | ||||
| 			nDomain := len(domain) | ||||
| 			buffer = append(buffer, byte(protocol.AddressTypeDomain), byte(nDomain)) | ||||
| 			buffer = append(buffer, domain...) | ||||
| 			buffer.AppendBytes(byte(protocol.AddressTypeDomain), byte(nDomain)) | ||||
| 			common.Must(buffer.AppendSupplier(serial.WriteString(domain))) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if padingLen > 0 { | ||||
| 		pading := make([]byte, padingLen) | ||||
| 		common.Must2(rand.Read(pading)) | ||||
| 		buffer = append(buffer, pading...) | ||||
| 		common.Must(buffer.AppendSupplier(buf.ReadFullFrom(rand.Reader, padingLen))) | ||||
| 	} | ||||
| 
 | ||||
| 	fnv1a := fnv.New32a() | ||||
| 	common.Must2(fnv1a.Write(buffer)) | ||||
| 	common.Must2(fnv1a.Write(buffer.Bytes())) | ||||
| 
 | ||||
| 	buffer = fnv1a.Sum(buffer) | ||||
| 	common.Must(buffer.AppendSupplier(func(b []byte) (int, error) { | ||||
| 		fnv1a.Sum(b[:0]) | ||||
| 		return fnv1a.Size(), nil | ||||
| 	})) | ||||
| 
 | ||||
| 	timestampHash := md5.New() | ||||
| 	common.Must2(timestampHash.Write(hashTimestamp(timestamp))) | ||||
| 	iv := timestampHash.Sum(nil) | ||||
| 	aesStream := crypto.NewAesEncryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv) | ||||
| 	aesStream.XORKeyStream(buffer, buffer) | ||||
| 	common.Must2(writer.Write(buffer)) | ||||
| 	aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) | ||||
| 	common.Must2(writer.Write(buffer.Bytes())) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
|  | @ -197,32 +201,31 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon | |||
| 	aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey, c.responseBodyIV) | ||||
| 	c.responseReader = crypto.NewCryptionReader(aesStream, reader) | ||||
| 
 | ||||
| 	buffer := make([]byte, 256) | ||||
| 	buffer := buf.New() | ||||
| 	defer buffer.Release() | ||||
| 
 | ||||
| 	_, err := io.ReadFull(c.responseReader, buffer[:4]) | ||||
| 	if err != nil { | ||||
| 	if err := buffer.AppendSupplier(buf.ReadFullFrom(c.responseReader, 4)); err != nil { | ||||
| 		log.Trace(newError("failed to read response header").Base(err)) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if buffer[0] != c.responseHeader { | ||||
| 		return nil, newError("unexpected response header. Expecting ", int(c.responseHeader), " but actually ", int(buffer[0])) | ||||
| 	if buffer.Byte(0) != c.responseHeader { | ||||
| 		return nil, newError("unexpected response header. Expecting ", int(c.responseHeader), " but actually ", int(buffer.Byte(0))) | ||||
| 	} | ||||
| 
 | ||||
| 	header := &protocol.ResponseHeader{ | ||||
| 		Option: bitmask.Byte(buffer[1]), | ||||
| 		Option: bitmask.Byte(buffer.Byte(1)), | ||||
| 	} | ||||
| 
 | ||||
| 	if buffer[2] != 0 { | ||||
| 		cmdID := buffer[2] | ||||
| 		dataLen := int(buffer[3]) | ||||
| 		_, err := io.ReadFull(c.responseReader, buffer[:dataLen]) | ||||
| 		if err != nil { | ||||
| 	if buffer.Byte(2) != 0 { | ||||
| 		cmdID := buffer.Byte(2) | ||||
| 		dataLen := int(buffer.Byte(3)) | ||||
| 
 | ||||
| 		if err := buffer.Reset(buf.ReadFullFrom(c.responseReader, dataLen)); err != nil { | ||||
| 			log.Trace(newError("failed to read response command").Base(err)) | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		data := buffer[:dataLen] | ||||
| 		command, err := UnmarshalCommand(cmdID, data) | ||||
| 		command, err := UnmarshalCommand(cmdID, buffer.Bytes()) | ||||
| 		if err == nil { | ||||
| 			header.Command = command | ||||
| 		} | ||||
|  |  | |||
|  | @ -115,14 +115,14 @@ func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionH | |||
| } | ||||
| 
 | ||||
| func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) { | ||||
| 	buffer := make([]byte, 512) | ||||
| 	buffer := buf.New() | ||||
| 	defer buffer.Release() | ||||
| 
 | ||||
| 	_, err := io.ReadFull(reader, buffer[:protocol.IDBytesLen]) | ||||
| 	if err != nil { | ||||
| 	if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, protocol.IDBytesLen)); err != nil { | ||||
| 		return nil, newError("failed to read request header").Base(err) | ||||
| 	} | ||||
| 
 | ||||
| 	user, timestamp, valid := s.userValidator.Get(buffer[:protocol.IDBytesLen]) | ||||
| 	user, timestamp, valid := s.userValidator.Get(buffer.Bytes()) | ||||
| 	if !valid { | ||||
| 		return nil, newError("invalid user") | ||||
| 	} | ||||
|  | @ -139,23 +139,21 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request | |||
| 	aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv) | ||||
| 	decryptor := crypto.NewCryptionReader(aesStream, reader) | ||||
| 
 | ||||
| 	nBytes, err := io.ReadFull(decryptor, buffer[:41]) | ||||
| 	if err != nil { | ||||
| 	if err := buffer.Reset(buf.ReadFullFrom(decryptor, 41)); err != nil { | ||||
| 		return nil, newError("failed to read request header").Base(err) | ||||
| 	} | ||||
| 	bufferLen := nBytes | ||||
| 
 | ||||
| 	request := &protocol.RequestHeader{ | ||||
| 		User:    user, | ||||
| 		Version: buffer[0], | ||||
| 		Version: buffer.Byte(0), | ||||
| 	} | ||||
| 
 | ||||
| 	if request.Version != Version { | ||||
| 		return nil, newError("invalid protocol version ", request.Version) | ||||
| 	} | ||||
| 
 | ||||
| 	s.requestBodyIV = append([]byte(nil), buffer[1:17]...)   // 16 bytes
 | ||||
| 	s.requestBodyKey = append([]byte(nil), buffer[17:33]...) // 16 bytes
 | ||||
| 	s.requestBodyIV = append([]byte(nil), buffer.BytesRange(1, 17)...)   // 16 bytes
 | ||||
| 	s.requestBodyKey = append([]byte(nil), buffer.BytesRange(17, 33)...) // 16 bytes
 | ||||
| 	var sid sessionId | ||||
| 	copy(sid.user[:], vmessAccount.ID.Bytes()) | ||||
| 	copy(sid.key[:], s.requestBodyKey) | ||||
|  | @ -165,66 +163,56 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request | |||
| 	} | ||||
| 	s.sessionHistory.add(sid) | ||||
| 
 | ||||
| 	s.responseHeader = buffer[33]             // 1 byte
 | ||||
| 	request.Option = bitmask.Byte(buffer[34]) // 1 byte
 | ||||
| 	padingLen := int(buffer[35] >> 4) | ||||
| 	request.Security = protocol.NormSecurity(protocol.Security(buffer[35] & 0x0F)) | ||||
| 	s.responseHeader = buffer.Byte(33)             // 1 byte
 | ||||
| 	request.Option = bitmask.Byte(buffer.Byte(34)) // 1 byte
 | ||||
| 	padingLen := int(buffer.Byte(35) >> 4) | ||||
| 	request.Security = protocol.NormSecurity(protocol.Security(buffer.Byte(35) & 0x0F)) | ||||
| 	// 1 bytes reserved
 | ||||
| 	request.Command = protocol.RequestCommand(buffer[37]) | ||||
| 	request.Command = protocol.RequestCommand(buffer.Byte(37)) | ||||
| 
 | ||||
| 	if request.Command != protocol.RequestCommandMux { | ||||
| 		request.Port = net.PortFromBytes(buffer[38:40]) | ||||
| 		request.Port = net.PortFromBytes(buffer.BytesRange(38, 40)) | ||||
| 
 | ||||
| 		switch protocol.AddressType(buffer[40]) { | ||||
| 		switch protocol.AddressType(buffer.Byte(40)) { | ||||
| 		case protocol.AddressTypeIPv4: | ||||
| 			_, err = io.ReadFull(decryptor, buffer[41:45]) // 4 bytes
 | ||||
| 			bufferLen += 4 | ||||
| 			if err != nil { | ||||
| 			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil { | ||||
| 				return nil, newError("failed to read IPv4 address").Base(err) | ||||
| 			} | ||||
| 			request.Address = net.IPAddress(buffer[41:45]) | ||||
| 			request.Address = net.IPAddress(buffer.BytesFrom(-4)) | ||||
| 		case protocol.AddressTypeIPv6: | ||||
| 			_, err = io.ReadFull(decryptor, buffer[41:57]) // 16 bytes
 | ||||
| 			bufferLen += 16 | ||||
| 			if err != nil { | ||||
| 			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 16)); err != nil { | ||||
| 				return nil, newError("failed to read IPv6 address").Base(err) | ||||
| 			} | ||||
| 			request.Address = net.IPAddress(buffer[41:57]) | ||||
| 			request.Address = net.IPAddress(buffer.BytesFrom(-16)) | ||||
| 		case protocol.AddressTypeDomain: | ||||
| 			_, err = io.ReadFull(decryptor, buffer[41:42]) | ||||
| 			if err != nil { | ||||
| 			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 1)); err != nil { | ||||
| 				return nil, newError("failed to read domain address").Base(err) | ||||
| 			} | ||||
| 			domainLength := int(buffer[41]) | ||||
| 			domainLength := int(buffer.Byte(buffer.Len() - 1)) | ||||
| 			if domainLength == 0 { | ||||
| 				return nil, newError("zero length domain").Base(err) | ||||
| 			} | ||||
| 			_, err = io.ReadFull(decryptor, buffer[42:42+domainLength]) | ||||
| 			if err != nil { | ||||
| 			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, domainLength)); err != nil { | ||||
| 				return nil, newError("failed to read domain address").Base(err) | ||||
| 			} | ||||
| 			bufferLen += 1 + domainLength | ||||
| 			request.Address = net.DomainAddress(string(buffer[42 : 42+domainLength])) | ||||
| 			request.Address = net.DomainAddress(string(buffer.BytesFrom(-domainLength))) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if padingLen > 0 { | ||||
| 		_, err = io.ReadFull(decryptor, buffer[bufferLen:bufferLen+padingLen]) | ||||
| 		if err != nil { | ||||
| 		if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, padingLen)); err != nil { | ||||
| 			return nil, newError("failed to read padding").Base(err) | ||||
| 		} | ||||
| 		bufferLen += padingLen | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = io.ReadFull(decryptor, buffer[bufferLen:bufferLen+4]) | ||||
| 	if err != nil { | ||||
| 	if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil { | ||||
| 		return nil, newError("failed to read checksum").Base(err) | ||||
| 	} | ||||
| 
 | ||||
| 	fnv1a := fnv.New32a() | ||||
| 	common.Must2(fnv1a.Write(buffer[:bufferLen])) | ||||
| 	common.Must2(fnv1a.Write(buffer.BytesTo(-4))) | ||||
| 	actualHash := fnv1a.Sum32() | ||||
| 	expectedHash := serial.BytesToUint32(buffer[bufferLen : bufferLen+4]) | ||||
| 	expectedHash := serial.BytesToUint32(buffer.BytesFrom(-4)) | ||||
| 
 | ||||
| 	if actualHash != expectedHash { | ||||
| 		return nil, newError("invalid auth") | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Darien Raymond
						Darien Raymond