diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 97a6a4fd57..c8d8bf6e23 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -34,6 +34,7 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/autoscaling" @@ -569,6 +570,126 @@ type CloudConfig struct { //yourself in an non-AWS cloud and open an issue, please indicate that in the //issue body. DisableStrictZoneCheck bool + + // Allows AWS endpoints to be overridden + // Useful in deployments to private edge nodes where amazonaws.com does not resolve + OverrideEndpoints bool + + // Delimiter to use to separate servicename from its configuration parameters + // NOTE: semi-colon ';' truncates the input line in INI files, do not use ';' + // Defaults "|" + ServicenameDelimiter string + + // Delimiter to use to separate region of occurrence, url and signing region for each override + // NOTE: semi-colon ';' truncates the input line in INI files, do not use ';' + // Defaults to "," + OverrideSeparator string + + // Delimiter to use to separate overridden services + // NOTE: semi-colon ';' truncates the input line in INI files, do not use ';' + // Defaults to "&" + ServiceDelimiter string + + // These are of format servicename ServicenameDelimiter url OverrideSeparator signing_region ServiceDelimiter nextservice + // s3|region1, https://s3.foo.bar, some signing_region & ec2|region1, https://ec2.foo.bar, signing_region + ServiceOverrides string + } +} + +const ( + ServicenameDelimiterDefault = "|" + ServicesDelimiterDefault = "&" + OverrideSeparatorDefault = "," +) + +type CustomEndpoint struct { + Endpoint string + SigningRegion string +} + +var overridesActive = false +var overrides map[string]CustomEndpoint + +func IsOverridesActive() bool { + return overridesActive +} + +func SetOverridesDefaults(cfg *CloudConfig) error { + if cfg.Global.OverrideEndpoints { + if cfg.Global.ServiceDelimiter == "" { + cfg.Global.ServiceDelimiter = ServicesDelimiterDefault + } else if cfg.Global.ServiceDelimiter == ";" { + return fmt.Errorf("semi-colon may not be used as a service delimiter, it truncates the input") + } + if cfg.Global.ServicenameDelimiter == "" { + cfg.Global.ServicenameDelimiter = ServicenameDelimiterDefault + } else if cfg.Global.ServicenameDelimiter == ";" { + return fmt.Errorf("semi-colon may not be used as a service name delimiter, it truncates the input") + } + if cfg.Global.OverrideSeparator == "" { + cfg.Global.OverrideSeparator = OverrideSeparatorDefault + } else if cfg.Global.OverrideSeparator == ";" { + return fmt.Errorf("semi-colon may not be used as a override separator, it truncates the input") + } + } + return nil +} + +func MakeRegionEndpointSignature(serviceName, region string) string { + return fmt.Sprintf("%s__%s", strings.TrimSpace(serviceName), strings.TrimSpace(region)) +} + +func ParseOverrides(cfg *CloudConfig) error { + if cfg.Global.OverrideEndpoints { + if err := SetOverridesDefaults(cfg); err != nil { + return err + } + overrides = make(map[string]CustomEndpoint) + allOverrides := strings.Split(cfg.Global.ServiceOverrides, cfg.Global.ServiceDelimiter) + for _, o := range allOverrides { + if idx := strings.Index(o, cfg.Global.ServicenameDelimiter); idx != -1 { + name := strings.TrimSpace(o[:idx]) + values := o[idx+1:] + tuple := strings.Split(values, cfg.Global.OverrideSeparator) + if len(tuple) != 3 { + return errors.New(fmt.Sprintf("3 parameters (region, url, signing region) are required for [%s] in %s", + name, o)) + } + signature := MakeRegionEndpointSignature(name, tuple[0]) + overrides[signature] = CustomEndpoint{Endpoint: strings.TrimSpace(tuple[1]), SigningRegion: strings.TrimSpace(tuple[2])} + } else { + cfg.Global.OverrideEndpoints = false + overridesActive = false + return errors.New(fmt.Sprintf("Unable to find ServicenameSeparator [%s] in %s", + cfg.Global.ServicenameDelimiter, o)) + } + } + overridesActive = true + } else { + overridesActive = false + } + return nil +} + +func loadCustomResolver() func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + defaultResolver := endpoints.DefaultResolver() + defaultResolverFn := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + return defaultResolver.EndpointFor(service, region, optFns...) + } + if IsOverridesActive() { + customResolverFn := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + signature := MakeRegionEndpointSignature(service, region) + if ep, ok := overrides[signature]; ok { + return endpoints.ResolvedEndpoint{ + URL: ep.Endpoint, + SigningRegion: ep.SigningRegion, + }, nil + } + return defaultResolver.EndpointFor(service, region, optFns...) + } + return customResolverFn + } else { + return defaultResolverFn } } @@ -651,7 +772,8 @@ func (p *awsSDKProvider) Compute(regionName string) (EC2, error) { Region: ®ionName, Credentials: p.creds, } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true) + awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). + WithEndpointResolver(endpoints.ResolverFunc(loadCustomResolver())) sess, err := session.NewSession(awsConfig) if err != nil { @@ -672,7 +794,8 @@ func (p *awsSDKProvider) LoadBalancing(regionName string) (ELB, error) { Region: ®ionName, Credentials: p.creds, } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true) + awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). + WithEndpointResolver(endpoints.ResolverFunc(loadCustomResolver())) sess, err := session.NewSession(awsConfig) if err != nil { @@ -689,7 +812,8 @@ func (p *awsSDKProvider) LoadBalancingV2(regionName string) (ELBV2, error) { Region: ®ionName, Credentials: p.creds, } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true) + awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). + WithEndpointResolver(endpoints.ResolverFunc(loadCustomResolver())) sess, err := session.NewSession(awsConfig) if err != nil { @@ -707,7 +831,8 @@ func (p *awsSDKProvider) Autoscaling(regionName string) (ASG, error) { Region: ®ionName, Credentials: p.creds, } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true) + awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). + WithEndpointResolver(endpoints.ResolverFunc(loadCustomResolver())) sess, err := session.NewSession(awsConfig) if err != nil { @@ -721,7 +846,7 @@ func (p *awsSDKProvider) Autoscaling(regionName string) (ASG, error) { } func (p *awsSDKProvider) Metadata() (EC2Metadata, error) { - sess, err := session.NewSession(&aws.Config{}) + sess, err := session.NewSession(&aws.Config{EndpointResolver: endpoints.ResolverFunc(loadCustomResolver())}) if err != nil { return nil, fmt.Errorf("unable to initialize AWS session: %v", err) } @@ -735,7 +860,8 @@ func (p *awsSDKProvider) KeyManagement(regionName string) (KMS, error) { Region: ®ionName, Credentials: p.creds, } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true) + awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). + WithEndpointResolver(endpoints.ResolverFunc(loadCustomResolver())) sess, err := session.NewSession(awsConfig) if err != nil { @@ -960,6 +1086,10 @@ func init() { return nil, fmt.Errorf("unable to read AWS cloud provider config file: %v", err) } + if err = ParseOverrides(cfg); err != nil { + return nil, fmt.Errorf("unable to parse custom endpoint overrides: %v", err) + } + sess, err := session.NewSession(&aws.Config{}) if err != nil { return nil, fmt.Errorf("unable to initialize AWS session: %v", err) diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index 4d12932859..6a3aefedab 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -186,6 +186,295 @@ func TestReadAWSCloudConfig(t *testing.T) { } } +func TestOverridesActiveConfig(t *testing.T) { + tests := []struct { + name string + + + reader io.Reader + aws Services + + expectError bool + active bool + servicesOverridden []string + regions []string + }{ + { + "No overrides in config", + strings.NewReader("[global]\nServiceOverrides=s3|sregion, https://s3.foo.bar, sregion"), + nil, + false, false, + []string{}, []string{}, + }, + { + "Missing Servicename Separator", + strings.NewReader("[global]\nOverrideEndpoints=true\nServiceOverrides=s3sregion, https://s3.foo.bar, sregion"), + nil, + true, false, + []string{}, []string{}, + }, + { + "Missing Service Region", + strings.NewReader("[global]\nOverrideEndpoints=true\nServiceOverrides=s3|https://s3.foo.bar, sregion"), + nil, + true, false, + []string{}, []string{}, + }, + { + "Semi-colon in service delimiter", + strings.NewReader("[global]\nOverrideEndpoints=true\nServiceDelimiter=;"), + nil, + true, false, + []string{}, []string{}, + }, + { + "Semi-colon in service name delimiter", + strings.NewReader("[global]\nOverrideEndpoints=true\nServicenameDelimiter=;"), + nil, + true, false, + []string{}, []string{}, + }, + { + "Semi-colon in service name delimiter", + strings.NewReader("[global]\nOverrideEndpoints=true\nOverrideSeparator=;"), + nil, + true, false, + []string{}, []string{}, + }, + { + "Active Overrides", + strings.NewReader("[global]\nOverrideEndpoints=true\nServiceOverrides=s3|sregion, https://s3.foo.bar, sregion"), + nil, + false, true, + []string{"s3"}, []string{"sregion"}, + }, + { + "Multiple Overriden Services", + strings.NewReader("[global]\nOverrideEndpoints=true\n" + + "ServiceOverrides=s3|sregion, https://s3.foo.bar, sregion & ec2|sregion, https://ec2.foo.bar, sregion"), + nil, + false, true, + []string{"s3", "ec2"}, []string{"sregion", "sregion"}, + }, + { + "Multiple Overriden Services in Multiple regions", + strings.NewReader("[global]\nOverrideEndpoints=true\n" + + "ServiceOverrides=s3|region1, https://s3.foo.bar, sregion & ec2|region2, https://ec2.foo.bar, sregion"), + nil, + false, true, + []string{"s3", "ec2"}, []string{"region1", "region2"}, + }, + { + "Multiple regions, Same Service", + strings.NewReader("[global]\nOverrideEndpoints=true\n" + + "ServiceOverrides=s3|region1, https://s3.foo.bar, sregion & s3|region2, https://s3.foo.bar, sregion"), + nil, + false, true, + []string{"s3", "s3"}, []string{"region1", "region2"}, + }, + } + + for _, test := range tests { + t.Logf("Running test case %s", test.name) + cfg, err := readAWSCloudConfig(test.reader) + if err == nil { + err = ParseOverrides(cfg) + } + if test.expectError { + if err == nil { + t.Errorf("Should error for case %s (cfg=%v)", test.name, cfg) + } + if IsOverridesActive() != test.active { + t.Errorf("Incorrect active flag (%v vs %v) for case: %s", + IsOverridesActive(), test.active, test.name) + } + } else { + if err != nil { + t.Errorf("Should succeed for case: %s", test.name) + } + if IsOverridesActive() != test.active { + t.Errorf("Incorrect active flag (%v vs %v) for case: %s", + IsOverridesActive(), test.active, test.name) + } + if len(overrides) != len(test.servicesOverridden) { + t.Errorf("Expected %d overridden services, received %d for case %s", + len(test.servicesOverridden), len(overrides), test.name) + } else { + for i, name := range test.servicesOverridden { + signature := MakeRegionEndpointSignature(name, test.regions[i]) + ep, ok := overrides[signature] + if !ok { + t.Errorf("Missing override for service %s in case %s", + name, test.name) + } else { + if ep.SigningRegion != "sregion" { + t.Errorf("Expected signing region 'sregion', received '%s' for case %s", + ep.SigningRegion, test.name) + } + targetName := fmt.Sprintf("https://%s.foo.bar", name) + if ep.Endpoint != targetName { + t.Errorf("Expected Endpoint '%s', received '%s' for case %s", + targetName, ep.Endpoint, test.name) + } + + fn := loadCustomResolver() + ep1, e := fn(name, test.regions[i], nil) + if e != nil { + t.Errorf("Expected a valid endpoint for %s in case %s", + name, test.name) + } else { + targetName := fmt.Sprintf("https://%s.foo.bar", name) + if ep1.URL != targetName { + t.Errorf("Expected endpoint url: %s, received %s in case %s", + targetName, ep1.URL, test.name) + } + if ep1.SigningRegion != "sregion" { + t.Errorf("Expected signing region 'sregion', received '%s' in case %s", + ep1.SigningRegion, test.name) + } + } + } + } + } + } + } +} + +func TestOverridesDefaults(t *testing.T) { + tests := []struct { + name string + + configString string + + expectError bool + active bool + servicesOverridden []string + defaults []string + }{ + { + "Bad Servicename Delimiter", + "[global]\nOverrideEndpoints=true\n" + + "ServiceOverrides=s3|sregion, https://s3.foo.bar, sregion\n" + + "ServicenameDelimiter=?", + true, false, + []string{}, + []string{";", "?", ","}, + }, + { + "Custom ServicenameDelimiter", + "[global]\nOverrideEndpoints=true\n" + + "ServiceOverrides=s3?sregion, https://s3.foo.bar, sregion\n" + + "ServicenameDelimiter=?", + false, true, + []string{"s3"}, + []string{"&", "?", ","}, + }, + { + "Custom OverrideSeparator", + "[global]\nOverrideEndpoints=true\n" + + "ServiceOverrides=s3|sregion + https://s3.foo.bar + sregion \n" + + "OverrideSeparator=+", + false, true, + []string{"s3"}, + []string{"&", "|", "+"}, + }, + { + "Custom Services Delimiter", + "[global]\nOverrideEndpoints=true\n" + + "ServiceOverrides=s3|sregion, https://s3.foo.bar, sregion + ec2|sregion, https://ec2.foo.bar , sregion\n" + + "ServiceDelimiter=+", + false, true, + []string{"s3", "ec2"}, + []string{"+", "|", ","}, + }, + { + "Active Overrides", + "[global]\nOverrideEndpoints=true\n" + + "ServiceOverrides=s3|sregion, https://s3.foo.bar , sregion & ec2|sregion, https://ec2.foo.bar, sregion", + false, true, + []string{"s3", "ec2"}, + []string{"&", "|", ","}, + }, + } + + for _, test := range tests { + t.Logf("Running test case %s", test.name) + cfg, err := readAWSCloudConfig(strings.NewReader(test.configString)) + if err == nil { + err = ParseOverrides(cfg) + } + if test.expectError { + if err == nil { + t.Errorf("Should error for case %s (cfg=%v)", test.name, cfg) + } + if IsOverridesActive() != test.active { + t.Errorf("Incorrect active flag (%v vs %v) for case: %s", + IsOverridesActive(), test.active, test.name) + } + } else { + if err != nil { + t.Errorf("Should succeed for case: %s", test.name) + } + if IsOverridesActive() != test.active { + t.Errorf("Incorrect active flag (%v vs %v) for case: %s", + IsOverridesActive(), test.active, test.name) + } + if cfg.Global.ServiceDelimiter != test.defaults[0] { + t.Errorf("Incorrect ServiceDelimter (%s vs %s) for case %s", + cfg.Global.ServiceDelimiter, test.defaults[0], test.name) + } + if cfg.Global.ServicenameDelimiter != test.defaults[1] { + t.Errorf("Incorrect ServicenameDelimiter (%s vs %s) for case %s", + cfg.Global.ServicenameDelimiter, test.defaults[1], test.name) + } + if cfg.Global.OverrideSeparator != test.defaults[2] { + t.Errorf("Incorrect OverrideSeparator (%s vs %s) for case %s", + cfg.Global.OverrideSeparator, test.defaults[2], test.name) + } + if len(overrides) != len(test.servicesOverridden) { + t.Errorf("Expected %d overridden services, received %d for case %s", + len(test.servicesOverridden), len(overrides), test.name) + } else { + for _, name := range test.servicesOverridden { + signature := MakeRegionEndpointSignature(name, "sregion") + ep, ok := overrides[signature] + if !ok { + t.Errorf("Missing override for service %s in case %s", + name, test.name) + } else { + if ep.SigningRegion != "sregion" { + t.Errorf("Expected signing region 'sregion', received '%s' for case %s", + ep.SigningRegion, test.name) + } + targetName := fmt.Sprintf("https://%s.foo.bar", name) + if ep.Endpoint != targetName { + t.Errorf("Expected Endpoint '%s', received '%s' for case %s", + targetName, ep.Endpoint, test.name) + } + + fn := loadCustomResolver() + ep1, e := fn(name, "sregion", nil) + if e != nil { + t.Errorf("Expected a valid endpoint for %s in case %s", + name, test.name) + } else { + targetName := fmt.Sprintf("https://%s.foo.bar", name) + if ep1.URL != targetName { + t.Errorf("Expected endpoint url: %s, received %s in case %s", + targetName, ep1.URL, test.name) + } + if ep1.SigningRegion != "sregion" { + t.Errorf("Expected signing region 'sregion', received '%s' in case %s", + ep1.SigningRegion, test.name) + } + } + } + } + } + } + } +} + func TestNewAWSCloud(t *testing.T) { tests := []struct { name string