diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index e90a785d..90f53fac 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -412,8 +412,9 @@ type TLSConfig struct { MasterKeyLog string `json:"masterKeyLog"` ServerNameToVerify string `json:"serverNameToVerify"` VerifyPeerCertInNames []string `json:"verifyPeerCertInNames"` - ECHConfigList string `json:"echConfigList"` ECHServerKeys string `json:"echServerKeys"` + ECHConfigList string `json:"echConfigList"` + ECHForceQuery bool `json:"echForceQuery"` } // Build implements Buildable. @@ -485,8 +486,6 @@ func (c *TLSConfig) Build() (proto.Message, error) { } config.VerifyPeerCertInNames = c.VerifyPeerCertInNames - config.EchConfigList = c.ECHConfigList - if c.ECHServerKeys != "" { EchPrivateKey, err := base64.StdEncoding.DecodeString(c.ECHServerKeys) if err != nil { @@ -494,6 +493,8 @@ func (c *TLSConfig) Build() (proto.Message, error) { } config.EchServerKeys = EchPrivateKey } + config.EchForceQuery = c.ECHForceQuery + config.EchConfigList = c.ECHConfigList return config, nil } diff --git a/transport/internet/tls/config.pb.go b/transport/internet/tls/config.pb.go index fc719c7b..c7944ae0 100644 --- a/transport/internet/tls/config.pb.go +++ b/transport/internet/tls/config.pb.go @@ -217,8 +217,9 @@ type Config struct { // @Document After allow_insecure (automatically), if the server's cert can't be verified by any of these names, pinned_peer_certificate_chain_sha256 will be tried. // @Critical VerifyPeerCertInNames []string `protobuf:"bytes,17,rep,name=verify_peer_cert_in_names,json=verifyPeerCertInNames,proto3" json:"verify_peer_cert_in_names,omitempty"` - EchConfigList string `protobuf:"bytes,18,opt,name=ech_config_list,json=echConfigList,proto3" json:"ech_config_list,omitempty"` - EchServerKeys []byte `protobuf:"bytes,19,opt,name=ech_server_keys,json=echServerKeys,proto3" json:"ech_server_keys,omitempty"` + EchServerKeys []byte `protobuf:"bytes,18,opt,name=ech_server_keys,json=echServerKeys,proto3" json:"ech_server_keys,omitempty"` + EchConfigList string `protobuf:"bytes,19,opt,name=ech_config_list,json=echConfigList,proto3" json:"ech_config_list,omitempty"` + EchForceQuery bool `protobuf:"varint,20,opt,name=ech_force_query,json=echForceQuery,proto3" json:"ech_force_query,omitempty"` } func (x *Config) Reset() { @@ -363,6 +364,13 @@ func (x *Config) GetVerifyPeerCertInNames() []string { return nil } +func (x *Config) GetEchServerKeys() []byte { + if x != nil { + return x.EchServerKeys + } + return nil +} + func (x *Config) GetEchConfigList() string { if x != nil { return x.EchConfigList @@ -370,11 +378,11 @@ func (x *Config) GetEchConfigList() string { return "" } -func (x *Config) GetEchServerKeys() []byte { +func (x *Config) GetEchForceQuery() bool { if x != nil { - return x.EchServerKeys + return x.EchForceQuery } - return nil + return false } var File_transport_internet_tls_config_proto protoreflect.FileDescriptor @@ -408,7 +416,7 @@ var file_transport_internet_tls_config_proto_rawDesc = []byte{ 0x4e, 0x43, 0x49, 0x50, 0x48, 0x45, 0x52, 0x4d, 0x45, 0x4e, 0x54, 0x10, 0x00, 0x12, 0x14, 0x0a, 0x10, 0x41, 0x55, 0x54, 0x48, 0x4f, 0x52, 0x49, 0x54, 0x59, 0x5f, 0x56, 0x45, 0x52, 0x49, 0x46, 0x59, 0x10, 0x01, 0x12, 0x13, 0x0a, 0x0f, 0x41, 0x55, 0x54, 0x48, 0x4f, 0x52, 0x49, 0x54, 0x59, - 0x5f, 0x49, 0x53, 0x53, 0x55, 0x45, 0x10, 0x02, 0x22, 0xea, 0x06, 0x0a, 0x06, 0x43, 0x6f, 0x6e, + 0x5f, 0x49, 0x53, 0x53, 0x55, 0x45, 0x10, 0x02, 0x22, 0x92, 0x07, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x25, 0x0a, 0x0e, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x5f, 0x69, 0x6e, 0x73, 0x65, 0x63, 0x75, 0x72, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x49, 0x6e, 0x73, 0x65, 0x63, 0x75, 0x72, 0x65, 0x12, 0x4a, 0x0a, 0x0b, 0x63, 0x65, @@ -458,20 +466,22 @@ var file_transport_internet_tls_config_proto_rawDesc = []byte{ 0x65, 0x72, 0x69, 0x66, 0x79, 0x5f, 0x70, 0x65, 0x65, 0x72, 0x5f, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x69, 0x6e, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x18, 0x11, 0x20, 0x03, 0x28, 0x09, 0x52, 0x15, 0x76, 0x65, 0x72, 0x69, 0x66, 0x79, 0x50, 0x65, 0x65, 0x72, 0x43, 0x65, 0x72, 0x74, 0x49, 0x6e, - 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x12, 0x26, 0x0a, 0x0f, 0x65, 0x63, 0x68, 0x5f, 0x63, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, - 0x65, 0x63, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x26, 0x0a, - 0x0f, 0x65, 0x63, 0x68, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x6b, 0x65, 0x79, 0x73, - 0x18, 0x13, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0d, 0x65, 0x63, 0x68, 0x53, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x4b, 0x65, 0x79, 0x73, 0x42, 0x73, 0x0a, 0x1f, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, - 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, - 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x74, 0x6c, 0x73, 0x50, 0x01, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, - 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, - 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, - 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2f, 0x74, 0x6c, 0x73, 0xaa, 0x02, 0x1b, 0x58, - 0x72, 0x61, 0x79, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x49, 0x6e, - 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x54, 0x6c, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x12, 0x26, 0x0a, 0x0f, 0x65, 0x63, 0x68, 0x5f, 0x73, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x5f, 0x6b, 0x65, 0x79, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0d, + 0x65, 0x63, 0x68, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x73, 0x12, 0x26, 0x0a, + 0x0f, 0x65, 0x63, 0x68, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x5f, 0x6c, 0x69, 0x73, 0x74, + 0x18, 0x13, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x65, 0x63, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x0f, 0x65, 0x63, 0x68, 0x5f, 0x66, 0x6f, 0x72, + 0x63, 0x65, 0x5f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x18, 0x14, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, + 0x65, 0x63, 0x68, 0x46, 0x6f, 0x72, 0x63, 0x65, 0x51, 0x75, 0x65, 0x72, 0x79, 0x42, 0x73, 0x0a, + 0x1f, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, + 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x74, 0x6c, 0x73, + 0x50, 0x01, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, + 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x74, 0x72, + 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, + 0x2f, 0x74, 0x6c, 0x73, 0xaa, 0x02, 0x1b, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x54, 0x72, 0x61, 0x6e, + 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x54, + 0x6c, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/transport/internet/tls/config.proto b/transport/internet/tls/config.proto index 0e78651f..97c25d57 100644 --- a/transport/internet/tls/config.proto +++ b/transport/internet/tls/config.proto @@ -92,7 +92,9 @@ message Config { */ repeated string verify_peer_cert_in_names = 17; - string ech_config_list = 18; + bytes ech_server_keys = 18; - bytes ech_server_keys = 19; -} + string ech_config_list = 19; + + bool ech_force_query = 20; +} \ No newline at end of file diff --git a/transport/internet/tls/ech.go b/transport/internet/tls/ech.go index 982235db..2a78fd2e 100644 --- a/transport/internet/tls/ech.go +++ b/transport/internet/tls/ech.go @@ -32,8 +32,26 @@ func ApplyECH(c *Config, config *tls.Config) error { nameToQuery := c.ServerName var DNSServer string + // for server + if len(c.EchServerKeys) != 0 { + KeySets, err := ConvertToGoECHKeys(c.EchServerKeys) + if err != nil { + return errors.New("Failed to unmarshal ECHKeySetList: ", err) + } + config.EncryptedClientHelloKeys = KeySets + } + // for client if len(c.EchConfigList) != 0 { + defer func() { + // if failed to get ECHConfig, use an invalid one to make connection fail + if err != nil { + if c.EchForceQuery { + ECHConfig = []byte{1, 1, 4, 5, 1, 4} + } + } + config.EncryptedClientHelloConfigList = ECHConfig + }() // direct base64 config if strings.Contains(c.EchConfigList, "://") { // query config from dns @@ -51,7 +69,7 @@ func ApplyECH(c *Config, config *tls.Config) error { if nameToQuery == "" { return errors.New("Using DNS for ECH Config needs serverName or use Server format example.com+https://1.1.1.1/dns-query") } - ECHConfig, err = QueryRecord(nameToQuery, DNSServer) + ECHConfig, err = QueryRecord(nameToQuery, DNSServer, c.EchForceQuery) if err != nil { return err } @@ -61,17 +79,6 @@ func ApplyECH(c *Config, config *tls.Config) error { return errors.New("Failed to unmarshal ECHConfigList: ", err) } } - - config.EncryptedClientHelloConfigList = ECHConfig - } - - // for server - if len(c.EchServerKeys) != 0 { - KeySets, err := ConvertToGoECHKeys(c.EchServerKeys) - if err != nil { - return errors.New("Failed to unmarshal ECHKeySetList: ", err) - } - config.EncryptedClientHelloKeys = KeySets } return nil @@ -86,9 +93,11 @@ type ECHConfigCache struct { type echConfigRecord struct { config []byte expire time.Time + err error } var ( + // key value must be like this: "example.com|udp://1.1.1.1" GlobalECHConfigCache = utils.NewTypedSyncMap[string, *ECHConfigCache]() clientForECHDOH = utils.NewTypedSyncMap[string, *http.Client]() ) @@ -96,7 +105,7 @@ var ( // Update updates the ECH config for given domain and server. // this method is concurrent safe, only one update request will be sent, others get the cache. // if isLockedUpdate is true, it will not try to acquire the lock. -func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool) ([]byte, error) { +func (c *ECHConfigCache) Update(domain string, server string, forceQuery bool, isLockedUpdate bool) ([]byte, error) { if !isLockedUpdate { c.UpdateLock.Lock() defer c.UpdateLock.Unlock() @@ -105,13 +114,23 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo configRecord := c.configRecord.Load() if configRecord.expire.After(time.Now()) { errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain) - return configRecord.config, nil + return configRecord.config, configRecord.err } // Query ECH config from DNS server errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server) echConfig, ttl, err := dnsQuery(server, domain) if err != nil { - return nil, err + if forceQuery { + return nil, err + } else { + configRecord = &echConfigRecord{ + config: nil, + expire: time.Now().Add(10 * time.Minute), + err: err, + } + c.configRecord.Store(configRecord) + return echConfig, err + } } configRecord = &echConfigRecord{ config: echConfig, @@ -123,30 +142,31 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo // QueryRecord returns the ECH config for given domain. // If the record is not in cache or expired, it will query the DNS server and update the cache. -func QueryRecord(domain string, server string) ([]byte, error) { - echConfigCache, ok := GlobalECHConfigCache.Load(domain) +func QueryRecord(domain string, server string, forceQuery bool) ([]byte, error) { + GlobalECHConfigCacheKey := domain + "|" + server + echConfigCache, ok := GlobalECHConfigCache.Load(GlobalECHConfigCacheKey) if !ok { echConfigCache = &ECHConfigCache{} echConfigCache.configRecord.Store(&echConfigRecord{}) - echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(domain, echConfigCache) + echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(GlobalECHConfigCacheKey, echConfigCache) } configRecord := echConfigCache.configRecord.Load() if configRecord.expire.After(time.Now()) { errors.LogDebug(context.Background(), "Cache hit for domain: ", domain) - return configRecord.config, nil + return configRecord.config, configRecord.err } // If expire is zero value, it means we are in initial state, wait for the query to finish // otherwise return old value immediately and update in a goroutine // but if the cache is too old, wait for update if configRecord.expire == (time.Time{}) || configRecord.expire.Add(time.Hour*6).Before(time.Now()) { - return echConfigCache.Update(domain, server, false) + return echConfigCache.Update(domain, server, false, forceQuery) } else { // If someone already acquired the lock, it means it is updating, do not start another update goroutine if echConfigCache.UpdateLock.TryLock() { go func() { defer echConfigCache.UpdateLock.Unlock() - echConfigCache.Update(domain, server, true) + echConfigCache.Update(domain, server, true, forceQuery) }() } return configRecord.config, nil @@ -165,7 +185,7 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) { m.Id = 0 msg, err := m.Pack() if err != nil { - return []byte{}, 0, err + return nil, 0, err } var client *http.Client if client, _ = clientForECHDOH.Load(server); client == nil { @@ -194,20 +214,20 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) { } req, err := http.NewRequest("POST", server, bytes.NewReader(msg)) if err != nil { - return []byte{}, 0, err + return nil, 0, err } req.Header.Set("Content-Type", "application/dns-message") resp, err := client.Do(req) if err != nil { - return []byte{}, 0, err + return nil, 0, err } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - return []byte{}, 0, err + return nil, 0, err } if resp.StatusCode != http.StatusOK { - return []byte{}, 0, errors.New("query failed with response code:", resp.StatusCode) + return nil, 0, errors.New("query failed with response code:", resp.StatusCode) } dnsResolve = respBody } else if strings.HasPrefix(server, "udp://") { // for classic udp dns server @@ -231,24 +251,25 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) { } }() if err != nil { - return []byte{}, 0, err + return nil, 0, err } msg, err := m.Pack() if err != nil { - return []byte{}, 0, err + return nil, 0, err } conn.Write(msg) udpResponse := make([]byte, 512) + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _, err = conn.Read(udpResponse) if err != nil { - return []byte{}, 0, err + return nil, 0, err } dnsResolve = udpResponse } respMsg := new(dns.Msg) err := respMsg.Unpack(dnsResolve) if err != nil { - return []byte{}, 0, errors.New("failed to unpack dns response for ECH: ", err) + return nil, 0, errors.New("failed to unpack dns response for ECH: ", err) } if len(respMsg.Answer) > 0 { for _, answer := range respMsg.Answer { @@ -262,7 +283,7 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) { } } } - return []byte{}, 0, errors.New("no ech record found") + return nil, 0, errors.New("no ech record found") } // reference github.com/OmarTariq612/goech diff --git a/transport/internet/tls/ech_test.go b/transport/internet/tls/ech_test.go index f0bb3c56..67aa67df 100644 --- a/transport/internet/tls/ech_test.go +++ b/transport/internet/tls/ech_test.go @@ -1,4 +1,4 @@ -package tls_test +package tls import ( "io" @@ -8,13 +8,12 @@ import ( "testing" "github.com/xtls/xray-core/common" - . "github.com/xtls/xray-core/transport/internet/tls" ) func TestECHDial(t *testing.T) { config := &Config{ - ServerName: "encryptedsni.com", - EchConfigList: "udp://1.1.1.1", + ServerName: "cloudflare.com", + EchConfigList: "encryptedsni.com+udp://1.1.1.1", } // test concurrent Dial(to test cache problem) wg := sync.WaitGroup{} @@ -28,7 +27,7 @@ func TestECHDial(t *testing.T) { TLSClientConfig: TLSConfig, }, } - resp, err := client.Get("https://encryptedsni.com/cdn-cgi/trace") + resp, err := client.Get("https://cloudflare.com/cdn-cgi/trace") common.Must(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -40,4 +39,51 @@ func TestECHDial(t *testing.T) { }() } wg.Wait() + // check cache + echConfigCache, ok := GlobalECHConfigCache.Load("encryptedsni.com|udp://1.1.1.1") + if !ok { + t.Error("ECH config cache not found") + + } + ok = echConfigCache.UpdateLock.TryLock() + if !ok { + t.Error("ECH config cache dead lock detected") + } + echConfigCache.UpdateLock.Unlock() + configRecord := echConfigCache.configRecord.Load() + if configRecord == nil { + t.Error("ECH config record not found in cache") + } +} + +func TestECHDialFail(t *testing.T) { + config := &Config{ + ServerName: "cloudflare.com", + EchConfigList: "udp://1.1.1.1", + } + TLSConfig := config.GetTLSConfig() + TLSConfig.NextProtos = []string{"http/1.1"} + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: TLSConfig, + }, + } + resp, err := client.Get("https://cloudflare.com/cdn-cgi/trace") + common.Must(err) + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + common.Must(err) + // check cache + echConfigCache, ok := GlobalECHConfigCache.Load("cloudflare.com|udp://1.1.1.1") + if !ok { + t.Error("ECH config cache not found") + } + configRecord := echConfigCache.configRecord.Load() + if configRecord == nil { + t.Error("ECH config record not found in cache") + return + } + if configRecord.err == nil { + t.Error("unexpected nil error in ECH config record") + } }