Refine must2 and apply NewAesGcm() to all usage (#5011)

* Refine must2 and apply NewAesGcm() to all usage

* Remove unused package

* Fix test
pull/4988/merge
风扇滑翔翼 2025-08-11 09:37:46 +08:00 committed by GitHub
parent 0cceea75da
commit b1107b9810
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 37 additions and 96 deletions

View File

@ -18,31 +18,31 @@ func Test_parseResponse(t *testing.T) {
ans := new(dns.Msg) ans := new(dns.Msg)
ans.Id = 0 ans.Id = 0
p = append(p, common.Must2(ans.Pack()).([]byte)) p = append(p, common.Must2(ans.Pack()))
p = append(p, []byte{}) p = append(p, []byte{})
ans = new(dns.Msg) ans = new(dns.Msg)
ans.Id = 1 ans.Id = 1
ans.Answer = append(ans.Answer, ans.Answer = append(ans.Answer,
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR), common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")),
common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR), common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")),
common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")).(dns.RR), common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")),
common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")).(dns.RR), common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")),
) )
p = append(p, common.Must2(ans.Pack()).([]byte)) p = append(p, common.Must2(ans.Pack()))
ans = new(dns.Msg) ans = new(dns.Msg)
ans.Id = 2 ans.Id = 2
ans.Answer = append(ans.Answer, ans.Answer = append(ans.Answer,
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR), common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")),
common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR), common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")),
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR), common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")),
common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")).(dns.RR), common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")),
common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")).(dns.RR), common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")),
common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")).(dns.RR), common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")),
) )
p = append(p, common.Must2(ans.Pack()).([]byte)) p = append(p, common.Must2(ans.Pack()))
tests := []struct { tests := []struct {
name string name string

View File

@ -23,7 +23,9 @@ func Must(err error) {
} }
// Must2 panics if the second parameter is not nil, otherwise returns the first parameter. // Must2 panics if the second parameter is not nil, otherwise returns the first parameter.
func Must2(v interface{}, err error) interface{} { // This is useful when function returned "sth, err" and avoid many "if err != nil"
// Internal usage only, if user input can cause err, it must be handled
func Must2[T any](v T, err error) T {
Must(err) Must(err)
return v return v
} }

View File

@ -32,9 +32,7 @@ func NewAesCTRStream(key []byte, iv []byte) cipher.Stream {
// NewAesGcm creates a AEAD cipher based on AES-GCM. // NewAesGcm creates a AEAD cipher based on AES-GCM.
func NewAesGcm(key []byte) cipher.AEAD { func NewAesGcm(key []byte) cipher.AEAD {
block, err := aes.NewCipher(key) block := common.Must2(aes.NewCipher(key))
common.Must(err) aead := common.Must2(cipher.NewGCM(block))
aead, err := cipher.NewGCM(block)
common.Must(err)
return aead return aead
} }

View File

@ -2,8 +2,6 @@ package crypto_test
import ( import (
"bytes" "bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand" "crypto/rand"
"io" "io"
"testing" "testing"
@ -18,11 +16,8 @@ import (
func TestAuthenticationReaderWriter(t *testing.T) { func TestAuthenticationReaderWriter(t *testing.T) {
key := make([]byte, 16) key := make([]byte, 16)
rand.Read(key) rand.Read(key)
block, err := aes.NewCipher(key)
common.Must(err)
aead, err := cipher.NewGCM(block) aead := NewAesGcm(key)
common.Must(err)
const payloadSize = 1024 * 80 const payloadSize = 1024 * 80
rawPayload := make([]byte, payloadSize) rawPayload := make([]byte, payloadSize)
@ -71,7 +66,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
t.Error(r) t.Error(r)
} }
_, err = reader.ReadMultiBuffer() _, err := reader.ReadMultiBuffer()
if err != io.EOF { if err != io.EOF {
t.Error("error: ", err) t.Error("error: ", err)
} }
@ -80,11 +75,8 @@ func TestAuthenticationReaderWriter(t *testing.T) {
func TestAuthenticationReaderWriterPacket(t *testing.T) { func TestAuthenticationReaderWriterPacket(t *testing.T) {
key := make([]byte, 16) key := make([]byte, 16)
common.Must2(rand.Read(key)) common.Must2(rand.Read(key))
block, err := aes.NewCipher(key)
common.Must(err)
aead, err := cipher.NewGCM(block) aead := NewAesGcm(key)
common.Must(err)
cache := buf.New() cache := buf.New()
iv := make([]byte, 12) iv := make([]byte, 12)

View File

@ -91,7 +91,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
} }
} }
if dest.Port == 0 { if dest.Port == 0 {
dest.Port = net.Port(common.Must2(strconv.Atoi(port)).(int)) dest.Port = net.Port(common.Must2(strconv.Atoi(port)))
} }
if d.portMap != nil && d.portMap[port] != "" { if d.portMap != nil && d.portMap[port] != "" {
h, p, _ := net.SplitHostPort(d.portMap[port]) h, p, _ := net.SplitHostPort(d.portMap[port])
@ -99,7 +99,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
dest.Address = net.ParseAddress(h) dest.Address = net.ParseAddress(h)
} }
if len(p) > 0 { if len(p) > 0 {
dest.Port = net.Port(common.Must2(strconv.Atoi(p)).(int)) dest.Port = net.Port(common.Must2(strconv.Atoi(p)))
} }
} }
} }

View File

@ -2,7 +2,6 @@ package shadowsocks
import ( import (
"bytes" "bytes"
"crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/md5" "crypto/md5"
"crypto/sha1" "crypto/sha1"
@ -58,11 +57,7 @@ func (a *MemoryAccount) CheckIV(iv []byte) error {
} }
func createAesGcm(key []byte) cipher.AEAD { func createAesGcm(key []byte) cipher.AEAD {
block, err := aes.NewCipher(key) return crypto.NewAesGcm(key)
common.Must(err)
gcm, err := cipher.NewGCM(block)
common.Must(err)
return gcm
} }
func createChaCha20Poly1305(key []byte) cipher.AEAD { func createChaCha20Poly1305(key []byte) cipher.AEAD {

View File

@ -2,14 +2,13 @@ package aead
import ( import (
"bytes" "bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"io" "io"
"time" "time"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/crypto"
) )
func SealVMessAEADHeader(key [16]byte, data []byte) []byte { func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
@ -34,15 +33,7 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12] payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12]
payloadHeaderLengthAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey) payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderLengthAEADKey)
if err != nil {
panic(err.Error())
}
payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderLengthAEADAESBlock)
if err != nil {
panic(err.Error())
}
payloadHeaderLengthAEADEncrypted = payloadHeaderAEAD.Seal(nil, payloadHeaderLengthAEADNonce, aeadPayloadLengthSerializedByte, generatedAuthID[:]) payloadHeaderLengthAEADEncrypted = payloadHeaderAEAD.Seal(nil, payloadHeaderLengthAEADNonce, aeadPayloadLengthSerializedByte, generatedAuthID[:])
} }
@ -54,15 +45,7 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
payloadHeaderAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12] payloadHeaderAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12]
payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey) payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderAEADKey)
if err != nil {
panic(err.Error())
}
payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock)
if err != nil {
panic(err.Error())
}
payloadHeaderAEADEncrypted = payloadHeaderAEAD.Seal(nil, payloadHeaderAEADNonce, data, generatedAuthID[:]) payloadHeaderAEADEncrypted = payloadHeaderAEAD.Seal(nil, payloadHeaderAEADNonce, data, generatedAuthID[:])
} }
@ -104,15 +87,7 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte,
payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(authid[:]), string(nonce[:]))[:12] payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(authid[:]), string(nonce[:]))[:12]
payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey) payloadHeaderLengthAEAD := crypto.NewAesGcm(payloadHeaderLengthAEADKey)
if err != nil {
panic(err.Error())
}
payloadHeaderLengthAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock)
if err != nil {
panic(err.Error())
}
decryptedAEADHeaderLengthPayload, erropenAEAD := payloadHeaderLengthAEAD.Open(nil, payloadHeaderLengthAEADNonce, payloadHeaderLengthAEADEncrypted[:], authid[:]) decryptedAEADHeaderLengthPayload, erropenAEAD := payloadHeaderLengthAEAD.Open(nil, payloadHeaderLengthAEADNonce, payloadHeaderLengthAEADEncrypted[:], authid[:])
@ -145,15 +120,7 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte,
return nil, false, bytesRead, err return nil, false, bytesRead, err
} }
payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey) payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderAEADKey)
if err != nil {
panic(err.Error())
}
payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock)
if err != nil {
panic(err.Error())
}
decryptedAEADHeaderPayload, erropenAEAD := payloadHeaderAEAD.Open(nil, payloadHeaderAEADNonce, payloadHeaderAEADEncrypted, authid[:]) decryptedAEADHeaderPayload, erropenAEAD := payloadHeaderAEAD.Open(nil, payloadHeaderAEADNonce, payloadHeaderAEADEncrypted, authid[:])

View File

@ -3,8 +3,6 @@ package encoding
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/aes"
"crypto/cipher"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
@ -182,8 +180,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey) aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey)
aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12] aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12]
aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block) aeadResponseHeaderLengthEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderLengthEncryptionKey)
aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD)
var aeadEncryptedResponseHeaderLength [18]byte var aeadEncryptedResponseHeaderLength [18]byte
var decryptedResponseHeaderLength int var decryptedResponseHeaderLength int
@ -205,8 +202,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey) aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey)
aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12] aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12]
aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block) aeadResponseHeaderPayloadEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderPayloadEncryptionKey)
aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD)
encryptedResponseHeaderBuffer := make([]byte, decryptedResponseHeaderLength+16) encryptedResponseHeaderBuffer := make([]byte, decryptedResponseHeaderLength+16)

View File

@ -2,8 +2,6 @@ package encoding
import ( import (
"bytes" "bytes"
"crypto/aes"
"crypto/cipher"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"hash/fnv" "hash/fnv"
@ -350,8 +348,7 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr
aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey) aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey)
aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12] aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12]
aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block) aeadResponseHeaderLengthEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderLengthEncryptionKey)
aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD)
aeadResponseHeaderLengthEncryptionBuffer := bytes.NewBuffer(nil) aeadResponseHeaderLengthEncryptionBuffer := bytes.NewBuffer(nil)
@ -365,8 +362,7 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr
aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey) aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey)
aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12] aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12]
aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block) aeadResponseHeaderPayloadEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderPayloadEncryptionKey)
aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD)
aeadEncryptedHeaderPayload := aeadResponseHeaderPayloadEncryptionAEAD.Seal(nil, aeadResponseHeaderPayloadEncryptionIV, aeadEncryptedHeaderBuffer.Bytes(), nil) aeadEncryptedHeaderPayload := aeadResponseHeaderPayloadEncryptionAEAD.Seal(nil, aeadResponseHeaderPayloadEncryptionIV, aeadEncryptedHeaderBuffer.Bytes(), nil)
common.Must2(io.Copy(writer, bytes.NewReader(aeadEncryptedHeaderPayload))) common.Must2(io.Copy(writer, bytes.NewReader(aeadEncryptedHeaderPayload)))

View File

@ -1,15 +1,13 @@
package kcp package kcp
import ( import (
"crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/sha256" "crypto/sha256"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/crypto"
) )
func NewAEADAESGCMBasedOnSeed(seed string) cipher.AEAD { func NewAEADAESGCMBasedOnSeed(seed string) cipher.AEAD {
hashedSeed := sha256.Sum256([]byte(seed)) hashedSeed := sha256.Sum256([]byte(seed))
aesBlock := common.Must2(aes.NewCipher(hashedSeed[:16])).(cipher.Block) return crypto.NewAesGcm(hashedSeed[:])
return common.Must2(cipher.NewGCM(aesBlock)).(cipher.AEAD)
} }

View File

@ -3,8 +3,6 @@ package reality
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/aes"
"crypto/cipher"
"crypto/ecdh" "crypto/ecdh"
"crypto/ed25519" "crypto/ed25519"
"crypto/hmac" "crypto/hmac"
@ -169,8 +167,7 @@ func UClient(c net.Conn, config *Config, ctx context.Context, dest net.Destinati
if _, err := hkdf.New(sha256.New, uConn.AuthKey, hello.Random[:20], []byte("REALITY")).Read(uConn.AuthKey); err != nil { if _, err := hkdf.New(sha256.New, uConn.AuthKey, hello.Random[:20], []byte("REALITY")).Read(uConn.AuthKey); err != nil {
return nil, err return nil, err
} }
block, _ := aes.NewCipher(uConn.AuthKey) aead := crypto.NewAesGcm(uConn.AuthKey)
aead, _ := cipher.NewGCM(block)
if config.Show { if config.Show {
fmt.Printf("REALITY localAddr: %v\tuConn.AuthKey[:16]: %v\tAEAD: %T\n", localAddr, uConn.AuthKey[:16], aead) fmt.Printf("REALITY localAddr: %v\tuConn.AuthKey[:16]: %v\tAEAD: %T\n", localAddr, uConn.AuthKey[:16], aead)
} }

View File

@ -297,7 +297,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
if transportConfiguration.DownloadSettings != nil { if transportConfiguration.DownloadSettings != nil {
globalDialerAccess.Lock() globalDialerAccess.Lock()
if streamSettings.DownloadSettings == nil { if streamSettings.DownloadSettings == nil {
streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings)).(*internet.MemoryStreamConfig) streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings))
if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.Penetrate { if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.Penetrate {
streamSettings.DownloadSettings.SocketSettings = streamSettings.SocketSettings streamSettings.DownloadSettings.SocketSettings = streamSettings.SocketSettings
} }