diff --git a/pkg/cloudprovider/providers/azure/BUILD b/pkg/cloudprovider/providers/azure/BUILD index ac0f0f8697..8cb737eed7 100644 --- a/pkg/cloudprovider/providers/azure/BUILD +++ b/pkg/cloudprovider/providers/azure/BUILD @@ -68,6 +68,7 @@ go_test( "azure_cache_test.go", "azure_loadbalancer_test.go", "azure_metrics_test.go", + "azure_routes_test.go", "azure_standard_test.go", "azure_storage_test.go", "azure_storageaccount_test.go", @@ -79,6 +80,7 @@ go_test( importpath = "k8s.io/kubernetes/pkg/cloudprovider/providers/azure", deps = [ "//pkg/api/v1/service:go_default_library", + "//pkg/cloudprovider:go_default_library", "//pkg/cloudprovider/providers/azure/auth:go_default_library", "//pkg/kubelet/apis:go_default_library", "//vendor/github.com/Azure/azure-sdk-for-go/arm/compute:go_default_library", diff --git a/pkg/cloudprovider/providers/azure/azure_routes.go b/pkg/cloudprovider/providers/azure/azure_routes.go index eef61003ff..162bd007c9 100644 --- a/pkg/cloudprovider/providers/azure/azure_routes.go +++ b/pkg/cloudprovider/providers/azure/azure_routes.go @@ -28,18 +28,23 @@ import ( ) // ListRoutes lists all managed routes that belong to the specified clusterName -func (az *Cloud) ListRoutes(clusterName string) (routes []*cloudprovider.Route, err error) { +func (az *Cloud) ListRoutes(clusterName string) ([]*cloudprovider.Route, error) { glog.V(10).Infof("list: START clusterName=%q", clusterName) routeTable, existsRouteTable, err := az.getRouteTable() + return processRoutes(routeTable, existsRouteTable, err) +} + +// Injectable for testing +func processRoutes(routeTable network.RouteTable, exists bool, err error) ([]*cloudprovider.Route, error) { if err != nil { return nil, err } - if !existsRouteTable { + if !exists { return []*cloudprovider.Route{}, nil } var kubeRoutes []*cloudprovider.Route - if routeTable.Routes != nil { + if routeTable.RouteTablePropertiesFormat != nil && routeTable.Routes != nil { kubeRoutes = make([]*cloudprovider.Route, len(*routeTable.Routes)) for i, route := range *routeTable.Routes { instance := mapRouteNameToNodeName(*route.Name) @@ -58,49 +63,54 @@ func (az *Cloud) ListRoutes(clusterName string) (routes []*cloudprovider.Route, return kubeRoutes, nil } +func (az *Cloud) createRouteTableIfNotExists(clusterName string, kubeRoute *cloudprovider.Route) error { + if _, existsRouteTable, err := az.getRouteTable(); err != nil { + glog.V(2).Infof("create error: couldn't get routetable. clusterName=%q instance=%q cidr=%q", clusterName, kubeRoute.TargetNode, kubeRoute.DestinationCIDR) + return err + } else if existsRouteTable { + return nil + } + return az.createRouteTable() +} + +func (az *Cloud) createRouteTable() error { + routeTable := network.RouteTable{ + Name: to.StringPtr(az.RouteTableName), + Location: to.StringPtr(az.Location), + RouteTablePropertiesFormat: &network.RouteTablePropertiesFormat{}, + } + + glog.V(3).Infof("create: creating routetable. routeTableName=%q", az.RouteTableName) + respChan, errChan := az.RouteTablesClient.CreateOrUpdate(az.ResourceGroup, az.RouteTableName, routeTable, nil) + resp := <-respChan + err := <-errChan + glog.V(10).Infof("RouteTablesClient.CreateOrUpdate(%q): end", az.RouteTableName) + if az.CloudProviderBackoff && shouldRetryAPIRequest(resp.Response, err) { + glog.V(2).Infof("create backing off: creating routetable. routeTableName=%q", az.RouteTableName) + retryErr := az.CreateOrUpdateRouteTableWithRetry(routeTable) + if retryErr != nil { + err = retryErr + glog.V(2).Infof("create abort backoff: creating routetable. routeTableName=%q", az.RouteTableName) + } + } + if err != nil { + return err + } + + glog.V(10).Infof("RouteTablesClient.Get(%q): start", az.RouteTableName) + _, err = az.RouteTablesClient.Get(az.ResourceGroup, az.RouteTableName, "") + glog.V(10).Infof("RouteTablesClient.Get(%q): end", az.RouteTableName) + return err +} + // CreateRoute creates the described managed route // route.Name will be ignored, although the cloud-provider may use nameHint // to create a more user-meaningful name. func (az *Cloud) CreateRoute(clusterName string, nameHint string, kubeRoute *cloudprovider.Route) error { glog.V(2).Infof("create: creating route. clusterName=%q instance=%q cidr=%q", clusterName, kubeRoute.TargetNode, kubeRoute.DestinationCIDR) - - routeTable, existsRouteTable, err := az.getRouteTable() - if err != nil { - glog.V(2).Infof("create error: couldn't get routetable. clusterName=%q instance=%q cidr=%q", clusterName, kubeRoute.TargetNode, kubeRoute.DestinationCIDR) + if err := az.createRouteTableIfNotExists(clusterName, kubeRoute); err != nil { return err } - if !existsRouteTable { - routeTable = network.RouteTable{ - Name: to.StringPtr(az.RouteTableName), - Location: to.StringPtr(az.Location), - RouteTablePropertiesFormat: &network.RouteTablePropertiesFormat{}, - } - - glog.V(3).Infof("create: creating routetable. routeTableName=%q", az.RouteTableName) - respChan, errChan := az.RouteTablesClient.CreateOrUpdate(az.ResourceGroup, az.RouteTableName, routeTable, nil) - resp := <-respChan - err := <-errChan - glog.V(10).Infof("RouteTablesClient.CreateOrUpdate(%q): end", az.RouteTableName) - if az.CloudProviderBackoff && shouldRetryAPIRequest(resp.Response, err) { - glog.V(2).Infof("create backing off: creating routetable. routeTableName=%q", az.RouteTableName) - retryErr := az.CreateOrUpdateRouteTableWithRetry(routeTable) - if retryErr != nil { - err = retryErr - glog.V(2).Infof("create abort backoff: creating routetable. routeTableName=%q", az.RouteTableName) - } - } - if err != nil { - return err - } - - glog.V(10).Infof("RouteTablesClient.Get(%q): start", az.RouteTableName) - routeTable, err = az.RouteTablesClient.Get(az.ResourceGroup, az.RouteTableName, "") - glog.V(10).Infof("RouteTablesClient.Get(%q): end", az.RouteTableName) - if err != nil { - return err - } - } - targetIP, err := az.getIPForMachine(kubeRoute.TargetNode) if err != nil { return err diff --git a/pkg/cloudprovider/providers/azure/azure_routes_test.go b/pkg/cloudprovider/providers/azure/azure_routes_test.go new file mode 100644 index 0000000000..02041f0fa2 --- /dev/null +++ b/pkg/cloudprovider/providers/azure/azure_routes_test.go @@ -0,0 +1,175 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "fmt" + "reflect" + "testing" + + "k8s.io/kubernetes/pkg/cloudprovider" + + "github.com/Azure/azure-sdk-for-go/arm/network" + "github.com/Azure/go-autorest/autorest/to" +) + +func TestCreateRoute(t *testing.T) { + fake := newFakeRouteTablesClient() + cloud := &Cloud{ + RouteTablesClient: fake, + Config: Config{ + ResourceGroup: "foo", + RouteTableName: "bar", + Location: "location", + }, + } + expectedTable := network.RouteTable{ + Name: &cloud.RouteTableName, + Location: &cloud.Location, + } + + err := cloud.createRouteTable() + if err != nil { + t.Errorf("unexpected error in creating route table: %v", err) + t.FailNow() + } + + table := fake.FakeStore["foo"]["bar"] + if *table.Location != *expectedTable.Location { + t.Errorf("mismatch: %s vs %s", *table.Location, *expectedTable.Location) + } + if *table.Name != *expectedTable.Name { + t.Errorf("mismatch: %s vs %s", *table.Name, *expectedTable.Name) + } +} + +func TestProcessRoutes(t *testing.T) { + tests := []struct { + rt network.RouteTable + exists bool + err error + expectErr bool + expectedError string + expectedRoute []cloudprovider.Route + name string + }{ + { + err: fmt.Errorf("test error"), + expectErr: true, + expectedError: "test error", + }, + { + exists: false, + name: "doesn't exist", + }, + { + rt: network.RouteTable{}, + exists: true, + name: "nil routes", + }, + { + rt: network.RouteTable{ + RouteTablePropertiesFormat: &network.RouteTablePropertiesFormat{}, + }, + exists: true, + name: "no routes", + }, + { + rt: network.RouteTable{ + RouteTablePropertiesFormat: &network.RouteTablePropertiesFormat{ + Routes: &[]network.Route{ + { + Name: to.StringPtr("name"), + RoutePropertiesFormat: &network.RoutePropertiesFormat{ + AddressPrefix: to.StringPtr("1.2.3.4/16"), + }, + }, + }, + }, + }, + exists: true, + expectedRoute: []cloudprovider.Route{ + { + Name: "name", + TargetNode: mapRouteNameToNodeName("name"), + DestinationCIDR: "1.2.3.4/16", + }, + }, + name: "one route", + }, + { + rt: network.RouteTable{ + RouteTablePropertiesFormat: &network.RouteTablePropertiesFormat{ + Routes: &[]network.Route{ + { + Name: to.StringPtr("name"), + RoutePropertiesFormat: &network.RoutePropertiesFormat{ + AddressPrefix: to.StringPtr("1.2.3.4/16"), + }, + }, + { + Name: to.StringPtr("name2"), + RoutePropertiesFormat: &network.RoutePropertiesFormat{ + AddressPrefix: to.StringPtr("5.6.7.8/16"), + }, + }, + }, + }, + }, + exists: true, + expectedRoute: []cloudprovider.Route{ + { + Name: "name", + TargetNode: mapRouteNameToNodeName("name"), + DestinationCIDR: "1.2.3.4/16", + }, + { + Name: "name2", + TargetNode: mapRouteNameToNodeName("name2"), + DestinationCIDR: "5.6.7.8/16", + }, + }, + name: "more routes", + }, + } + for _, test := range tests { + routes, err := processRoutes(test.rt, test.exists, test.err) + if test.expectErr { + if err == nil { + t.Errorf("%s: unexpected non-error", test.name) + continue + } + if err.Error() != test.expectedError { + t.Errorf("%s: Expected error: %v, saw error: %v", test.name, test.expectedError, err.Error()) + continue + } + } + if !test.expectErr && err != nil { + t.Errorf("%s; unexpected error: %v", test.name, err) + continue + } + if len(routes) != len(test.expectedRoute) { + t.Errorf("%s: Unexpected difference: %#v vs %#v", test.name, routes, test.expectedRoute) + continue + } + for ix := range test.expectedRoute { + if !reflect.DeepEqual(test.expectedRoute[ix], *routes[ix]) { + t.Errorf("%s: Unexpected difference: %#v vs %#v", test.name, test.expectedRoute[ix], *routes[ix]) + } + } + } +} diff --git a/pkg/cloudprovider/providers/azure/azure_test.go b/pkg/cloudprovider/providers/azure/azure_test.go index 5fab9e965b..04f4067519 100644 --- a/pkg/cloudprovider/providers/azure/azure_test.go +++ b/pkg/cloudprovider/providers/azure/azure_test.go @@ -19,6 +19,8 @@ package azure import ( "fmt" "math" + "net/http" + "net/http/httptest" "strings" "testing" @@ -1560,19 +1562,57 @@ func validateEmptyConfig(t *testing.T, config string) { if azureCloud.CloudProviderBackoff != false { t.Errorf("got incorrect value for CloudProviderBackoff") } - // rate limits should be disabled by default if not explicitly enabled in config if azureCloud.CloudProviderRateLimit != false { t.Errorf("got incorrect value for CloudProviderRateLimit") } } +func TestGetZone(t *testing.T) { + data := `{"ID":"_azdev","UD":"0","FD":"99"}` + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, data) + })) + defer ts.Close() + + cloud := &Cloud{} + cloud.Location = "eastus" + + zone, err := cloud.getZoneFromURL(ts.URL) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if zone.FailureDomain != "99" { + t.Errorf("Unexpected value: %s, expected '99'", zone.FailureDomain) + } + if zone.Region != cloud.Location { + t.Errorf("Expected: %s, saw: %s", cloud.Location, zone.Region) + } +} + +func TestFetchFaultDomain(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"ID":"_azdev","UD":"0","FD":"99"}`) + })) + defer ts.Close() + + faultDomain, err := fetchFaultDomain(ts.URL) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if faultDomain == nil { + t.Errorf("Unexpected nil fault domain") + } + if *faultDomain != "99" { + t.Errorf("Expected '99', saw '%s'", *faultDomain) + } +} func TestDecodeInstanceInfo(t *testing.T) { response := `{"ID":"_azdev","UD":"0","FD":"99"}` faultDomain, err := readFaultDomain(strings.NewReader(response)) if err != nil { - t.Error("Unexpected error in ReadFaultDomain") + t.Errorf("Unexpected error in ReadFaultDomain: %v", err) } if faultDomain == nil { diff --git a/pkg/cloudprovider/providers/azure/azure_zones.go b/pkg/cloudprovider/providers/azure/azure_zones.go index 75d0c41251..91df85a036 100644 --- a/pkg/cloudprovider/providers/azure/azure_zones.go +++ b/pkg/cloudprovider/providers/azure/azure_zones.go @@ -40,11 +40,16 @@ type instanceInfo struct { // GetZone returns the Zone containing the current failure zone and locality region that the program is running in func (az *Cloud) GetZone() (cloudprovider.Zone, error) { + return az.getZoneFromURL(instanceInfoURL) +} + +// This is injectable for testing. +func (az *Cloud) getZoneFromURL(url string) (cloudprovider.Zone, error) { faultMutex.Lock() defer faultMutex.Unlock() if faultDomain == nil { var err error - faultDomain, err = fetchFaultDomain() + faultDomain, err = fetchFaultDomain(url) if err != nil { return cloudprovider.Zone{}, err } @@ -75,8 +80,8 @@ func (az *Cloud) GetZoneByNodeName(nodeName types.NodeName) (cloudprovider.Zone, return az.vmSet.GetZoneByNodeName(string(nodeName)) } -func fetchFaultDomain() (*string, error) { - resp, err := http.Get(instanceInfoURL) +func fetchFaultDomain(url string) (*string, error) { + resp, err := http.Get(url) if err != nil { return nil, err }