diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 775f61e2..4a12761d 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -414,7 +414,7 @@ type TLSConfig struct { VerifyPeerCertInNames []string `json:"verifyPeerCertInNames"` ECHServerKeys string `json:"echServerKeys"` ECHConfigList string `json:"echConfigList"` - ECHForceQuery bool `json:"echForceQuery"` + ECHForceQuery string `json:"echForceQuery"` ECHSocketSettings *SocketConfig `json:"echSockopt"` } @@ -494,6 +494,12 @@ func (c *TLSConfig) Build() (proto.Message, error) { } config.EchServerKeys = EchPrivateKey } + switch c.ECHForceQuery { + case "none", "half", "full", "": + config.EchForceQuery = c.ECHForceQuery + default: + return nil, errors.New(`invalid "echForceQuery": `, c.ECHForceQuery) + } config.EchForceQuery = c.ECHForceQuery config.EchConfigList = c.ECHConfigList if c.ECHSocketSettings != nil { diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index 4f7700ca..90427b8d 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -8,7 +8,6 @@ import ( "crypto/tls" "crypto/x509" "encoding/base64" - "github.com/xtls/xray-core/features/dns" "os" "slices" "strings" @@ -451,7 +450,7 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { if len(c.EchConfigList) > 0 || len(c.EchServerKeys) > 0 { err := ApplyECH(c, config) if err != nil { - if c.EchForceQuery || errors.Cause(err) != dns.ErrEmptyResponse { + if c.EchForceQuery == "full" { errors.LogError(context.Background(), err) } else { errors.LogInfo(context.Background(), err) diff --git a/transport/internet/tls/config.pb.go b/transport/internet/tls/config.pb.go index c30d5ef3..b93af678 100644 --- a/transport/internet/tls/config.pb.go +++ b/transport/internet/tls/config.pb.go @@ -220,7 +220,7 @@ type Config struct { VerifyPeerCertInNames []string `protobuf:"bytes,17,rep,name=verify_peer_cert_in_names,json=verifyPeerCertInNames,proto3" json:"verify_peer_cert_in_names,omitempty"` EchServerKeys []byte `protobuf:"bytes,18,opt,name=ech_server_keys,json=echServerKeys,proto3" json:"ech_server_keys,omitempty"` EchConfigList string `protobuf:"bytes,19,opt,name=ech_config_list,json=echConfigList,proto3" json:"ech_config_list,omitempty"` - EchForceQuery bool `protobuf:"varint,20,opt,name=ech_force_query,json=echForceQuery,proto3" json:"ech_force_query,omitempty"` + EchForceQuery string `protobuf:"bytes,20,opt,name=ech_force_query,json=echForceQuery,proto3" json:"ech_force_query,omitempty"` EchSocketSettings *internet.SocketConfig `protobuf:"bytes,21,opt,name=ech_socket_settings,json=echSocketSettings,proto3" json:"ech_socket_settings,omitempty"` } @@ -380,11 +380,11 @@ func (x *Config) GetEchConfigList() string { return "" } -func (x *Config) GetEchForceQuery() bool { +func (x *Config) GetEchForceQuery() string { if x != nil { return x.EchForceQuery } - return false + return "" } func (x *Config) GetEchSocketSettings() *internet.SocketConfig { @@ -483,7 +483,7 @@ var file_transport_internet_tls_config_proto_rawDesc = []byte{ 0x0a, 0x0f, 0x65, 0x63, 0x68, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x13, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x65, 0x63, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x0f, 0x65, 0x63, 0x68, 0x5f, 0x66, 0x6f, - 0x72, 0x63, 0x65, 0x5f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x18, 0x14, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x72, 0x63, 0x65, 0x5f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x18, 0x14, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x65, 0x63, 0x68, 0x46, 0x6f, 0x72, 0x63, 0x65, 0x51, 0x75, 0x65, 0x72, 0x79, 0x12, 0x55, 0x0a, 0x13, 0x65, 0x63, 0x68, 0x5f, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x5f, 0x73, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x15, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x78, 0x72, diff --git a/transport/internet/tls/config.proto b/transport/internet/tls/config.proto index c9906333..6d39bc56 100644 --- a/transport/internet/tls/config.proto +++ b/transport/internet/tls/config.proto @@ -98,7 +98,7 @@ message Config { string ech_config_list = 19; - bool ech_force_query = 20; + string ech_force_query = 20; SocketConfig ech_socket_settings = 21; } diff --git a/transport/internet/tls/ech.go b/transport/internet/tls/ech.go index 5069f06d..2cd07c9d 100644 --- a/transport/internet/tls/ech.go +++ b/transport/internet/tls/ech.go @@ -9,10 +9,6 @@ import ( "encoding/base64" "encoding/binary" "fmt" - utls "github.com/refraction-networking/utls" - "github.com/xtls/xray-core/common/crypto" - dns2 "github.com/xtls/xray-core/features/dns" - "golang.org/x/net/http2" "io" "net/http" "net/url" @@ -21,6 +17,11 @@ import ( "sync/atomic" "time" + utls "github.com/refraction-networking/utls" + "github.com/xtls/xray-core/common/crypto" + dns2 "github.com/xtls/xray-core/features/dns" + "golang.org/x/net/http2" + "github.com/miekg/dns" "github.com/xtls/reality" "github.com/xtls/reality/hpke" @@ -52,10 +53,18 @@ func ApplyECH(c *Config, config *tls.Config) error { // for client if len(c.EchConfigList) != 0 { + ECHForceQuery := c.EchForceQuery + switch ECHForceQuery { + case "none", "half", "full": + case "": + ECHForceQuery = "none" // default to none + default: + panic("Invalid ECHForceQuery: " + c.EchForceQuery) + } defer func() { // if failed to get ECHConfig, use an invalid one to make connection fail - if err != nil { - if c.EchForceQuery { + if err != nil || len(ECHConfig) == 0 { + if ECHForceQuery == "full" { ECHConfig = []byte{1, 1, 4, 5, 1, 4} } } @@ -106,32 +115,40 @@ type echConfigRecord struct { } var ( - // key value must be like this: "example.com|udp://1.1.1.1" + // The keys for both maps must be generated by ECHCacheKey(). GlobalECHConfigCache = utils.NewTypedSyncMap[string, *ECHConfigCache]() clientForECHDOH = utils.NewTypedSyncMap[string, *http.Client]() ) +// sockopt can be nil if not specified. +// if for clientForECHDOH, domain can be empty. +func ECHCacheKey(server, domain string, sockopt *internet.SocketConfig) string { + return server + "|" + domain + "|" + fmt.Sprintf("%p", sockopt) +} + // Update updates the ECH config for given domain and server. // this method is concurrent safe, only one update request will be sent, others get the cache. // if isLockedUpdate is true, it will not try to acquire the lock. -func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool, forceQuery bool, sockopt *internet.SocketConfig) ([]byte, error) { +func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool, forceQuery string, sockopt *internet.SocketConfig) ([]byte, error) { if !isLockedUpdate { c.UpdateLock.Lock() defer c.UpdateLock.Unlock() } // Double check cache after acquiring lock configRecord := c.configRecord.Load() - if configRecord.expire.After(time.Now()) { + if configRecord.expire.After(time.Now()) && configRecord.err == nil { errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain) return configRecord.config, configRecord.err } // Query ECH config from DNS server errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server) echConfig, ttl, err := dnsQuery(server, domain, sockopt) - if err != nil { - if forceQuery || ttl == 0 { - return nil, err - } + // if in "full", directly return + if err != nil && forceQuery == "full" { + return nil, err + } + if ttl == 0 { + ttl = dns2.DefaultTTL } configRecord = &echConfigRecord{ config: echConfig, @@ -144,8 +161,8 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo // QueryRecord returns the ECH config for given domain. // If the record is not in cache or expired, it will query the DNS server and update the cache. -func QueryRecord(domain string, server string, forceQuery bool, sockopt *internet.SocketConfig) ([]byte, error) { - GlobalECHConfigCacheKey := domain + "|" + server + "|" + fmt.Sprintf("%p", sockopt) +func QueryRecord(domain string, server string, forceQuery string, sockopt *internet.SocketConfig) ([]byte, error) { + GlobalECHConfigCacheKey := ECHCacheKey(server, domain, sockopt) echConfigCache, ok := GlobalECHConfigCache.Load(GlobalECHConfigCacheKey) if !ok { echConfigCache = &ECHConfigCache{} @@ -153,7 +170,7 @@ func QueryRecord(domain string, server string, forceQuery bool, sockopt *interne echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(GlobalECHConfigCacheKey, echConfigCache) } configRecord := echConfigCache.configRecord.Load() - if configRecord.expire.After(time.Now()) { + if configRecord.expire.After(time.Now()) && (configRecord.err == nil || forceQuery == "none") { errors.LogDebug(context.Background(), "Cache hit for domain: ", domain) return configRecord.config, configRecord.err } @@ -196,7 +213,7 @@ func dnsQuery(server string, domain string, sockopt *internet.SocketConfig) ([]b return nil, 0, err } var client *http.Client - serverKey := server + "|" + fmt.Sprintf("%p", sockopt) + serverKey := ECHCacheKey(server, "", sockopt) if client, _ = clientForECHDOH.Load(serverKey); client == nil { // All traffic sent by core should via xray's internet.DialSystem // This involves the behavior of some Android VPN GUI clients @@ -307,7 +324,8 @@ func dnsQuery(server string, domain string, sockopt *internet.SocketConfig) ([]b } } } - return nil, dns2.DefaultTTL, dns2.ErrEmptyResponse + // empty is valid, means no ECH config found + return nil, dns2.DefaultTTL, nil } // reference github.com/OmarTariq612/goech diff --git a/transport/internet/tls/ech_test.go b/transport/internet/tls/ech_test.go index 009e80a1..bdf87868 100644 --- a/transport/internet/tls/ech_test.go +++ b/transport/internet/tls/ech_test.go @@ -1,7 +1,6 @@ package tls import ( - "fmt" "io" "net/http" "strings" @@ -41,7 +40,7 @@ func TestECHDial(t *testing.T) { } wg.Wait() // check cache - echConfigCache, ok := GlobalECHConfigCache.Load("encryptedsni.com|udp://1.1.1.1" + "|" + fmt.Sprintf("%p", config.EchSocketSettings)) + echConfigCache, ok := GlobalECHConfigCache.Load(ECHCacheKey("udp://1.1.1.1", "encryptedsni.com", nil)) if !ok { t.Error("ECH config cache not found") @@ -60,22 +59,12 @@ func TestECHDial(t *testing.T) { func TestECHDialFail(t *testing.T) { config := &Config{ ServerName: "cloudflare.com", - EchConfigList: "udp://1.1.1.1", + EchConfigList: "udp://127.0.0.1", + EchForceQuery: "half", } - TLSConfig := config.GetTLSConfig() - TLSConfig.NextProtos = []string{"http/1.1"} - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: TLSConfig, - }, - } - resp, err := client.Get("https://cloudflare.com/cdn-cgi/trace") - common.Must(err) - defer resp.Body.Close() - _, err = io.ReadAll(resp.Body) - common.Must(err) + config.GetTLSConfig() // check cache - echConfigCache, ok := GlobalECHConfigCache.Load("cloudflare.com|udp://1.1.1.1" + "|" + fmt.Sprintf("%p", config.EchSocketSettings)) + echConfigCache, ok := GlobalECHConfigCache.Load(ECHCacheKey("udp://127.0.0.1", "cloudflare.com", nil)) if !ok { t.Error("ECH config cache not found") }