diff --git a/pkg/cloudprovider/providers/azure/azure_vmss_test.go b/pkg/cloudprovider/providers/azure/azure_vmss_test.go index 7830eab783..08afdd5690 100644 --- a/pkg/cloudprovider/providers/azure/azure_vmss_test.go +++ b/pkg/cloudprovider/providers/azure/azure_vmss_test.go @@ -24,12 +24,15 @@ import ( "github.com/stretchr/testify/assert" ) -func newTestScaleSet() *scaleSet { - ss := newScaleSet(getTestCloud()) +func newTestScaleSet(scaleSetName string, vmList []string) *scaleSet { + cloud := getTestCloud() + setTestVirtualMachineCloud(cloud, scaleSetName, vmList) + ss := newScaleSet(cloud) + return ss.(*scaleSet) } -func setTestVirtualMachineScaleSets(ss *scaleSet, scaleSetName string, vmList []string) { +func setTestVirtualMachineCloud(ss *Cloud, scaleSetName string, vmList []string) { virtualMachineScaleSetsClient := newFakeVirtualMachineScaleSetsClient() scaleSets := make(map[string]map[string]compute.VirtualMachineScaleSet) scaleSets["rg"] = map[string]compute.VirtualMachineScaleSet{ @@ -63,13 +66,13 @@ func setTestVirtualMachineScaleSets(ss *scaleSet, scaleSetName string, vmList [] }, ID: &ID, InstanceID: &instanceID, - Location: &ss.Cloud.Location, + Location: &ss.Location, } } virtualMachineScaleSetVMsClient.setFakeStore(ssVMs) - ss.Cloud.VirtualMachineScaleSetsClient = virtualMachineScaleSetsClient - ss.Cloud.VirtualMachineScaleSetVMsClient = virtualMachineScaleSetVMsClient + ss.VirtualMachineScaleSetsClient = virtualMachineScaleSetsClient + ss.VirtualMachineScaleSetVMsClient = virtualMachineScaleSetVMsClient } func TestGetScaleSetVMInstanceID(t *testing.T) { @@ -102,8 +105,6 @@ func TestGetScaleSetVMInstanceID(t *testing.T) { } func TestGetInstanceIDByNodeName(t *testing.T) { - ss := newTestScaleSet() - testCases := []struct { description string scaleSet string @@ -136,7 +137,8 @@ func TestGetInstanceIDByNodeName(t *testing.T) { } for _, test := range testCases { - setTestVirtualMachineScaleSets(ss, test.scaleSet, test.vmList) + ss := newTestScaleSet(test.scaleSet, test.vmList) + real, err := ss.GetInstanceIDByNodeName(test.nodeName) if test.expectError { assert.Error(t, err, test.description)