Merge pull request #47395 from justinsb/followup_47215_2

Automatic merge from submit-queue

AWS cleanup

Rationalize the existing code.

```release-note
NONE
```

 Issue #47394
pull/6/head
Kubernetes Submit Queue 2017-06-13 08:50:05 -07:00 committed by GitHub
commit 4d31eca42d
6 changed files with 232 additions and 234 deletions

View File

@ -18,6 +18,7 @@ go_library(
"aws_routes.go",
"aws_utils.go",
"device_allocator.go",
"instances.go",
"log_handler.go",
"regions.go",
"retry_handler.go",
@ -58,6 +59,7 @@ go_test(
srcs = [
"aws_test.go",
"device_allocator_test.go",
"instances_test.go",
"regions_test.go",
"retry_handler_test.go",
"tags_test.go",

View File

@ -204,7 +204,6 @@ type Services interface {
type EC2 interface {
// Query EC2 for instances matching the filter
DescribeInstances(request *ec2.DescribeInstancesInput) ([]*ec2.Instance, error)
DescribeAddresses(request *ec2.DescribeAddressesInput) ([]*ec2.Address, error)
// Attach a volume to an instance
AttachVolume(*ec2.AttachVolumeInput) (*ec2.VolumeAttachment, error)
@ -609,20 +608,6 @@ func (s *awsSdkEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*e
return results, nil
}
// Implementation of EC2.DescribeAddresses
func (s *awsSdkEC2) DescribeAddresses(request *ec2.DescribeAddressesInput) ([]*ec2.Address, error) {
requestTime := time.Now()
response, err := s.ec2.DescribeAddresses(request)
if err != nil {
recordAwsMetric("describe_address", 0, err)
return nil, fmt.Errorf("error listing AWS addresses: %v", err)
}
timeTaken := time.Since(requestTime).Seconds()
recordAwsMetric("describe_address", timeTaken, nil)
return response.Addresses, nil
}
// Implements EC2.DescribeSecurityGroups
func (s *awsSdkEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) {
// Security groups are not paged
@ -996,38 +981,51 @@ func (c *Cloud) NodeAddresses(name types.NodeName) ([]v1.NodeAddress, error) {
return addresses, nil
}
instance, err := c.getInstanceByNodeName(name)
if err != nil {
return nil, fmt.Errorf("getInstanceByNodeName failed for %q with %v", name, err)
}
return extractNodeAddresses(instance)
}
// extractNodeAddresses maps the instance information from EC2 to an array of NodeAddresses
func extractNodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) {
// Not clear if the order matters here, but we might as well indicate a sensible preference order
if instance == nil {
return nil, fmt.Errorf("nil instance passed to extractNodeAddresses")
}
addresses := []v1.NodeAddress{}
if !isNilOrEmpty(instance.PrivateIpAddress) {
ipAddress := *instance.PrivateIpAddress
ip := net.ParseIP(ipAddress)
privateIPAddress := aws.StringValue(instance.PrivateIpAddress)
if privateIPAddress != "" {
ip := net.ParseIP(privateIPAddress)
if ip == nil {
return nil, fmt.Errorf("EC2 instance had invalid private address: %s (%s)", orEmpty(instance.InstanceId), ipAddress)
return nil, fmt.Errorf("EC2 instance had invalid private address: %s (%s)", aws.StringValue(instance.InstanceId), privateIPAddress)
}
addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalIP, Address: ip.String()})
}
// TODO: Other IP addresses (multiple ips)?
if !isNilOrEmpty(instance.PublicIpAddress) {
ipAddress := *instance.PublicIpAddress
ip := net.ParseIP(ipAddress)
publicIPAddress := aws.StringValue(instance.PublicIpAddress)
if publicIPAddress != "" {
ip := net.ParseIP(publicIPAddress)
if ip == nil {
return nil, fmt.Errorf("EC2 instance had invalid public address: %s (%s)", orEmpty(instance.InstanceId), ipAddress)
return nil, fmt.Errorf("EC2 instance had invalid public address: %s (%s)", aws.StringValue(instance.InstanceId), publicIPAddress)
}
addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalIP, Address: ip.String()})
}
if !isNilOrEmpty(instance.PrivateDnsName) {
addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalDNS, Address: *instance.PrivateDnsName})
privateDNSName := aws.StringValue(instance.PrivateDnsName)
if privateDNSName != "" {
addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalDNS, Address: privateDNSName})
}
if !isNilOrEmpty(instance.PublicDnsName) {
addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalDNS, Address: *instance.PublicDnsName})
publicDNSName := aws.StringValue(instance.PublicDnsName)
if publicDNSName != "" {
addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalDNS, Address: publicDNSName})
}
return addresses, nil
@ -1037,45 +1035,17 @@ func (c *Cloud) NodeAddresses(name types.NodeName) ([]v1.NodeAddress, error) {
// This method will not be called from the node that is requesting this ID. i.e. metadata service
// and other local methods cannot be used here
func (c *Cloud) NodeAddressesByProviderID(providerID string) ([]v1.NodeAddress, error) {
instanceID, error := instanceIDFromProviderID(providerID)
if error != nil {
return nil, error
instanceID, err := kubernetesInstanceID(providerID).mapToAWSInstanceID()
if err != nil {
return nil, err
}
addresses, error := c.describeAddressesByInstanceID(instanceID)
if error != nil {
return nil, error
instance, err := describeInstance(c.ec2, instanceID)
if err != nil {
return nil, err
}
instances, error := c.describeInstancesByInstanceID(instanceID)
if error != nil {
return nil, error
}
nodeAddresses := []v1.NodeAddress{}
for _, address := range addresses {
convertedAddress, error := convertAwsAddress(address)
if error != nil {
return nil, error
}
nodeAddresses = append(nodeAddresses, convertedAddress...)
}
for _, instance := range instances {
addresses, error := instanceAddresses(instance)
if error != nil {
return nil, error
}
nodeAddresses = append(nodeAddresses, addresses...)
}
return nodeAddresses, nil
return extractNodeAddresses(instance)
}
// ExternalID returns the cloud provider ID of the node with the specified nodeName (deprecated).
@ -1114,16 +1084,14 @@ func (c *Cloud) InstanceID(nodeName types.NodeName) (string, error) {
// This method will not be called from the node that is requesting this ID. i.e. metadata service
// and other local methods cannot be used here
func (c *Cloud) InstanceTypeByProviderID(providerID string) (string, error) {
instanceID, error := instanceIDFromProviderID(providerID)
if error != nil {
return "", error
instanceID, err := kubernetesInstanceID(providerID).mapToAWSInstanceID()
if err != nil {
return "", err
}
instance, error := c.describeInstanceByInstanceID(instanceID)
if error != nil {
return "", error
instance, err := describeInstance(c.ec2, instanceID)
if err != nil {
return "", err
}
return aws.StringValue(instance.InstanceType), nil
@ -1138,7 +1106,7 @@ func (c *Cloud) InstanceType(nodeName types.NodeName) (string, error) {
if err != nil {
return "", fmt.Errorf("getInstanceByNodeName failed for %q with %v", nodeName, err)
}
return orEmpty(inst.InstanceType), nil
return aws.StringValue(inst.InstanceType), nil
}
// Return a list of instances matching regex string.
@ -1307,22 +1275,7 @@ func (i *awsInstance) getInstanceType() *awsInstanceType {
// Gets the full information about this instance from the EC2 API
func (i *awsInstance) describeInstance() (*ec2.Instance, error) {
instanceID := i.awsID
request := &ec2.DescribeInstancesInput{
InstanceIds: []*string{&instanceID},
}
instances, err := i.ec2.DescribeInstances(request)
if err != nil {
return nil, err
}
if len(instances) == 0 {
return nil, fmt.Errorf("no instances found for instance: %s", i.awsID)
}
if len(instances) > 1 {
return nil, fmt.Errorf("multiple instances found for instance: %s", i.awsID)
}
return instances[0], nil
return describeInstance(i.ec2, awsInstanceID(i.awsID))
}
// Gets the mountDevice already assigned to the volume, or assigns an unused mountDevice.
@ -3407,25 +3360,6 @@ func (c *Cloud) getInstancesByNodeNames(nodeNames []string, states ...string) ([
return instances, nil
}
func (c *Cloud) describeInstancesByInstanceID(instanceID string) ([]*ec2.Instance, error) {
filters := []*ec2.Filter{newEc2Filter("instance-id", instanceID)}
return c.describeInstances(filters)
}
func (c *Cloud) describeInstanceByInstanceID(instanceID string) (*ec2.Instance, error) {
filters := []*ec2.Filter{newEc2Filter("instance-id", instanceID)}
instances, err := c.describeInstances(filters)
if err != nil {
return nil, err
}
if len(instances) != 1 {
return nil, fmt.Errorf("expected 1 instance, found %d for instanceID %s", len(instances), instanceID)
}
return instances[0], nil
}
func (c *Cloud) describeInstances(filters []*ec2.Filter) ([]*ec2.Instance, error) {
filters = c.tagging.addFilters(filters)
request := &ec2.DescribeInstancesInput{
@ -3446,21 +3380,6 @@ func (c *Cloud) describeInstances(filters []*ec2.Filter) ([]*ec2.Instance, error
return matches, nil
}
func (c *Cloud) describeAddressesByInstanceID(instanceID string) ([]*ec2.Address, error) {
filters := []*ec2.Filter{newEc2Filter("instance-id", instanceID)}
params := &ec2.DescribeAddressesInput{
Filters: filters,
}
addresses, error := c.ec2.DescribeAddresses(params)
if error != nil {
return nil, error
}
return addresses, nil
}
// mapNodeNameToPrivateDNSName maps a k8s NodeName to an AWS Instance PrivateDNSName
// This is a simple string cast
func mapNodeNameToPrivateDNSName(nodeName types.NodeName) string {
@ -3518,78 +3437,6 @@ func (c *Cloud) getFullInstance(nodeName types.NodeName) (*awsInstance, *ec2.Ins
return awsInstance, instance, err
}
func instanceAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) {
addresses := []v1.NodeAddress{}
privateDNSName := aws.StringValue(instance.PrivateDnsName)
unsafePrivateIP := aws.StringValue(instance.PrivateIpAddress)
publicDNSName := aws.StringValue(instance.PublicDnsName)
unsafePublicIP := aws.StringValue(instance.PublicIpAddress)
if privateDNSName != "" {
addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalDNS, Address: privateDNSName})
}
if unsafePrivateIP != "" {
ip := net.ParseIP(unsafePrivateIP)
if ip != nil {
addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalIP, Address: ip.String()})
} else {
return nil, fmt.Errorf("EC2 address had invalid private IP: %s", unsafePrivateIP)
}
}
if publicDNSName != "" {
addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalDNS, Address: publicDNSName})
}
if unsafePublicIP != "" {
ip := net.ParseIP(unsafePublicIP)
if ip != nil {
addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalIP, Address: ip.String()})
} else {
return nil, fmt.Errorf("EC2 address had invalid public IP: %s", unsafePublicIP)
}
}
return addresses, nil
}
func convertAwsAddress(address *ec2.Address) ([]v1.NodeAddress, error) {
nodeAddresses := []v1.NodeAddress{}
if aws.StringValue(address.PrivateIpAddress) != "" {
unsafeIP := *address.PrivateIpAddress
ip := net.ParseIP(unsafeIP)
if ip != nil {
nodeAddresses = append(nodeAddresses, v1.NodeAddress{Type: v1.NodeInternalIP, Address: ip.String()})
} else {
return nil, fmt.Errorf("EC2 address had invalid private IP: %s", unsafeIP)
}
}
if aws.StringValue(address.PublicIp) != "" {
unsafeIP := *address.PublicIp
ip := net.ParseIP(unsafeIP)
if ip != nil {
nodeAddresses = append(nodeAddresses, v1.NodeAddress{Type: v1.NodeExternalIP, Address: ip.String()})
} else {
return nil, fmt.Errorf("EC2 address had invalid public IP: %s", unsafeIP)
}
}
return nodeAddresses, nil
}
var providerIDRegexp = regexp.MustCompile(`^aws://([^/]+)$`)
func instanceIDFromProviderID(providerID string) (instanceID string, err error) {
matches := providerIDRegexp.FindStringSubmatch(providerID)
if len(matches) != 2 {
return "", fmt.Errorf("ProviderID \"%s\" didn't match expected format \"aws://InstanceID\"", providerID)
}
return matches[1], nil
}
func setNodeDisk(
nodeDiskMap map[types.NodeName]map[KubernetesVolumeID]bool,
volumeID KubernetesVolumeID,

View File

@ -322,12 +322,6 @@ func (self *FakeEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*
return matches, nil
}
func (self *FakeEC2) DescribeAddresses(request *ec2.DescribeAddressesInput) ([]*ec2.Address, error) {
addresses := []*ec2.Address{}
return addresses, nil
}
type FakeMetadata struct {
aws *FakeAWSServices
}
@ -1356,37 +1350,3 @@ func TestGetLoadBalancerAdditionalTags(t *testing.T) {
}
}
}
func TestInstanceIDFromProviderID(t *testing.T) {
testCases := []struct {
providerID string
instanceID string
fail bool
}{
{
providerID: "aws://i-0194bbdb81a49b169",
instanceID: "i-0194bbdb81a49b169",
fail: false,
},
{
providerID: "i-0194bbdb81a49b169",
instanceID: "",
fail: true,
},
}
for _, test := range testCases {
instanceID, err := instanceIDFromProviderID(test.providerID)
if (err != nil) != test.fail {
t.Errorf("%s yielded `err != nil` as %t. expected %t", test.providerID, (err != nil), test.fail)
}
if test.fail {
continue
}
if instanceID != test.instanceID {
t.Errorf("%s yielded %s. expected %s", test.providerID, instanceID, test.instanceID)
}
}
}

View File

@ -0,0 +1,100 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package aws
import (
"fmt"
"net/url"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
)
// awsInstanceID represents the ID of the instance in the AWS API, e.g. i-12345678
// The "traditional" format is "i-12345678"
// A new longer format is also being introduced: "i-12345678abcdef01"
// We should not assume anything about the length or format, though it seems
// reasonable to assume that instances will continue to start with "i-".
type awsInstanceID string
func (i awsInstanceID) awsString() *string {
return aws.String(string(i))
}
// kubernetesInstanceID represents the id for an instance in the kubernetes API;
// the following form
// * aws:///<zone>/<awsInstanceId>
// * aws:////<awsInstanceId>
// * <awsInstanceId>
type kubernetesInstanceID string
// mapToAWSInstanceID extracts the awsInstanceID from the kubernetesInstanceID
func (name kubernetesInstanceID) mapToAWSInstanceID() (awsInstanceID, error) {
s := string(name)
if !strings.HasPrefix(s, "aws://") {
// Assume a bare aws volume id (vol-1234...)
// Build a URL with an empty host (AZ)
s = "aws://" + "/" + "/" + s
}
url, err := url.Parse(s)
if err != nil {
return "", fmt.Errorf("Invalid instance name (%s): %v", name, err)
}
if url.Scheme != "aws" {
return "", fmt.Errorf("Invalid scheme for AWS instance (%s)", name)
}
awsID := ""
tokens := strings.Split(strings.Trim(url.Path, "/"), "/")
if len(tokens) == 1 {
// instanceId
awsID = tokens[0]
} else if len(tokens) == 2 {
// az/instanceId
awsID = tokens[1]
}
// We sanity check the resulting volume; the two known formats are
// i-12345678 and i-12345678abcdef01
// TODO: Regex match?
if awsID == "" || strings.Contains(awsID, "/") || !strings.HasPrefix(awsID, "i-") {
return "", fmt.Errorf("Invalid format for AWS instance (%s)", name)
}
return awsInstanceID(awsID), nil
}
// Gets the full information about this instance from the EC2 API
func describeInstance(ec2Client EC2, instanceID awsInstanceID) (*ec2.Instance, error) {
request := &ec2.DescribeInstancesInput{
InstanceIds: []*string{instanceID.awsString()},
}
instances, err := ec2Client.DescribeInstances(request)
if err != nil {
return nil, err
}
if len(instances) == 0 {
return nil, fmt.Errorf("no instances found for instance: %s", instanceID)
}
if len(instances) > 1 {
return nil, fmt.Errorf("multiple instances found for instance: %s", instanceID)
}
return instances[0], nil
}

View File

@ -0,0 +1,89 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package aws
import (
"testing"
)
func TestParseInstance(t *testing.T) {
tests := []struct {
Kubernetes kubernetesInstanceID
Aws awsInstanceID
ExpectError bool
}{
{
Kubernetes: "aws:///us-east-1a/i-12345678",
Aws: "i-12345678",
},
{
Kubernetes: "aws:////i-12345678",
Aws: "i-12345678",
},
{
Kubernetes: "i-12345678",
Aws: "i-12345678",
},
{
Kubernetes: "aws:///us-east-1a/i-12345678abcdef01",
Aws: "i-12345678abcdef01",
},
{
Kubernetes: "aws:////i-12345678abcdef01",
Aws: "i-12345678abcdef01",
},
{
Kubernetes: "i-12345678abcdef01",
Aws: "i-12345678abcdef01",
},
{
Kubernetes: "vol-123456789",
ExpectError: true,
},
{
Kubernetes: "aws:///us-east-1a/vol-12345678abcdef01",
ExpectError: true,
},
{
Kubernetes: "aws://accountid/us-east-1a/vol-12345678abcdef01",
ExpectError: true,
},
{
Kubernetes: "aws:///us-east-1a/vol-12345678abcdef01/suffix",
ExpectError: true,
},
{
Kubernetes: "",
ExpectError: true,
},
}
for _, test := range tests {
awsID, err := test.Kubernetes.mapToAWSInstanceID()
if err != nil {
if !test.ExpectError {
t.Errorf("unexpected error parsing %s: %v", test.Kubernetes, err)
}
} else {
if test.ExpectError {
t.Errorf("expected error parsing %s", test.Kubernetes)
} else if test.Aws != awsID {
t.Errorf("unexpected value parsing %s, got %s", test.Kubernetes, awsID)
}
}
}
}

View File

@ -24,7 +24,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
)
// awsVolumeID represents the ID of the volume in the AWS API, e.g. vol-12345678a
// awsVolumeID represents the ID of the volume in the AWS API, e.g. vol-12345678
// The "traditional" format is "vol-12345678"
// A new longer format is also being introduced: "vol-12345678abcdef01"
// We should not assume anything about the length or format, though it seems