diff --git a/pkg/cloudprovider/providers/aws/BUILD b/pkg/cloudprovider/providers/aws/BUILD index eba94c8da4..8d9a8da5b4 100644 --- a/pkg/cloudprovider/providers/aws/BUILD +++ b/pkg/cloudprovider/providers/aws/BUILD @@ -18,6 +18,7 @@ go_library( "aws_routes.go", "aws_utils.go", "device_allocator.go", + "instances.go", "log_handler.go", "regions.go", "retry_handler.go", @@ -58,6 +59,7 @@ go_test( srcs = [ "aws_test.go", "device_allocator_test.go", + "instances_test.go", "regions_test.go", "retry_handler_test.go", "tags_test.go", diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index eec475e7f7..994d671ecf 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -204,7 +204,6 @@ type Services interface { type EC2 interface { // Query EC2 for instances matching the filter DescribeInstances(request *ec2.DescribeInstancesInput) ([]*ec2.Instance, error) - DescribeAddresses(request *ec2.DescribeAddressesInput) ([]*ec2.Address, error) // Attach a volume to an instance AttachVolume(*ec2.AttachVolumeInput) (*ec2.VolumeAttachment, error) @@ -609,20 +608,6 @@ func (s *awsSdkEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*e return results, nil } -// Implementation of EC2.DescribeAddresses -func (s *awsSdkEC2) DescribeAddresses(request *ec2.DescribeAddressesInput) ([]*ec2.Address, error) { - requestTime := time.Now() - response, err := s.ec2.DescribeAddresses(request) - if err != nil { - recordAwsMetric("describe_address", 0, err) - return nil, fmt.Errorf("error listing AWS addresses: %v", err) - } - - timeTaken := time.Since(requestTime).Seconds() - recordAwsMetric("describe_address", timeTaken, nil) - return response.Addresses, nil -} - // Implements EC2.DescribeSecurityGroups func (s *awsSdkEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) { // Security groups are not paged @@ -996,38 +981,51 @@ func (c *Cloud) NodeAddresses(name types.NodeName) ([]v1.NodeAddress, error) { return addresses, nil } + instance, err := c.getInstanceByNodeName(name) if err != nil { return nil, fmt.Errorf("getInstanceByNodeName failed for %q with %v", name, err) } + return extractNodeAddresses(instance) +} + +// extractNodeAddresses maps the instance information from EC2 to an array of NodeAddresses +func extractNodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) { + // Not clear if the order matters here, but we might as well indicate a sensible preference order + + if instance == nil { + return nil, fmt.Errorf("nil instance passed to extractNodeAddresses") + } addresses := []v1.NodeAddress{} - if !isNilOrEmpty(instance.PrivateIpAddress) { - ipAddress := *instance.PrivateIpAddress - ip := net.ParseIP(ipAddress) + privateIPAddress := aws.StringValue(instance.PrivateIpAddress) + if privateIPAddress != "" { + ip := net.ParseIP(privateIPAddress) if ip == nil { - return nil, fmt.Errorf("EC2 instance had invalid private address: %s (%s)", orEmpty(instance.InstanceId), ipAddress) + return nil, fmt.Errorf("EC2 instance had invalid private address: %s (%s)", aws.StringValue(instance.InstanceId), privateIPAddress) } addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalIP, Address: ip.String()}) } // TODO: Other IP addresses (multiple ips)? - if !isNilOrEmpty(instance.PublicIpAddress) { - ipAddress := *instance.PublicIpAddress - ip := net.ParseIP(ipAddress) + publicIPAddress := aws.StringValue(instance.PublicIpAddress) + if publicIPAddress != "" { + ip := net.ParseIP(publicIPAddress) if ip == nil { - return nil, fmt.Errorf("EC2 instance had invalid public address: %s (%s)", orEmpty(instance.InstanceId), ipAddress) + return nil, fmt.Errorf("EC2 instance had invalid public address: %s (%s)", aws.StringValue(instance.InstanceId), publicIPAddress) } addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalIP, Address: ip.String()}) } - if !isNilOrEmpty(instance.PrivateDnsName) { - addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalDNS, Address: *instance.PrivateDnsName}) + privateDNSName := aws.StringValue(instance.PrivateDnsName) + if privateDNSName != "" { + addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalDNS, Address: privateDNSName}) } - if !isNilOrEmpty(instance.PublicDnsName) { - addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalDNS, Address: *instance.PublicDnsName}) + publicDNSName := aws.StringValue(instance.PublicDnsName) + if publicDNSName != "" { + addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalDNS, Address: publicDNSName}) } return addresses, nil @@ -1037,45 +1035,17 @@ func (c *Cloud) NodeAddresses(name types.NodeName) ([]v1.NodeAddress, error) { // This method will not be called from the node that is requesting this ID. i.e. metadata service // and other local methods cannot be used here func (c *Cloud) NodeAddressesByProviderID(providerID string) ([]v1.NodeAddress, error) { - instanceID, error := instanceIDFromProviderID(providerID) - - if error != nil { - return nil, error + instanceID, err := kubernetesInstanceID(providerID).mapToAWSInstanceID() + if err != nil { + return nil, err } - addresses, error := c.describeAddressesByInstanceID(instanceID) - - if error != nil { - return nil, error + instance, err := describeInstance(c.ec2, instanceID) + if err != nil { + return nil, err } - instances, error := c.describeInstancesByInstanceID(instanceID) - - if error != nil { - return nil, error - } - - nodeAddresses := []v1.NodeAddress{} - - for _, address := range addresses { - convertedAddress, error := convertAwsAddress(address) - if error != nil { - return nil, error - } - - nodeAddresses = append(nodeAddresses, convertedAddress...) - } - - for _, instance := range instances { - addresses, error := instanceAddresses(instance) - if error != nil { - return nil, error - } - - nodeAddresses = append(nodeAddresses, addresses...) - } - - return nodeAddresses, nil + return extractNodeAddresses(instance) } // ExternalID returns the cloud provider ID of the node with the specified nodeName (deprecated). @@ -1114,16 +1084,14 @@ func (c *Cloud) InstanceID(nodeName types.NodeName) (string, error) { // This method will not be called from the node that is requesting this ID. i.e. metadata service // and other local methods cannot be used here func (c *Cloud) InstanceTypeByProviderID(providerID string) (string, error) { - instanceID, error := instanceIDFromProviderID(providerID) - - if error != nil { - return "", error + instanceID, err := kubernetesInstanceID(providerID).mapToAWSInstanceID() + if err != nil { + return "", err } - instance, error := c.describeInstanceByInstanceID(instanceID) - - if error != nil { - return "", error + instance, err := describeInstance(c.ec2, instanceID) + if err != nil { + return "", err } return aws.StringValue(instance.InstanceType), nil @@ -1138,7 +1106,7 @@ func (c *Cloud) InstanceType(nodeName types.NodeName) (string, error) { if err != nil { return "", fmt.Errorf("getInstanceByNodeName failed for %q with %v", nodeName, err) } - return orEmpty(inst.InstanceType), nil + return aws.StringValue(inst.InstanceType), nil } // Return a list of instances matching regex string. @@ -1307,22 +1275,7 @@ func (i *awsInstance) getInstanceType() *awsInstanceType { // Gets the full information about this instance from the EC2 API func (i *awsInstance) describeInstance() (*ec2.Instance, error) { - instanceID := i.awsID - request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{&instanceID}, - } - - instances, err := i.ec2.DescribeInstances(request) - if err != nil { - return nil, err - } - if len(instances) == 0 { - return nil, fmt.Errorf("no instances found for instance: %s", i.awsID) - } - if len(instances) > 1 { - return nil, fmt.Errorf("multiple instances found for instance: %s", i.awsID) - } - return instances[0], nil + return describeInstance(i.ec2, awsInstanceID(i.awsID)) } // Gets the mountDevice already assigned to the volume, or assigns an unused mountDevice. @@ -3407,25 +3360,6 @@ func (c *Cloud) getInstancesByNodeNames(nodeNames []string, states ...string) ([ return instances, nil } -func (c *Cloud) describeInstancesByInstanceID(instanceID string) ([]*ec2.Instance, error) { - filters := []*ec2.Filter{newEc2Filter("instance-id", instanceID)} - return c.describeInstances(filters) -} - -func (c *Cloud) describeInstanceByInstanceID(instanceID string) (*ec2.Instance, error) { - filters := []*ec2.Filter{newEc2Filter("instance-id", instanceID)} - instances, err := c.describeInstances(filters) - if err != nil { - return nil, err - } - - if len(instances) != 1 { - return nil, fmt.Errorf("expected 1 instance, found %d for instanceID %s", len(instances), instanceID) - } - - return instances[0], nil -} - func (c *Cloud) describeInstances(filters []*ec2.Filter) ([]*ec2.Instance, error) { filters = c.tagging.addFilters(filters) request := &ec2.DescribeInstancesInput{ @@ -3446,21 +3380,6 @@ func (c *Cloud) describeInstances(filters []*ec2.Filter) ([]*ec2.Instance, error return matches, nil } -func (c *Cloud) describeAddressesByInstanceID(instanceID string) ([]*ec2.Address, error) { - filters := []*ec2.Filter{newEc2Filter("instance-id", instanceID)} - params := &ec2.DescribeAddressesInput{ - Filters: filters, - } - - addresses, error := c.ec2.DescribeAddresses(params) - - if error != nil { - return nil, error - } - - return addresses, nil -} - // mapNodeNameToPrivateDNSName maps a k8s NodeName to an AWS Instance PrivateDNSName // This is a simple string cast func mapNodeNameToPrivateDNSName(nodeName types.NodeName) string { @@ -3518,78 +3437,6 @@ func (c *Cloud) getFullInstance(nodeName types.NodeName) (*awsInstance, *ec2.Ins return awsInstance, instance, err } -func instanceAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) { - addresses := []v1.NodeAddress{} - privateDNSName := aws.StringValue(instance.PrivateDnsName) - unsafePrivateIP := aws.StringValue(instance.PrivateIpAddress) - publicDNSName := aws.StringValue(instance.PublicDnsName) - unsafePublicIP := aws.StringValue(instance.PublicIpAddress) - - if privateDNSName != "" { - addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalDNS, Address: privateDNSName}) - } - - if unsafePrivateIP != "" { - ip := net.ParseIP(unsafePrivateIP) - if ip != nil { - addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalIP, Address: ip.String()}) - } else { - return nil, fmt.Errorf("EC2 address had invalid private IP: %s", unsafePrivateIP) - } - } - - if publicDNSName != "" { - addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalDNS, Address: publicDNSName}) - } - - if unsafePublicIP != "" { - ip := net.ParseIP(unsafePublicIP) - if ip != nil { - addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalIP, Address: ip.String()}) - } else { - return nil, fmt.Errorf("EC2 address had invalid public IP: %s", unsafePublicIP) - } - } - - return addresses, nil -} - -func convertAwsAddress(address *ec2.Address) ([]v1.NodeAddress, error) { - nodeAddresses := []v1.NodeAddress{} - if aws.StringValue(address.PrivateIpAddress) != "" { - unsafeIP := *address.PrivateIpAddress - ip := net.ParseIP(unsafeIP) - if ip != nil { - nodeAddresses = append(nodeAddresses, v1.NodeAddress{Type: v1.NodeInternalIP, Address: ip.String()}) - } else { - return nil, fmt.Errorf("EC2 address had invalid private IP: %s", unsafeIP) - } - } - - if aws.StringValue(address.PublicIp) != "" { - unsafeIP := *address.PublicIp - ip := net.ParseIP(unsafeIP) - if ip != nil { - nodeAddresses = append(nodeAddresses, v1.NodeAddress{Type: v1.NodeExternalIP, Address: ip.String()}) - } else { - return nil, fmt.Errorf("EC2 address had invalid public IP: %s", unsafeIP) - } - } - - return nodeAddresses, nil -} - -var providerIDRegexp = regexp.MustCompile(`^aws://([^/]+)$`) - -func instanceIDFromProviderID(providerID string) (instanceID string, err error) { - matches := providerIDRegexp.FindStringSubmatch(providerID) - if len(matches) != 2 { - return "", fmt.Errorf("ProviderID \"%s\" didn't match expected format \"aws://InstanceID\"", providerID) - } - - return matches[1], nil -} - func setNodeDisk( nodeDiskMap map[types.NodeName]map[KubernetesVolumeID]bool, volumeID KubernetesVolumeID, diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index f75c1a571c..a50eac1fc8 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -322,12 +322,6 @@ func (self *FakeEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]* return matches, nil } -func (self *FakeEC2) DescribeAddresses(request *ec2.DescribeAddressesInput) ([]*ec2.Address, error) { - addresses := []*ec2.Address{} - - return addresses, nil -} - type FakeMetadata struct { aws *FakeAWSServices } @@ -1356,37 +1350,3 @@ func TestGetLoadBalancerAdditionalTags(t *testing.T) { } } } - -func TestInstanceIDFromProviderID(t *testing.T) { - testCases := []struct { - providerID string - instanceID string - fail bool - }{ - { - providerID: "aws://i-0194bbdb81a49b169", - instanceID: "i-0194bbdb81a49b169", - fail: false, - }, - { - providerID: "i-0194bbdb81a49b169", - instanceID: "", - fail: true, - }, - } - - for _, test := range testCases { - instanceID, err := instanceIDFromProviderID(test.providerID) - if (err != nil) != test.fail { - t.Errorf("%s yielded `err != nil` as %t. expected %t", test.providerID, (err != nil), test.fail) - } - - if test.fail { - continue - } - - if instanceID != test.instanceID { - t.Errorf("%s yielded %s. expected %s", test.providerID, instanceID, test.instanceID) - } - } -} diff --git a/pkg/cloudprovider/providers/aws/instances.go b/pkg/cloudprovider/providers/aws/instances.go new file mode 100644 index 0000000000..2d8c1ea965 --- /dev/null +++ b/pkg/cloudprovider/providers/aws/instances.go @@ -0,0 +1,100 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "fmt" + "net/url" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" +) + +// awsInstanceID represents the ID of the instance in the AWS API, e.g. i-12345678 +// The "traditional" format is "i-12345678" +// A new longer format is also being introduced: "i-12345678abcdef01" +// We should not assume anything about the length or format, though it seems +// reasonable to assume that instances will continue to start with "i-". +type awsInstanceID string + +func (i awsInstanceID) awsString() *string { + return aws.String(string(i)) +} + +// kubernetesInstanceID represents the id for an instance in the kubernetes API; +// the following form +// * aws://// +// * aws://// +// * +type kubernetesInstanceID string + +// mapToAWSInstanceID extracts the awsInstanceID from the kubernetesInstanceID +func (name kubernetesInstanceID) mapToAWSInstanceID() (awsInstanceID, error) { + s := string(name) + + if !strings.HasPrefix(s, "aws://") { + // Assume a bare aws volume id (vol-1234...) + // Build a URL with an empty host (AZ) + s = "aws://" + "/" + "/" + s + } + url, err := url.Parse(s) + if err != nil { + return "", fmt.Errorf("Invalid instance name (%s): %v", name, err) + } + if url.Scheme != "aws" { + return "", fmt.Errorf("Invalid scheme for AWS instance (%s)", name) + } + + awsID := "" + tokens := strings.Split(strings.Trim(url.Path, "/"), "/") + if len(tokens) == 1 { + // instanceId + awsID = tokens[0] + } else if len(tokens) == 2 { + // az/instanceId + awsID = tokens[1] + } + + // We sanity check the resulting volume; the two known formats are + // i-12345678 and i-12345678abcdef01 + // TODO: Regex match? + if awsID == "" || strings.Contains(awsID, "/") || !strings.HasPrefix(awsID, "i-") { + return "", fmt.Errorf("Invalid format for AWS instance (%s)", name) + } + + return awsInstanceID(awsID), nil +} + +// Gets the full information about this instance from the EC2 API +func describeInstance(ec2Client EC2, instanceID awsInstanceID) (*ec2.Instance, error) { + request := &ec2.DescribeInstancesInput{ + InstanceIds: []*string{instanceID.awsString()}, + } + + instances, err := ec2Client.DescribeInstances(request) + if err != nil { + return nil, err + } + if len(instances) == 0 { + return nil, fmt.Errorf("no instances found for instance: %s", instanceID) + } + if len(instances) > 1 { + return nil, fmt.Errorf("multiple instances found for instance: %s", instanceID) + } + return instances[0], nil +} diff --git a/pkg/cloudprovider/providers/aws/instances_test.go b/pkg/cloudprovider/providers/aws/instances_test.go new file mode 100644 index 0000000000..79637ad91d --- /dev/null +++ b/pkg/cloudprovider/providers/aws/instances_test.go @@ -0,0 +1,89 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "testing" +) + +func TestParseInstance(t *testing.T) { + tests := []struct { + Kubernetes kubernetesInstanceID + Aws awsInstanceID + ExpectError bool + }{ + { + Kubernetes: "aws:///us-east-1a/i-12345678", + Aws: "i-12345678", + }, + { + Kubernetes: "aws:////i-12345678", + Aws: "i-12345678", + }, + { + Kubernetes: "i-12345678", + Aws: "i-12345678", + }, + { + Kubernetes: "aws:///us-east-1a/i-12345678abcdef01", + Aws: "i-12345678abcdef01", + }, + { + Kubernetes: "aws:////i-12345678abcdef01", + Aws: "i-12345678abcdef01", + }, + { + Kubernetes: "i-12345678abcdef01", + Aws: "i-12345678abcdef01", + }, + { + Kubernetes: "vol-123456789", + ExpectError: true, + }, + { + Kubernetes: "aws:///us-east-1a/vol-12345678abcdef01", + ExpectError: true, + }, + { + Kubernetes: "aws://accountid/us-east-1a/vol-12345678abcdef01", + ExpectError: true, + }, + { + Kubernetes: "aws:///us-east-1a/vol-12345678abcdef01/suffix", + ExpectError: true, + }, + { + Kubernetes: "", + ExpectError: true, + }, + } + + for _, test := range tests { + awsID, err := test.Kubernetes.mapToAWSInstanceID() + if err != nil { + if !test.ExpectError { + t.Errorf("unexpected error parsing %s: %v", test.Kubernetes, err) + } + } else { + if test.ExpectError { + t.Errorf("expected error parsing %s", test.Kubernetes) + } else if test.Aws != awsID { + t.Errorf("unexpected value parsing %s, got %s", test.Kubernetes, awsID) + } + } + } +} diff --git a/pkg/cloudprovider/providers/aws/volumes.go b/pkg/cloudprovider/providers/aws/volumes.go index c9d11ef8a7..8ff342199d 100644 --- a/pkg/cloudprovider/providers/aws/volumes.go +++ b/pkg/cloudprovider/providers/aws/volumes.go @@ -24,7 +24,7 @@ import ( "github.com/aws/aws-sdk-go/aws" ) -// awsVolumeID represents the ID of the volume in the AWS API, e.g. vol-12345678a +// awsVolumeID represents the ID of the volume in the AWS API, e.g. vol-12345678 // The "traditional" format is "vol-12345678" // A new longer format is also being introduced: "vol-12345678abcdef01" // We should not assume anything about the length or format, though it seems