diff --git a/proxy/vless/encryption/client.go b/proxy/vless/encryption/client.go index 301c2328..77c0b334 100644 --- a/proxy/vless/encryption/client.go +++ b/proxy/vless/encryption/client.go @@ -12,6 +12,7 @@ import ( "github.com/xtls/xray-core/common/crypto" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/protocol" "lukechampine.com/blake3" ) @@ -66,7 +67,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { if i.NfsPKeys == nil { return nil, errors.New("uninitialized") } - c := NewCommonConn(conn) + c := NewCommonConn(conn, protocol.HasAESGCMHardwareSupport) ivAndRealysLength := 16 + i.RelaysLength pfsKeyExchangeLength := 18 + 1184 + 32 + 16 @@ -108,18 +109,18 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { lastCTR.XORKeyStream(relays[index:], i.Hash32s[j+1][:]) relays = relays[index+32:] } - nfsGCM := NewGCM(iv, nfsKey) + nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES) if i.Seconds > 0 { i.RWLock.RLock() if time.Now().Before(i.Expire) { c.Client = i c.UnitedKey = append(i.PfsKey, nfsKey...) // different unitedKey for each connection - nfsGCM.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil) - nfsGCM.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil) + nfsAEAD.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil) + nfsAEAD.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil) i.RWLock.RUnlock() c.PreWrite = clientHello[:ivAndRealysLength+18+32] - c.GCM = NewGCM(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey) + c.AEAD = NewAEAD(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey, c.UseAES) if i.XorMode == 2 { c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, iv), nil, len(c.PreWrite), 16) } @@ -129,15 +130,15 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { } pfsKeyExchange := clientHello[ivAndRealysLength : ivAndRealysLength+pfsKeyExchangeLength] - nfsGCM.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil) + nfsAEAD.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil) mlkem768DKey, _ := mlkem.GenerateKey768() x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader) pfsPublicKey := append(mlkem768DKey.EncapsulationKey().Bytes(), x25519SKey.PublicKey().Bytes()...) - nfsGCM.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil) + nfsAEAD.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil) padding := clientHello[ivAndRealysLength+pfsKeyExchangeLength:] - nfsGCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil) - nfsGCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil) + nfsAEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil) + nfsAEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil) if _, err := conn.Write(clientHello); err != nil { return nil, err @@ -148,7 +149,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil { return nil, err } - nfsGCM.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil) + nfsAEAD.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil) mlkem768Key, err := mlkem768DKey.Decapsulate(encryptedPfsPublicKey[:1088]) if err != nil { return nil, err @@ -165,14 +166,14 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { copy(pfsKey, mlkem768Key) copy(pfsKey[32:], x25519Key) c.UnitedKey = append(pfsKey, nfsKey...) - c.GCM = NewGCM(pfsPublicKey, c.UnitedKey) - c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1088+32], c.UnitedKey) + c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES) + c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1088+32], c.UnitedKey, c.UseAES) encryptedTicket := make([]byte, 32) if _, err := io.ReadFull(conn, encryptedTicket); err != nil { return nil, err } - if _, err := c.PeerGCM.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil { + if _, err := c.PeerAEAD.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil { return nil, err } seconds := DecodeLength(encryptedTicket) @@ -189,7 +190,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { if _, err := io.ReadFull(conn, encryptedLength); err != nil { return nil, err } - if _, err := c.PeerGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { + if _, err := c.PeerAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { return nil, err } length := DecodeLength(encryptedLength[:2]) diff --git a/proxy/vless/encryption/common.go b/proxy/vless/encryption/common.go index 6f914c8d..00528acc 100644 --- a/proxy/vless/encryption/common.go +++ b/proxy/vless/encryption/common.go @@ -12,6 +12,7 @@ import ( "time" "github.com/xtls/xray-core/common/errors" + "golang.org/x/crypto/chacha20poly1305" "lukechampine.com/blake3" ) @@ -23,19 +24,21 @@ var OutBytesPool = sync.Pool{ type CommonConn struct { net.Conn + UseAES bool Client *ClientInstance UnitedKey []byte PreWrite []byte - GCM *GCM - PeerGCM *GCM + AEAD *AEAD + PeerAEAD *AEAD PeerPadding []byte PeerInBytes []byte PeerCache []byte } -func NewCommonConn(conn net.Conn) *CommonConn { +func NewCommonConn(conn net.Conn, useAES bool) *CommonConn { return &CommonConn{ Conn: conn, + UseAES: useAES, PeerInBytes: make([]byte, 5+17000), // no need to use sync.Pool, because we are always reading } } @@ -55,12 +58,12 @@ func (c *CommonConn) Write(b []byte) (int, error) { headerAndData := outBytes[:5+len(b)+16] EncodeHeader(headerAndData, len(b)+16) max := false - if bytes.Equal(c.GCM.Nonce[:], MaxNonce) { + if bytes.Equal(c.AEAD.Nonce[:], MaxNonce) { max = true } - c.GCM.Seal(headerAndData[:5], nil, b, headerAndData[:5]) + c.AEAD.Seal(headerAndData[:5], nil, b, headerAndData[:5]) if max { - c.GCM = NewGCM(headerAndData, c.UnitedKey) + c.AEAD = NewAEAD(headerAndData, c.UnitedKey, c.UseAES) } if c.PreWrite != nil { headerAndData = append(c.PreWrite, headerAndData...) @@ -77,12 +80,12 @@ func (c *CommonConn) Read(b []byte) (int, error) { if len(b) == 0 { return 0, nil } - if c.PeerGCM == nil { // client's 0-RTT + if c.PeerAEAD == nil { // client's 0-RTT serverRandom := make([]byte, 16) if _, err := io.ReadFull(c.Conn, serverRandom); err != nil { return 0, err } - c.PeerGCM = NewGCM(serverRandom, c.UnitedKey) + c.PeerAEAD = NewAEAD(serverRandom, c.UnitedKey, c.UseAES) if xorConn, ok := c.Conn.(*XorConn); ok { xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom) } @@ -91,7 +94,7 @@ func (c *CommonConn) Read(b []byte) (int, error) { if _, err := io.ReadFull(c.Conn, c.PeerPadding); err != nil { return 0, err } - if _, err := c.PeerGCM.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil { + if _, err := c.PeerAEAD.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil { return 0, err } c.PeerPadding = nil @@ -126,13 +129,13 @@ func (c *CommonConn) Read(b []byte) (int, error) { if len(dst) <= len(b) { dst = b[:len(dst)] // avoids another copy() } - var newGCM *GCM - if bytes.Equal(c.PeerGCM.Nonce[:], MaxNonce) { - newGCM = NewGCM(c.PeerInBytes[:5+l], c.UnitedKey) + var newAEAD *AEAD + if bytes.Equal(c.PeerAEAD.Nonce[:], MaxNonce) { + newAEAD = NewAEAD(c.PeerInBytes[:5+l], c.UnitedKey, c.UseAES) } - _, err = c.PeerGCM.Open(dst[:0], nil, peerData, peerHeader) - if newGCM != nil { - c.PeerGCM = newGCM + _, err = c.PeerAEAD.Open(dst[:0], nil, peerData, peerHeader) + if newAEAD != nil { + c.PeerAEAD = newAEAD } if err != nil { return 0, err @@ -144,28 +147,32 @@ func (c *CommonConn) Read(b []byte) (int, error) { return len(dst), nil } -type GCM struct { +type AEAD struct { cipher.AEAD Nonce [12]byte } -func NewGCM(ctx, key []byte) *GCM { +func NewAEAD(ctx, key []byte, useAES bool) *AEAD { k := make([]byte, 32) blake3.DeriveKey(k, string(ctx), key) - block, _ := aes.NewCipher(k) - aead, _ := cipher.NewGCM(block) - return &GCM{AEAD: aead} - //chacha20poly1305.New() + var aead cipher.AEAD + if useAES { + block, _ := aes.NewCipher(k) + aead, _ = cipher.NewGCM(block) + } else { + aead, _ = chacha20poly1305.New(k) + } + return &AEAD{AEAD: aead} } -func (a *GCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte { +func (a *AEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte { if nonce == nil { nonce = IncreaseNonce(a.Nonce[:]) } return a.AEAD.Seal(dst, nonce, plaintext, additionalData) } -func (a *GCM) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { +func (a *AEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { if nonce == nil { nonce = IncreaseNonce(a.Nonce[:]) } diff --git a/proxy/vless/encryption/server.go b/proxy/vless/encryption/server.go index 8594fd25..ec0532bb 100644 --- a/proxy/vless/encryption/server.go +++ b/proxy/vless/encryption/server.go @@ -102,7 +102,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { if i.NfsSKeys == nil { return nil, errors.New("uninitialized") } - c := NewCommonConn(conn) + c := NewCommonConn(conn, true) ivAndRelays := make([]byte, 16+i.RelaysLength) if _, err := io.ReadFull(conn, ivAndRelays); err != nil { @@ -151,16 +151,21 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { } relays = relays[32:] } - nfsGCM := NewGCM(iv, nfsKey) + nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES) encryptedLength := make([]byte, 18) if _, err := io.ReadFull(conn, encryptedLength); err != nil { return nil, err } - if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { - return nil, err + decryptedLength := make([]byte, 2) + if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil { + c.UseAES = !c.UseAES + nfsAEAD = NewAEAD(iv, nfsKey, c.UseAES) + if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil { + return nil, err + } } - length := DecodeLength(encryptedLength[:2]) + length := DecodeLength(decryptedLength) if length == 32 { if i.Seconds == 0 { @@ -170,7 +175,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { if _, err := io.ReadFull(conn, encryptedTicket); err != nil { return nil, err } - ticket, err := nfsGCM.Open(nil, nil, encryptedTicket, nil) + ticket, err := nfsAEAD.Open(nil, nil, encryptedTicket, nil) if err != nil { return nil, err } @@ -193,8 +198,8 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { c.UnitedKey = append(s.PfsKey, nfsKey...) // the same nfsKey links the upload & download (prevents server -> client's another request) c.PreWrite = make([]byte, 16) rand.Read(c.PreWrite) // always trust yourself, not the client (also prevents being parsed as TLS thus causing false interruption for "native" and "xorpub") - c.GCM = NewGCM(c.PreWrite, c.UnitedKey) - c.PeerGCM = NewGCM(encryptedTicket, c.UnitedKey) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client) + c.AEAD = NewAEAD(c.PreWrite, c.UnitedKey, c.UseAES) + c.PeerAEAD = NewAEAD(encryptedTicket, c.UnitedKey, c.UseAES) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client) if i.XorMode == 2 { c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, c.PreWrite), NewCTR(c.UnitedKey, iv), 16, 0) // it doesn't matter if the attacker sends client's iv back to the client } @@ -208,7 +213,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil { return nil, err } - if _, err := nfsGCM.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil { + if _, err := nfsAEAD.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil { return nil, err } mlkem768EKey, err := mlkem.NewEncapsulationKey768(encryptedPfsPublicKey[:1184]) @@ -230,8 +235,8 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { copy(pfsKey[32:], x25519Key) pfsPublicKey := append(encapsulatedPfsKey, x25519SKey.PublicKey().Bytes()...) c.UnitedKey = append(pfsKey, nfsKey...) - c.GCM = NewGCM(pfsPublicKey, c.UnitedKey) - c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1184+32], c.UnitedKey) + c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES) + c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1184+32], c.UnitedKey, c.UseAES) ticket := make([]byte, 16) rand.Read(ticket) copy(ticket, EncodeLength(int(i.Seconds*4/5))) @@ -240,11 +245,11 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { encryptedTicketLength := 32 paddingLength := int(crypto.RandBetween(100, 1000)) serverHello := make([]byte, pfsKeyExchangeLength+encryptedTicketLength+paddingLength) - nfsGCM.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil) - c.GCM.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil) + nfsAEAD.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil) + c.AEAD.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil) padding := serverHello[pfsKeyExchangeLength+encryptedTicketLength:] - c.GCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil) - c.GCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil) + c.AEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil) + c.AEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil) if _, err := conn.Write(serverHello); err != nil { return nil, err @@ -264,14 +269,14 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { if _, err := io.ReadFull(conn, encryptedLength); err != nil { return nil, err } - if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { + if _, err := nfsAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { return nil, err } encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2])) if _, err := io.ReadFull(conn, encryptedPadding); err != nil { return nil, err } - if _, err := nfsGCM.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil { + if _, err := nfsAEAD.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil { return nil, err }