RPRX 2025-08-29 14:05:39 +00:00 committed by GitHub
parent 56a45ad578
commit 82ea7a3cc5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 54 deletions

View File

@ -12,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/crypto" "github.com/xtls/xray-core/common/crypto"
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/protocol"
"lukechampine.com/blake3" "lukechampine.com/blake3"
) )
@ -66,7 +67,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if i.NfsPKeys == nil { if i.NfsPKeys == nil {
return nil, errors.New("uninitialized") return nil, errors.New("uninitialized")
} }
c := NewCommonConn(conn) c := NewCommonConn(conn, protocol.HasAESGCMHardwareSupport)
ivAndRealysLength := 16 + i.RelaysLength ivAndRealysLength := 16 + i.RelaysLength
pfsKeyExchangeLength := 18 + 1184 + 32 + 16 pfsKeyExchangeLength := 18 + 1184 + 32 + 16
@ -108,18 +109,18 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
lastCTR.XORKeyStream(relays[index:], i.Hash32s[j+1][:]) lastCTR.XORKeyStream(relays[index:], i.Hash32s[j+1][:])
relays = relays[index+32:] relays = relays[index+32:]
} }
nfsGCM := NewGCM(iv, nfsKey) nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES)
if i.Seconds > 0 { if i.Seconds > 0 {
i.RWLock.RLock() i.RWLock.RLock()
if time.Now().Before(i.Expire) { if time.Now().Before(i.Expire) {
c.Client = i c.Client = i
c.UnitedKey = append(i.PfsKey, nfsKey...) // different unitedKey for each connection c.UnitedKey = append(i.PfsKey, nfsKey...) // different unitedKey for each connection
nfsGCM.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil) nfsAEAD.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil)
nfsGCM.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil) nfsAEAD.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil)
i.RWLock.RUnlock() i.RWLock.RUnlock()
c.PreWrite = clientHello[:ivAndRealysLength+18+32] c.PreWrite = clientHello[:ivAndRealysLength+18+32]
c.GCM = NewGCM(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey) c.AEAD = NewAEAD(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey, c.UseAES)
if i.XorMode == 2 { if i.XorMode == 2 {
c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, iv), nil, len(c.PreWrite), 16) c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, iv), nil, len(c.PreWrite), 16)
} }
@ -129,15 +130,15 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
} }
pfsKeyExchange := clientHello[ivAndRealysLength : ivAndRealysLength+pfsKeyExchangeLength] pfsKeyExchange := clientHello[ivAndRealysLength : ivAndRealysLength+pfsKeyExchangeLength]
nfsGCM.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil) nfsAEAD.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil)
mlkem768DKey, _ := mlkem.GenerateKey768() mlkem768DKey, _ := mlkem.GenerateKey768()
x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader) x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader)
pfsPublicKey := append(mlkem768DKey.EncapsulationKey().Bytes(), x25519SKey.PublicKey().Bytes()...) pfsPublicKey := append(mlkem768DKey.EncapsulationKey().Bytes(), x25519SKey.PublicKey().Bytes()...)
nfsGCM.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil) nfsAEAD.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil)
padding := clientHello[ivAndRealysLength+pfsKeyExchangeLength:] padding := clientHello[ivAndRealysLength+pfsKeyExchangeLength:]
nfsGCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil) nfsAEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
nfsGCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil) nfsAEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
if _, err := conn.Write(clientHello); err != nil { if _, err := conn.Write(clientHello); err != nil {
return nil, err return nil, err
@ -148,7 +149,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil { if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil {
return nil, err return nil, err
} }
nfsGCM.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil) nfsAEAD.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil)
mlkem768Key, err := mlkem768DKey.Decapsulate(encryptedPfsPublicKey[:1088]) mlkem768Key, err := mlkem768DKey.Decapsulate(encryptedPfsPublicKey[:1088])
if err != nil { if err != nil {
return nil, err return nil, err
@ -165,14 +166,14 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
copy(pfsKey, mlkem768Key) copy(pfsKey, mlkem768Key)
copy(pfsKey[32:], x25519Key) copy(pfsKey[32:], x25519Key)
c.UnitedKey = append(pfsKey, nfsKey...) c.UnitedKey = append(pfsKey, nfsKey...)
c.GCM = NewGCM(pfsPublicKey, c.UnitedKey) c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES)
c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1088+32], c.UnitedKey) c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1088+32], c.UnitedKey, c.UseAES)
encryptedTicket := make([]byte, 32) encryptedTicket := make([]byte, 32)
if _, err := io.ReadFull(conn, encryptedTicket); err != nil { if _, err := io.ReadFull(conn, encryptedTicket); err != nil {
return nil, err return nil, err
} }
if _, err := c.PeerGCM.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil { if _, err := c.PeerAEAD.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil {
return nil, err return nil, err
} }
seconds := DecodeLength(encryptedTicket) seconds := DecodeLength(encryptedTicket)
@ -189,7 +190,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if _, err := io.ReadFull(conn, encryptedLength); err != nil { if _, err := io.ReadFull(conn, encryptedLength); err != nil {
return nil, err return nil, err
} }
if _, err := c.PeerGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { if _, err := c.PeerAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
return nil, err return nil, err
} }
length := DecodeLength(encryptedLength[:2]) length := DecodeLength(encryptedLength[:2])

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
"golang.org/x/crypto/chacha20poly1305"
"lukechampine.com/blake3" "lukechampine.com/blake3"
) )
@ -23,19 +24,21 @@ var OutBytesPool = sync.Pool{
type CommonConn struct { type CommonConn struct {
net.Conn net.Conn
UseAES bool
Client *ClientInstance Client *ClientInstance
UnitedKey []byte UnitedKey []byte
PreWrite []byte PreWrite []byte
GCM *GCM AEAD *AEAD
PeerGCM *GCM PeerAEAD *AEAD
PeerPadding []byte PeerPadding []byte
PeerInBytes []byte PeerInBytes []byte
PeerCache []byte PeerCache []byte
} }
func NewCommonConn(conn net.Conn) *CommonConn { func NewCommonConn(conn net.Conn, useAES bool) *CommonConn {
return &CommonConn{ return &CommonConn{
Conn: conn, Conn: conn,
UseAES: useAES,
PeerInBytes: make([]byte, 5+17000), // no need to use sync.Pool, because we are always reading PeerInBytes: make([]byte, 5+17000), // no need to use sync.Pool, because we are always reading
} }
} }
@ -55,12 +58,12 @@ func (c *CommonConn) Write(b []byte) (int, error) {
headerAndData := outBytes[:5+len(b)+16] headerAndData := outBytes[:5+len(b)+16]
EncodeHeader(headerAndData, len(b)+16) EncodeHeader(headerAndData, len(b)+16)
max := false max := false
if bytes.Equal(c.GCM.Nonce[:], MaxNonce) { if bytes.Equal(c.AEAD.Nonce[:], MaxNonce) {
max = true max = true
} }
c.GCM.Seal(headerAndData[:5], nil, b, headerAndData[:5]) c.AEAD.Seal(headerAndData[:5], nil, b, headerAndData[:5])
if max { if max {
c.GCM = NewGCM(headerAndData, c.UnitedKey) c.AEAD = NewAEAD(headerAndData, c.UnitedKey, c.UseAES)
} }
if c.PreWrite != nil { if c.PreWrite != nil {
headerAndData = append(c.PreWrite, headerAndData...) headerAndData = append(c.PreWrite, headerAndData...)
@ -77,12 +80,12 @@ func (c *CommonConn) Read(b []byte) (int, error) {
if len(b) == 0 { if len(b) == 0 {
return 0, nil return 0, nil
} }
if c.PeerGCM == nil { // client's 0-RTT if c.PeerAEAD == nil { // client's 0-RTT
serverRandom := make([]byte, 16) serverRandom := make([]byte, 16)
if _, err := io.ReadFull(c.Conn, serverRandom); err != nil { if _, err := io.ReadFull(c.Conn, serverRandom); err != nil {
return 0, err return 0, err
} }
c.PeerGCM = NewGCM(serverRandom, c.UnitedKey) c.PeerAEAD = NewAEAD(serverRandom, c.UnitedKey, c.UseAES)
if xorConn, ok := c.Conn.(*XorConn); ok { if xorConn, ok := c.Conn.(*XorConn); ok {
xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom) xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom)
} }
@ -91,7 +94,7 @@ func (c *CommonConn) Read(b []byte) (int, error) {
if _, err := io.ReadFull(c.Conn, c.PeerPadding); err != nil { if _, err := io.ReadFull(c.Conn, c.PeerPadding); err != nil {
return 0, err return 0, err
} }
if _, err := c.PeerGCM.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil { if _, err := c.PeerAEAD.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil {
return 0, err return 0, err
} }
c.PeerPadding = nil c.PeerPadding = nil
@ -126,13 +129,13 @@ func (c *CommonConn) Read(b []byte) (int, error) {
if len(dst) <= len(b) { if len(dst) <= len(b) {
dst = b[:len(dst)] // avoids another copy() dst = b[:len(dst)] // avoids another copy()
} }
var newGCM *GCM var newAEAD *AEAD
if bytes.Equal(c.PeerGCM.Nonce[:], MaxNonce) { if bytes.Equal(c.PeerAEAD.Nonce[:], MaxNonce) {
newGCM = NewGCM(c.PeerInBytes[:5+l], c.UnitedKey) newAEAD = NewAEAD(c.PeerInBytes[:5+l], c.UnitedKey, c.UseAES)
} }
_, err = c.PeerGCM.Open(dst[:0], nil, peerData, peerHeader) _, err = c.PeerAEAD.Open(dst[:0], nil, peerData, peerHeader)
if newGCM != nil { if newAEAD != nil {
c.PeerGCM = newGCM c.PeerAEAD = newAEAD
} }
if err != nil { if err != nil {
return 0, err return 0, err
@ -144,28 +147,32 @@ func (c *CommonConn) Read(b []byte) (int, error) {
return len(dst), nil return len(dst), nil
} }
type GCM struct { type AEAD struct {
cipher.AEAD cipher.AEAD
Nonce [12]byte Nonce [12]byte
} }
func NewGCM(ctx, key []byte) *GCM { func NewAEAD(ctx, key []byte, useAES bool) *AEAD {
k := make([]byte, 32) k := make([]byte, 32)
blake3.DeriveKey(k, string(ctx), key) blake3.DeriveKey(k, string(ctx), key)
block, _ := aes.NewCipher(k) var aead cipher.AEAD
aead, _ := cipher.NewGCM(block) if useAES {
return &GCM{AEAD: aead} block, _ := aes.NewCipher(k)
//chacha20poly1305.New() aead, _ = cipher.NewGCM(block)
} else {
aead, _ = chacha20poly1305.New(k)
}
return &AEAD{AEAD: aead}
} }
func (a *GCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte { func (a *AEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
if nonce == nil { if nonce == nil {
nonce = IncreaseNonce(a.Nonce[:]) nonce = IncreaseNonce(a.Nonce[:])
} }
return a.AEAD.Seal(dst, nonce, plaintext, additionalData) return a.AEAD.Seal(dst, nonce, plaintext, additionalData)
} }
func (a *GCM) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { func (a *AEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
if nonce == nil { if nonce == nil {
nonce = IncreaseNonce(a.Nonce[:]) nonce = IncreaseNonce(a.Nonce[:])
} }

View File

@ -102,7 +102,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if i.NfsSKeys == nil { if i.NfsSKeys == nil {
return nil, errors.New("uninitialized") return nil, errors.New("uninitialized")
} }
c := NewCommonConn(conn) c := NewCommonConn(conn, true)
ivAndRelays := make([]byte, 16+i.RelaysLength) ivAndRelays := make([]byte, 16+i.RelaysLength)
if _, err := io.ReadFull(conn, ivAndRelays); err != nil { if _, err := io.ReadFull(conn, ivAndRelays); err != nil {
@ -151,16 +151,21 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
} }
relays = relays[32:] relays = relays[32:]
} }
nfsGCM := NewGCM(iv, nfsKey) nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES)
encryptedLength := make([]byte, 18) encryptedLength := make([]byte, 18)
if _, err := io.ReadFull(conn, encryptedLength); err != nil { if _, err := io.ReadFull(conn, encryptedLength); err != nil {
return nil, err return nil, err
} }
if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { decryptedLength := make([]byte, 2)
return nil, err if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
c.UseAES = !c.UseAES
nfsAEAD = NewAEAD(iv, nfsKey, c.UseAES)
if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
return nil, err
}
} }
length := DecodeLength(encryptedLength[:2]) length := DecodeLength(decryptedLength)
if length == 32 { if length == 32 {
if i.Seconds == 0 { if i.Seconds == 0 {
@ -170,7 +175,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if _, err := io.ReadFull(conn, encryptedTicket); err != nil { if _, err := io.ReadFull(conn, encryptedTicket); err != nil {
return nil, err return nil, err
} }
ticket, err := nfsGCM.Open(nil, nil, encryptedTicket, nil) ticket, err := nfsAEAD.Open(nil, nil, encryptedTicket, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -193,8 +198,8 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
c.UnitedKey = append(s.PfsKey, nfsKey...) // the same nfsKey links the upload & download (prevents server -> client's another request) c.UnitedKey = append(s.PfsKey, nfsKey...) // the same nfsKey links the upload & download (prevents server -> client's another request)
c.PreWrite = make([]byte, 16) c.PreWrite = make([]byte, 16)
rand.Read(c.PreWrite) // always trust yourself, not the client (also prevents being parsed as TLS thus causing false interruption for "native" and "xorpub") rand.Read(c.PreWrite) // always trust yourself, not the client (also prevents being parsed as TLS thus causing false interruption for "native" and "xorpub")
c.GCM = NewGCM(c.PreWrite, c.UnitedKey) c.AEAD = NewAEAD(c.PreWrite, c.UnitedKey, c.UseAES)
c.PeerGCM = NewGCM(encryptedTicket, c.UnitedKey) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client) c.PeerAEAD = NewAEAD(encryptedTicket, c.UnitedKey, c.UseAES) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client)
if i.XorMode == 2 { if i.XorMode == 2 {
c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, c.PreWrite), NewCTR(c.UnitedKey, iv), 16, 0) // it doesn't matter if the attacker sends client's iv back to the client c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, c.PreWrite), NewCTR(c.UnitedKey, iv), 16, 0) // it doesn't matter if the attacker sends client's iv back to the client
} }
@ -208,7 +213,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil { if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil {
return nil, err return nil, err
} }
if _, err := nfsGCM.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil { if _, err := nfsAEAD.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil {
return nil, err return nil, err
} }
mlkem768EKey, err := mlkem.NewEncapsulationKey768(encryptedPfsPublicKey[:1184]) mlkem768EKey, err := mlkem.NewEncapsulationKey768(encryptedPfsPublicKey[:1184])
@ -230,8 +235,8 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
copy(pfsKey[32:], x25519Key) copy(pfsKey[32:], x25519Key)
pfsPublicKey := append(encapsulatedPfsKey, x25519SKey.PublicKey().Bytes()...) pfsPublicKey := append(encapsulatedPfsKey, x25519SKey.PublicKey().Bytes()...)
c.UnitedKey = append(pfsKey, nfsKey...) c.UnitedKey = append(pfsKey, nfsKey...)
c.GCM = NewGCM(pfsPublicKey, c.UnitedKey) c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES)
c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1184+32], c.UnitedKey) c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1184+32], c.UnitedKey, c.UseAES)
ticket := make([]byte, 16) ticket := make([]byte, 16)
rand.Read(ticket) rand.Read(ticket)
copy(ticket, EncodeLength(int(i.Seconds*4/5))) copy(ticket, EncodeLength(int(i.Seconds*4/5)))
@ -240,11 +245,11 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
encryptedTicketLength := 32 encryptedTicketLength := 32
paddingLength := int(crypto.RandBetween(100, 1000)) paddingLength := int(crypto.RandBetween(100, 1000))
serverHello := make([]byte, pfsKeyExchangeLength+encryptedTicketLength+paddingLength) serverHello := make([]byte, pfsKeyExchangeLength+encryptedTicketLength+paddingLength)
nfsGCM.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil) nfsAEAD.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil)
c.GCM.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil) c.AEAD.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil)
padding := serverHello[pfsKeyExchangeLength+encryptedTicketLength:] padding := serverHello[pfsKeyExchangeLength+encryptedTicketLength:]
c.GCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil) c.AEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
c.GCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil) c.AEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
if _, err := conn.Write(serverHello); err != nil { if _, err := conn.Write(serverHello); err != nil {
return nil, err return nil, err
@ -264,14 +269,14 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if _, err := io.ReadFull(conn, encryptedLength); err != nil { if _, err := io.ReadFull(conn, encryptedLength); err != nil {
return nil, err return nil, err
} }
if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { if _, err := nfsAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
return nil, err return nil, err
} }
encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2])) encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2]))
if _, err := io.ReadFull(conn, encryptedPadding); err != nil { if _, err := io.ReadFull(conn, encryptedPadding); err != nil {
return nil, err return nil, err
} }
if _, err := nfsGCM.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil { if _, err := nfsAEAD.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil {
return nil, err return nil, err
} }