From b1107b9810a32623afff491068494dd9a825f220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A3=8E=E6=89=87=E6=BB=91=E7=BF=94=E7=BF=BC?= Date: Mon, 11 Aug 2025 09:37:46 +0800 Subject: [PATCH] Refine must2 and apply NewAesGcm() to all usage (#5011) * Refine must2 and apply NewAesGcm() to all usage * Remove unused package * Fix test --- app/dns/dnscommon_test.go | 26 ++++++++-------- common/common.go | 4 ++- common/crypto/aes.go | 6 ++-- common/crypto/auth_test.go | 14 ++------- proxy/dokodemo/dokodemo.go | 4 +-- proxy/shadowsocks/config.go | 7 +---- proxy/vmess/aead/encrypt.go | 43 +++----------------------- proxy/vmess/encoding/client.go | 8 ++--- proxy/vmess/encoding/server.go | 8 ++--- transport/internet/kcp/cryptreal.go | 6 ++-- transport/internet/reality/reality.go | 5 +-- transport/internet/splithttp/dialer.go | 2 +- 12 files changed, 37 insertions(+), 96 deletions(-) diff --git a/app/dns/dnscommon_test.go b/app/dns/dnscommon_test.go index bbaa9a21..7e06baaf 100644 --- a/app/dns/dnscommon_test.go +++ b/app/dns/dnscommon_test.go @@ -18,31 +18,31 @@ func Test_parseResponse(t *testing.T) { ans := new(dns.Msg) ans.Id = 0 - p = append(p, common.Must2(ans.Pack()).([]byte)) + p = append(p, common.Must2(ans.Pack())) p = append(p, []byte{}) ans = new(dns.Msg) ans.Id = 1 ans.Answer = append(ans.Answer, - common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR), - common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR), - common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")).(dns.RR), - common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")).(dns.RR), + common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")), + common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")), + common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")), + common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")), ) - p = append(p, common.Must2(ans.Pack()).([]byte)) + p = append(p, common.Must2(ans.Pack())) ans = new(dns.Msg) ans.Id = 2 ans.Answer = append(ans.Answer, - common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR), - common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR), - common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR), - common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")).(dns.RR), - common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")).(dns.RR), - common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")).(dns.RR), + common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")), + common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")), + common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")), + common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")), + common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")), + common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")), ) - p = append(p, common.Must2(ans.Pack()).([]byte)) + p = append(p, common.Must2(ans.Pack())) tests := []struct { name string diff --git a/common/common.go b/common/common.go index a09f6fbe..c3bfa944 100644 --- a/common/common.go +++ b/common/common.go @@ -23,7 +23,9 @@ func Must(err error) { } // Must2 panics if the second parameter is not nil, otherwise returns the first parameter. -func Must2(v interface{}, err error) interface{} { +// This is useful when function returned "sth, err" and avoid many "if err != nil" +// Internal usage only, if user input can cause err, it must be handled +func Must2[T any](v T, err error) T { Must(err) return v } diff --git a/common/crypto/aes.go b/common/crypto/aes.go index 3205a207..bbc974d9 100644 --- a/common/crypto/aes.go +++ b/common/crypto/aes.go @@ -32,9 +32,7 @@ func NewAesCTRStream(key []byte, iv []byte) cipher.Stream { // NewAesGcm creates a AEAD cipher based on AES-GCM. func NewAesGcm(key []byte) cipher.AEAD { - block, err := aes.NewCipher(key) - common.Must(err) - aead, err := cipher.NewGCM(block) - common.Must(err) + block := common.Must2(aes.NewCipher(key)) + aead := common.Must2(cipher.NewGCM(block)) return aead } diff --git a/common/crypto/auth_test.go b/common/crypto/auth_test.go index 6af8e0ad..7dc5509e 100644 --- a/common/crypto/auth_test.go +++ b/common/crypto/auth_test.go @@ -2,8 +2,6 @@ package crypto_test import ( "bytes" - "crypto/aes" - "crypto/cipher" "crypto/rand" "io" "testing" @@ -18,11 +16,8 @@ import ( func TestAuthenticationReaderWriter(t *testing.T) { key := make([]byte, 16) rand.Read(key) - block, err := aes.NewCipher(key) - common.Must(err) - aead, err := cipher.NewGCM(block) - common.Must(err) + aead := NewAesGcm(key) const payloadSize = 1024 * 80 rawPayload := make([]byte, payloadSize) @@ -71,7 +66,7 @@ func TestAuthenticationReaderWriter(t *testing.T) { t.Error(r) } - _, err = reader.ReadMultiBuffer() + _, err := reader.ReadMultiBuffer() if err != io.EOF { t.Error("error: ", err) } @@ -80,11 +75,8 @@ func TestAuthenticationReaderWriter(t *testing.T) { func TestAuthenticationReaderWriterPacket(t *testing.T) { key := make([]byte, 16) common.Must2(rand.Read(key)) - block, err := aes.NewCipher(key) - common.Must(err) - aead, err := cipher.NewGCM(block) - common.Must(err) + aead := NewAesGcm(key) cache := buf.New() iv := make([]byte, 12) diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 2a467d9b..2d553300 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -91,7 +91,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st } } if dest.Port == 0 { - dest.Port = net.Port(common.Must2(strconv.Atoi(port)).(int)) + dest.Port = net.Port(common.Must2(strconv.Atoi(port))) } if d.portMap != nil && d.portMap[port] != "" { h, p, _ := net.SplitHostPort(d.portMap[port]) @@ -99,7 +99,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st dest.Address = net.ParseAddress(h) } if len(p) > 0 { - dest.Port = net.Port(common.Must2(strconv.Atoi(p)).(int)) + dest.Port = net.Port(common.Must2(strconv.Atoi(p))) } } } diff --git a/proxy/shadowsocks/config.go b/proxy/shadowsocks/config.go index a6d2ef87..39c397fa 100644 --- a/proxy/shadowsocks/config.go +++ b/proxy/shadowsocks/config.go @@ -2,7 +2,6 @@ package shadowsocks import ( "bytes" - "crypto/aes" "crypto/cipher" "crypto/md5" "crypto/sha1" @@ -58,11 +57,7 @@ func (a *MemoryAccount) CheckIV(iv []byte) error { } func createAesGcm(key []byte) cipher.AEAD { - block, err := aes.NewCipher(key) - common.Must(err) - gcm, err := cipher.NewGCM(block) - common.Must(err) - return gcm + return crypto.NewAesGcm(key) } func createChaCha20Poly1305(key []byte) cipher.AEAD { diff --git a/proxy/vmess/aead/encrypt.go b/proxy/vmess/aead/encrypt.go index 8995f2ea..a3faec7e 100644 --- a/proxy/vmess/aead/encrypt.go +++ b/proxy/vmess/aead/encrypt.go @@ -2,14 +2,13 @@ package aead import ( "bytes" - "crypto/aes" - "crypto/cipher" "crypto/rand" "encoding/binary" "io" "time" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/crypto" ) func SealVMessAEADHeader(key [16]byte, data []byte) []byte { @@ -34,15 +33,7 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte { payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12] - payloadHeaderLengthAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey) - if err != nil { - panic(err.Error()) - } - - payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderLengthAEADAESBlock) - if err != nil { - panic(err.Error()) - } + payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderLengthAEADKey) payloadHeaderLengthAEADEncrypted = payloadHeaderAEAD.Seal(nil, payloadHeaderLengthAEADNonce, aeadPayloadLengthSerializedByte, generatedAuthID[:]) } @@ -54,15 +45,7 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte { payloadHeaderAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12] - payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey) - if err != nil { - panic(err.Error()) - } - - payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock) - if err != nil { - panic(err.Error()) - } + payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderAEADKey) payloadHeaderAEADEncrypted = payloadHeaderAEAD.Seal(nil, payloadHeaderAEADNonce, data, generatedAuthID[:]) } @@ -104,15 +87,7 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte, payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(authid[:]), string(nonce[:]))[:12] - payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey) - if err != nil { - panic(err.Error()) - } - - payloadHeaderLengthAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock) - if err != nil { - panic(err.Error()) - } + payloadHeaderLengthAEAD := crypto.NewAesGcm(payloadHeaderLengthAEADKey) decryptedAEADHeaderLengthPayload, erropenAEAD := payloadHeaderLengthAEAD.Open(nil, payloadHeaderLengthAEADNonce, payloadHeaderLengthAEADEncrypted[:], authid[:]) @@ -145,15 +120,7 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte, return nil, false, bytesRead, err } - payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey) - if err != nil { - panic(err.Error()) - } - - payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock) - if err != nil { - panic(err.Error()) - } + payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderAEADKey) decryptedAEADHeaderPayload, erropenAEAD := payloadHeaderAEAD.Open(nil, payloadHeaderAEADNonce, payloadHeaderAEADEncrypted, authid[:]) diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index d678646b..d48eddd7 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -3,8 +3,6 @@ package encoding import ( "bytes" "context" - "crypto/aes" - "crypto/cipher" "crypto/rand" "crypto/sha256" "encoding/binary" @@ -182,8 +180,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey) aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12] - aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block) - aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD) + aeadResponseHeaderLengthEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderLengthEncryptionKey) var aeadEncryptedResponseHeaderLength [18]byte var decryptedResponseHeaderLength int @@ -205,8 +202,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey) aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12] - aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block) - aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD) + aeadResponseHeaderPayloadEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderPayloadEncryptionKey) encryptedResponseHeaderBuffer := make([]byte, decryptedResponseHeaderLength+16) diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 99e7abc9..3a11c747 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -2,8 +2,6 @@ package encoding import ( "bytes" - "crypto/aes" - "crypto/cipher" "crypto/sha256" "encoding/binary" "hash/fnv" @@ -350,8 +348,7 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey) aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12] - aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block) - aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD) + aeadResponseHeaderLengthEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderLengthEncryptionKey) aeadResponseHeaderLengthEncryptionBuffer := bytes.NewBuffer(nil) @@ -365,8 +362,7 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey) aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12] - aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block) - aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD) + aeadResponseHeaderPayloadEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderPayloadEncryptionKey) aeadEncryptedHeaderPayload := aeadResponseHeaderPayloadEncryptionAEAD.Seal(nil, aeadResponseHeaderPayloadEncryptionIV, aeadEncryptedHeaderBuffer.Bytes(), nil) common.Must2(io.Copy(writer, bytes.NewReader(aeadEncryptedHeaderPayload))) diff --git a/transport/internet/kcp/cryptreal.go b/transport/internet/kcp/cryptreal.go index e86bba98..391d714e 100644 --- a/transport/internet/kcp/cryptreal.go +++ b/transport/internet/kcp/cryptreal.go @@ -1,15 +1,13 @@ package kcp import ( - "crypto/aes" "crypto/cipher" "crypto/sha256" - "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/crypto" ) func NewAEADAESGCMBasedOnSeed(seed string) cipher.AEAD { hashedSeed := sha256.Sum256([]byte(seed)) - aesBlock := common.Must2(aes.NewCipher(hashedSeed[:16])).(cipher.Block) - return common.Must2(cipher.NewGCM(aesBlock)).(cipher.AEAD) + return crypto.NewAesGcm(hashedSeed[:]) } diff --git a/transport/internet/reality/reality.go b/transport/internet/reality/reality.go index dca4e951..20f13ba5 100644 --- a/transport/internet/reality/reality.go +++ b/transport/internet/reality/reality.go @@ -3,8 +3,6 @@ package reality import ( "bytes" "context" - "crypto/aes" - "crypto/cipher" "crypto/ecdh" "crypto/ed25519" "crypto/hmac" @@ -169,8 +167,7 @@ func UClient(c net.Conn, config *Config, ctx context.Context, dest net.Destinati if _, err := hkdf.New(sha256.New, uConn.AuthKey, hello.Random[:20], []byte("REALITY")).Read(uConn.AuthKey); err != nil { return nil, err } - block, _ := aes.NewCipher(uConn.AuthKey) - aead, _ := cipher.NewGCM(block) + aead := crypto.NewAesGcm(uConn.AuthKey) if config.Show { fmt.Printf("REALITY localAddr: %v\tuConn.AuthKey[:16]: %v\tAEAD: %T\n", localAddr, uConn.AuthKey[:16], aead) } diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index dee5f486..e409e61c 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -297,7 +297,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me if transportConfiguration.DownloadSettings != nil { globalDialerAccess.Lock() if streamSettings.DownloadSettings == nil { - streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings)).(*internet.MemoryStreamConfig) + streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings)) if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.Penetrate { streamSettings.DownloadSettings.SocketSettings = streamSettings.SocketSettings }