fix(proxy): removed the udp payload length check when encryption is disabled

pull/2464/head
cty123 1 year ago committed by yuhan6665
parent f67167bb3b
commit a343d68944

@ -4,6 +4,7 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"errors"
"hash/crc32" "hash/crc32"
"io" "io"
@ -236,19 +237,26 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
} }
func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) { func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
bs := payload.Bytes() rawPayload := payload.Bytes()
if len(bs) <= 32 { user, _, d, _, err := validator.Get(rawPayload, protocol.RequestCommandUDP)
return nil, nil, newError("len(bs) <= 32")
}
user, _, d, _, err := validator.Get(bs, protocol.RequestCommandUDP) if errors.Is(err, ErrIVNotUnique) {
switch err {
case ErrIVNotUnique:
return nil, nil, newError("failed iv check").Base(err) return nil, nil, newError("failed iv check").Base(err)
case ErrNotFound: }
if errors.Is(err, ErrNotFound) {
return nil, nil, newError("failed to match an user").Base(err) return nil, nil, newError("failed to match an user").Base(err)
default: }
account := user.Account.(*MemoryAccount)
if err != nil {
return nil, nil, newError("unexpected error").Base(err)
}
account, ok := user.Account.(*MemoryAccount)
if !ok {
return nil, nil, newError("expected MemoryAccount returned from validator")
}
if account.Cipher.IsAEAD() { if account.Cipher.IsAEAD() {
payload.Clear() payload.Clear()
payload.Write(d) payload.Write(d)
@ -261,13 +269,6 @@ func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.Reque
return nil, nil, newError("failed to decrypt UDP payload").Base(err) return nil, nil, newError("failed to decrypt UDP payload").Base(err)
} }
} }
}
request := &protocol.RequestHeader{
Version: Version,
User: user,
Command: protocol.RequestCommandUDP,
}
payload.SetByte(0, payload.Byte(0)&0x0F) payload.SetByte(0, payload.Byte(0)&0x0F)
@ -276,8 +277,13 @@ func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.Reque
return nil, nil, newError("failed to parse address").Base(err) return nil, nil, newError("failed to parse address").Base(err)
} }
request.Address = addr request := &protocol.RequestHeader{
request.Port = port Version: Version,
User: user,
Command: protocol.RequestCommandUDP,
Address: addr,
Port: port,
}
return request, payload, nil return request, payload, nil
} }

@ -23,8 +23,9 @@ func equalRequestHeader(x, y *protocol.RequestHeader) bool {
})) }))
} }
func TestUDPEncoding(t *testing.T) { func TestUDPEncodingDecoding(t *testing.T) {
request := &protocol.RequestHeader{ testRequests := []protocol.RequestHeader{
{
Version: Version, Version: Version,
Command: protocol.RequestCommandUDP, Command: protocol.RequestCommandUDP,
Address: net.LocalHostIP, Address: net.LocalHostIP,
@ -36,11 +37,26 @@ func TestUDPEncoding(t *testing.T) {
CipherType: CipherType_AES_128_GCM, CipherType: CipherType_AES_128_GCM,
}), }),
}, },
},
{
Version: Version,
Command: protocol.RequestCommandUDP,
Address: net.LocalHostIP,
Port: 1234,
User: &protocol.MemoryUser{
Email: "love@example.com",
Account: toAccount(&Account{
Password: "123",
CipherType: CipherType_NONE,
}),
},
},
} }
for _, request := range testRequests {
data := buf.New() data := buf.New()
common.Must2(data.WriteString("test string")) common.Must2(data.WriteString("test string"))
encodedData, err := EncodeUDPPacket(request, data.Bytes()) encodedData, err := EncodeUDPPacket(&request, data.Bytes())
common.Must(err) common.Must(err)
validator := new(Validator) validator := new(Validator)
@ -52,10 +68,37 @@ func TestUDPEncoding(t *testing.T) {
t.Error("data: ", r) t.Error("data: ", r)
} }
if equalRequestHeader(decodedRequest, request) == false { if equalRequestHeader(decodedRequest, &request) == false {
t.Error("different request") t.Error("different request")
} }
} }
}
func TestUDPDecodingWithPayloadTooShort(t *testing.T) {
testAccounts := []protocol.Account{
toAccount(&Account{
Password: "password",
CipherType: CipherType_AES_128_GCM,
}),
toAccount(&Account{
Password: "password",
CipherType: CipherType_NONE,
}),
}
for _, account := range testAccounts {
data := buf.New()
data.WriteString("short payload")
validator := new(Validator)
validator.Add(&protocol.MemoryUser{
Account: account,
})
_, _, err := DecodeUDPPacket(validator, data)
if err == nil {
t.Fatal("expected error")
}
}
}
func TestTCPRequest(t *testing.T) { func TestTCPRequest(t *testing.T) {
cases := []struct { cases := []struct {

@ -80,6 +80,11 @@ func (v *Validator) Get(bs []byte, command protocol.RequestCommand) (u *protocol
for _, user := range v.users { for _, user := range v.users {
if account := user.Account.(*MemoryAccount); account.Cipher.IsAEAD() { if account := user.Account.(*MemoryAccount); account.Cipher.IsAEAD() {
// AEAD payload decoding requires the payload to be over 32 bytes
if len(bs) < 32 {
continue
}
aeadCipher := account.Cipher.(*AEADCipher) aeadCipher := account.Cipher.(*AEADCipher)
ivLen = aeadCipher.IVSize() ivLen = aeadCipher.IVSize()
iv := bs[:ivLen] iv := bs[:ivLen]

Loading…
Cancel
Save