Xray-core/proxy/vless/encryption/server.go

329 lines
9.9 KiB
Go

package encryption
import (
"bytes"
"crypto/cipher"
"crypto/ecdh"
"crypto/mlkem"
"crypto/rand"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/xtls/xray-core/common/crypto"
"github.com/xtls/xray-core/common/errors"
"lukechampine.com/blake3"
)
type ServerSession struct {
PfsKey []byte
NfsKeys sync.Map
}
type ServerInstance struct {
NfsSKeys []any
NfsPKeysBytes [][]byte
Hash32s [][32]byte
RelaysLength int
XorMode uint32
SecondsFrom int64
SecondsTo int64
PaddingLens [][3]int
PaddingGaps [][3]int
RWLock sync.RWMutex
Closed bool
Lasts map[int64][16]byte
Tickets [][16]byte
Sessions map[[16]byte]*ServerSession
}
func (i *ServerInstance) Init(nfsSKeysBytes [][]byte, xorMode uint32, secondsFrom, secondsTo int64, padding string) (err error) {
if i.NfsSKeys != nil {
return errors.New("already initialized")
}
l := len(nfsSKeysBytes)
if l == 0 {
return errors.New("empty nfsSKeysBytes")
}
i.NfsSKeys = make([]any, l)
i.NfsPKeysBytes = make([][]byte, l)
i.Hash32s = make([][32]byte, l)
for j, k := range nfsSKeysBytes {
if len(k) == 32 {
if i.NfsSKeys[j], err = ecdh.X25519().NewPrivateKey(k); err != nil {
return
}
i.NfsPKeysBytes[j] = i.NfsSKeys[j].(*ecdh.PrivateKey).PublicKey().Bytes()
i.RelaysLength += 32 + 32
} else {
if i.NfsSKeys[j], err = mlkem.NewDecapsulationKey768(k); err != nil {
return
}
i.NfsPKeysBytes[j] = i.NfsSKeys[j].(*mlkem.DecapsulationKey768).EncapsulationKey().Bytes()
i.RelaysLength += 1088 + 32
}
i.Hash32s[j] = blake3.Sum256(i.NfsPKeysBytes[j])
}
i.RelaysLength -= 32
i.XorMode = xorMode
i.SecondsFrom = secondsFrom
i.SecondsTo = secondsTo
err = ParsePadding(padding, &i.PaddingLens, &i.PaddingGaps)
if err != nil {
return
}
if i.SecondsFrom > 0 || i.SecondsTo > 0 {
i.Lasts = make(map[int64][16]byte)
i.Tickets = make([][16]byte, 0, 1024)
i.Sessions = make(map[[16]byte]*ServerSession)
go func() {
for {
time.Sleep(time.Minute)
i.RWLock.Lock()
if i.Closed {
i.RWLock.Unlock()
return
}
minute := time.Now().Unix() / 60
last := i.Lasts[minute]
delete(i.Lasts, minute)
delete(i.Lasts, minute-1) // for insurance
if last != [16]byte{} {
for j, ticket := range i.Tickets {
delete(i.Sessions, ticket)
if ticket == last {
i.Tickets = i.Tickets[j+1:]
break
}
}
}
i.RWLock.Unlock()
}
}()
}
return
}
func (i *ServerInstance) Close() (err error) {
i.RWLock.Lock()
i.Closed = true
i.RWLock.Unlock()
return
}
func (i *ServerInstance) Handshake(conn net.Conn, fallback *[]byte) (*CommonConn, error) {
if i.NfsSKeys == nil {
return nil, errors.New("uninitialized")
}
c := NewCommonConn(conn, true)
ivAndRelays := make([]byte, 16+i.RelaysLength)
if _, err := io.ReadFull(conn, ivAndRelays); err != nil {
return nil, err
}
if fallback != nil {
*fallback = append(*fallback, ivAndRelays...)
}
iv := ivAndRelays[:16]
relays := ivAndRelays[16:]
var nfsKey []byte
var lastCTR cipher.Stream
for j, k := range i.NfsSKeys {
if lastCTR != nil {
lastCTR.XORKeyStream(relays, relays[:32]) // recover this relay
}
var index = 32
if _, ok := k.(*mlkem.DecapsulationKey768); ok {
index = 1088
}
if i.XorMode > 0 {
NewCTR(i.NfsPKeysBytes[j], iv).XORKeyStream(relays, relays[:index]) // we don't use buggy elligator2, because we have PSK :)
}
if k, ok := k.(*ecdh.PrivateKey); ok {
publicKey, err := ecdh.X25519().NewPublicKey(relays[:index])
if err != nil {
return nil, err
}
if publicKey.Bytes()[31] > 127 { // we just don't want the observer can change even one bit without breaking the connection, though it has nothing to do with security
return nil, errors.New("the highest bit of the last byte of the peer-sent X25519 public key is not 0")
}
nfsKey, err = k.ECDH(publicKey)
if err != nil {
return nil, err
}
}
if k, ok := k.(*mlkem.DecapsulationKey768); ok {
var err error
nfsKey, err = k.Decapsulate(relays[:index])
if err != nil {
return nil, err
}
}
if j == len(i.NfsSKeys)-1 {
break
}
relays = relays[index:]
lastCTR = NewCTR(nfsKey, iv)
lastCTR.XORKeyStream(relays, relays[:32])
if !bytes.Equal(relays[:32], i.Hash32s[j+1][:]) {
return nil, errors.New("unexpected hash32: ", fmt.Sprintf("%v", relays[:32]))
}
relays = relays[32:]
}
nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES)
encryptedLength := make([]byte, 18)
if _, err := io.ReadFull(conn, encryptedLength); err != nil {
return nil, err
}
if fallback != nil {
*fallback = append(*fallback, encryptedLength...)
}
decryptedLength := make([]byte, 2)
if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
c.UseAES = !c.UseAES
nfsAEAD = NewAEAD(iv, nfsKey, c.UseAES)
if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
return nil, err
}
}
if fallback != nil {
*fallback = nil
}
length := DecodeLength(decryptedLength)
if length == 32 {
if i.SecondsFrom == 0 && i.SecondsTo == 0 {
return nil, errors.New("0-RTT is not allowed")
}
encryptedTicket := make([]byte, 32)
if _, err := io.ReadFull(conn, encryptedTicket); err != nil {
return nil, err
}
ticket, err := nfsAEAD.Open(nil, nil, encryptedTicket, nil)
if err != nil {
return nil, err
}
i.RWLock.RLock()
s := i.Sessions[[16]byte(ticket)]
i.RWLock.RUnlock()
if s == nil {
noises := make([]byte, crypto.RandBetween(1279, 2279)) // matches 1-RTT's server hello length for "random", though it is not important, just for example
var err error
for err == nil {
rand.Read(noises)
_, err = DecodeHeader(noises)
}
conn.Write(noises) // make client do new handshake
return nil, errors.New("expired ticket")
}
if _, loaded := s.NfsKeys.LoadOrStore([32]byte(nfsKey), true); loaded { // prevents bad client also
return nil, errors.New("replay detected")
}
c.UnitedKey = append(s.PfsKey, nfsKey...) // the same nfsKey links the upload & download (prevents server -> client's another request)
c.PreWrite = make([]byte, 16)
rand.Read(c.PreWrite) // always trust yourself, not the client (also prevents being parsed as TLS thus causing false interruption for "native" and "xorpub")
c.AEAD = NewAEAD(c.PreWrite, c.UnitedKey, c.UseAES)
c.PeerAEAD = NewAEAD(encryptedTicket, c.UnitedKey, c.UseAES) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client)
if i.XorMode == 2 {
c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, c.PreWrite), NewCTR(c.UnitedKey, iv), 16, 0) // it doesn't matter if the attacker sends client's iv back to the client
}
return c, nil
}
if length < 1184+32+16 { // client may send more public keys in the future's version
return nil, errors.New("too short length")
}
encryptedPfsPublicKey := make([]byte, length)
if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil {
return nil, err
}
if _, err := nfsAEAD.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil {
return nil, err
}
mlkem768EKey, err := mlkem.NewEncapsulationKey768(encryptedPfsPublicKey[:1184])
if err != nil {
return nil, err
}
mlkem768Key, encapsulatedPfsKey := mlkem768EKey.Encapsulate()
peerX25519PKey, err := ecdh.X25519().NewPublicKey(encryptedPfsPublicKey[1184 : 1184+32])
if err != nil {
return nil, err
}
x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader)
x25519Key, err := x25519SKey.ECDH(peerX25519PKey)
if err != nil {
return nil, err
}
pfsKey := make([]byte, 32+32) // no more capacity
copy(pfsKey, mlkem768Key)
copy(pfsKey[32:], x25519Key)
pfsPublicKey := append(encapsulatedPfsKey, x25519SKey.PublicKey().Bytes()...)
c.UnitedKey = append(pfsKey, nfsKey...)
c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES)
c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1184+32], c.UnitedKey, c.UseAES)
ticket := [16]byte{}
rand.Read(ticket[:])
var seconds int64
if i.SecondsTo == 0 {
seconds = i.SecondsFrom * crypto.RandBetween(50, 100) / 100
} else {
seconds = crypto.RandBetween(i.SecondsFrom, i.SecondsTo)
}
copy(ticket[:], EncodeLength(int(seconds)))
if seconds > 0 {
i.RWLock.Lock()
i.Lasts[(time.Now().Unix()+max(i.SecondsFrom, i.SecondsTo))/60+2] = ticket
i.Tickets = append(i.Tickets, ticket)
i.Sessions[ticket] = &ServerSession{PfsKey: pfsKey}
i.RWLock.Unlock()
}
pfsKeyExchangeLength := 1088 + 32 + 16
encryptedTicketLength := 32
paddingLength, paddingLens, paddingGaps := CreatPadding(i.PaddingLens, i.PaddingGaps)
serverHello := make([]byte, pfsKeyExchangeLength+encryptedTicketLength+paddingLength)
nfsAEAD.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil)
c.AEAD.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket[:], nil)
padding := serverHello[pfsKeyExchangeLength+encryptedTicketLength:]
c.AEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
c.AEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
paddingLens[0] = pfsKeyExchangeLength + encryptedTicketLength + paddingLens[0]
for i, l := range paddingLens { // sends padding in a fragmented way, to create variable traffic pattern, before inner VLESS flow takes control
if l > 0 {
if _, err := conn.Write(serverHello[:l]); err != nil {
return nil, err
}
serverHello = serverHello[l:]
}
if len(paddingGaps) > i {
time.Sleep(paddingGaps[i])
}
}
// important: allows client sends padding slowly, eliminating 1-RTT's traffic pattern
if _, err := io.ReadFull(conn, encryptedLength); err != nil {
return nil, err
}
if _, err := nfsAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
return nil, err
}
encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2]))
if _, err := io.ReadFull(conn, encryptedPadding); err != nil {
return nil, err
}
if _, err := nfsAEAD.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil {
return nil, err
}
if i.XorMode == 2 {
c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, ticket[:]), NewCTR(c.UnitedKey, iv), 0, 0)
}
return c, nil
}