From 3eab40cabd6fb24f50a9c3ad0c35980541f5aab6 Mon Sep 17 00:00:00 2001 From: Justin Santa Barbara Date: Mon, 9 Mar 2015 21:15:53 -0700 Subject: [PATCH] AWS support for Zones --- pkg/cloudprovider/aws/aws.go | 56 ++++++++++++++++++++++--- pkg/cloudprovider/aws/aws_test.go | 69 ++++++++++++++++++++++++------- 2 files changed, 103 insertions(+), 22 deletions(-) diff --git a/pkg/cloudprovider/aws/aws.go b/pkg/cloudprovider/aws/aws.go index 93363c17cd..224d124b1e 100644 --- a/pkg/cloudprovider/aws/aws.go +++ b/pkg/cloudprovider/aws/aws.go @@ -37,16 +37,22 @@ import ( type EC2 interface { // Query EC2 for instances matching the filter Instances(instIds []string, filter *ec2InstanceFilter) (resp *ec2.InstancesResp, err error) + + // Query the EC2 metadata service (used to discover instance-id etc) + GetMetaData(key string) ([]byte, error) } // AWSCloud is an implementation of Interface, TCPLoadBalancer and Instances for Amazon Web Services. type AWSCloud struct { - ec2 EC2 - cfg *AWSCloudConfig + ec2 EC2 + cfg *AWSCloudConfig + availabilityZone string + region aws.Region } type AWSCloudConfig struct { Global struct { + // TODO: Is there any use for this? We can get it from the instance metadata service Region string } } @@ -82,6 +88,14 @@ func (self *GoamzEC2) Instances(instanceIds []string, filter *ec2InstanceFilter) return self.ec2.Instances(instanceIds, goamzFilter) } +func (self *GoamzEC2) GetMetaData(key string) ([]byte, error) { + v, err := aws.GetMetaData(key) + if err != nil { + return nil, fmt.Errorf("Error querying AWS metadata for key %s: %v", key, err) + } + return v, nil +} + type AuthFunc func() (auth aws.Auth, err error) func init() { @@ -125,18 +139,36 @@ func newAWSCloud(config io.Reader, authFunc AuthFunc) (*AWSCloud, error) { return nil, err } + // TODO: We can get the region very easily from the instance-metadata service region, ok := aws.Regions[cfg.Global.Region] if !ok { return nil, fmt.Errorf("not a valid AWS region: %s", cfg.Global.Region) } - ec2 := ec2.New(auth, region) return &AWSCloud{ - ec2: &GoamzEC2{ec2: ec2}, - cfg: cfg, + ec2: &GoamzEC2{ec2: ec2.New(auth, region)}, + cfg: cfg, + region: region, }, nil } +func (self *AWSCloud) getAvailabilityZone() (string, error) { + // TODO: Do we need sync.Mutex here? + availabilityZone := self.availabilityZone + if self.availabilityZone == "" { + availabilityZoneBytes, err := self.ec2.GetMetaData("placement/availability-zone") + if err != nil { + return "", err + } + if availabilityZoneBytes == nil || len(availabilityZoneBytes) == 0 { + return "", fmt.Errorf("Unable to determine availability-zone from instance metadata") + } + availabilityZone = string(availabilityZoneBytes) + self.availabilityZone = availabilityZone + } + return availabilityZone, nil +} + func (aws *AWSCloud) Clusters() (cloudprovider.Clusters, bool) { return nil, false } @@ -153,7 +185,7 @@ func (aws *AWSCloud) Instances() (cloudprovider.Instances, bool) { // Zones returns an implementation of Zones for Amazon Web Services. func (aws *AWSCloud) Zones() (cloudprovider.Zones, bool) { - return nil, false + return aws, true } // IPAddress is an implementation of Instances.IPAddress. @@ -246,3 +278,15 @@ func (aws *AWSCloud) List(filter string) ([]string, error) { func (v *AWSCloud) GetNodeResources(name string) (*api.NodeResources, error) { return nil, nil } + +// GetZone implements Zones.GetZone +func (self *AWSCloud) GetZone() (cloudprovider.Zone, error) { + availabilityZone, err := self.getAvailabilityZone() + if err != nil { + return cloudprovider.Zone{}, err + } + return cloudprovider.Zone{ + FailureDomain: availabilityZone, + Region: self.region.Name, + }, nil +} diff --git a/pkg/cloudprovider/aws/aws_test.go b/pkg/cloudprovider/aws/aws_test.go index d08f99ab71..2b1012e644 100644 --- a/pkg/cloudprovider/aws/aws_test.go +++ b/pkg/cloudprovider/aws/aws_test.go @@ -76,28 +76,47 @@ func TestNewAWSCloud(t *testing.T) { } type FakeEC2 struct { - instances func(instanceIds []string, filter *ec2InstanceFilter) (resp *ec2.InstancesResp, err error) + instances []ec2.Instance + availabilityZone string } -func (ec2 *FakeEC2) Instances(instanceIds []string, filter *ec2InstanceFilter) (resp *ec2.InstancesResp, err error) { - return ec2.instances(instanceIds, filter) +func (self *FakeEC2) Instances(instanceIds []string, filter *ec2InstanceFilter) (resp *ec2.InstancesResp, err error) { + matches := []ec2.Instance{} + for _, instance := range self.instances { + if filter == nil || filter.Matches(instance) { + matches = append(matches, instance) + } + } + return &ec2.InstancesResp{"", + []ec2.Reservation{ + {"", "", "", nil, matches}}}, nil +} + +func (self *FakeEC2) GetMetaData(key string) ([]byte, error) { + if key == "placement/availability-zone" { + return []byte(self.availabilityZone), nil + } else { + return nil, nil + } } func mockInstancesResp(instances []ec2.Instance) (aws *AWSCloud) { + availabilityZone := "us-west-2d" return &AWSCloud{ - &FakeEC2{ - func(instanceIds []string, filter *ec2InstanceFilter) (resp *ec2.InstancesResp, err error) { - matches := []ec2.Instance{} - for _, instance := range instances { - if filter == nil || filter.Matches(instance) { - matches = append(matches, instance) - } - } - return &ec2.InstancesResp{"", - []ec2.Reservation{ - {"", "", "", nil, matches}}}, nil - }}, - nil} + ec2: &FakeEC2{ + instances: instances, + availabilityZone: availabilityZone, + }, + } +} + +func mockAvailabilityZone(region string, availabilityZone string) *AWSCloud { + return &AWSCloud{ + ec2: &FakeEC2{ + availabilityZone: availabilityZone, + }, + region: aws.Regions[region], + } } func TestList(t *testing.T) { @@ -163,3 +182,21 @@ func TestIPAddress(t *testing.T) { t.Errorf("Expected %v, got %v", e, a) } } + +func TestGetRegion(t *testing.T) { + aws := mockAvailabilityZone("us-west-2", "us-west-2e") + zones, ok := aws.Zones() + if !ok { + t.Fatalf("Unexpected missing zones impl") + } + zone, err := zones.GetZone() + if err != nil { + t.Fatalf("unexpected error %v", err) + } + if zone.Region != "us-west-2" { + t.Errorf("Unexpected region: %s", zone.Region) + } + if zone.FailureDomain != "us-west-2e" { + t.Errorf("Unexpected FailureDomain: %s", zone.FailureDomain) + } +}