diff --git a/pkg/cloudprovider/providers/azure/azure_fakes.go b/pkg/cloudprovider/providers/azure/azure_fakes.go index 54d9d63965..67d0b2be27 100644 --- a/pkg/cloudprovider/providers/azure/azure_fakes.go +++ b/pkg/cloudprovider/providers/azure/azure_fakes.go @@ -893,3 +893,7 @@ func (f *fakeVMSet) DetachDiskByName(diskName, diskURI string, nodeName types.No func (f *fakeVMSet) GetDataDisks(nodeName types.NodeName) ([]compute.DataDisk, error) { return nil, fmt.Errorf("unimplemented") } + +func (f *fakeVMSet) GetProvisioningStateByNodeName(name string) (string, error) { + return "", fmt.Errorf("unimplemented") +} diff --git a/pkg/cloudprovider/providers/azure/azure_instances.go b/pkg/cloudprovider/providers/azure/azure_instances.go index 7eb403a7b2..4f609fa72f 100644 --- a/pkg/cloudprovider/providers/azure/azure_instances.go +++ b/pkg/cloudprovider/providers/azure/azure_instances.go @@ -143,7 +143,17 @@ func (az *Cloud) InstanceExistsByProviderID(ctx context.Context, providerID stri // InstanceShutdownByProviderID returns true if the instance is in safe state to detach volumes func (az *Cloud) InstanceShutdownByProviderID(ctx context.Context, providerID string) (bool, error) { - return false, cloudprovider.NotImplemented + nodeName, err := az.vmSet.GetNodeNameByProviderID(providerID) + if err != nil { + return false, err + } + + provisioningState, err := az.vmSet.GetProvisioningStateByNodeName(string(nodeName)) + if err != nil { + return false, err + } + + return strings.ToLower(provisioningState) == "stopped" || strings.ToLower(provisioningState) == "deallocated", nil } // getComputeMetadata gets compute information from instance metadata. diff --git a/pkg/cloudprovider/providers/azure/azure_standard.go b/pkg/cloudprovider/providers/azure/azure_standard.go index f5b39274c2..d75c7fdc8c 100644 --- a/pkg/cloudprovider/providers/azure/azure_standard.go +++ b/pkg/cloudprovider/providers/azure/azure_standard.go @@ -346,6 +346,15 @@ func (as *availabilitySet) GetInstanceIDByNodeName(name string) (string, error) return *machine.ID, nil } +func (as *availabilitySet) GetProvisioningStateByNodeName(name string) (provisioningState string, err error) { + vm, err := as.getVirtualMachine(types.NodeName(name)) + if err != nil { + return provisioningState, err + } + + return *vm.ProvisioningState, nil +} + // GetNodeNameByProviderID gets the node name by provider ID. func (as *availabilitySet) GetNodeNameByProviderID(providerID string) (types.NodeName, error) { // NodeName is part of providerID for standard instances. diff --git a/pkg/cloudprovider/providers/azure/azure_vmsets.go b/pkg/cloudprovider/providers/azure/azure_vmsets.go index 35f8d0cf44..4752d3321f 100644 --- a/pkg/cloudprovider/providers/azure/azure_vmsets.go +++ b/pkg/cloudprovider/providers/azure/azure_vmsets.go @@ -64,4 +64,7 @@ type VMSet interface { DetachDiskByName(diskName, diskURI string, nodeName types.NodeName) error // GetDataDisks gets a list of data disks attached to the node. GetDataDisks(nodeName types.NodeName) ([]compute.DataDisk, error) + + // GetProvisioningStateByNodeName gets the provisioning state by node name. + GetProvisioningStateByNodeName(name string) (string, error) } diff --git a/pkg/cloudprovider/providers/azure/azure_vmss.go b/pkg/cloudprovider/providers/azure/azure_vmss.go index 2fe598e0cc..d2d1a3b82c 100644 --- a/pkg/cloudprovider/providers/azure/azure_vmss.go +++ b/pkg/cloudprovider/providers/azure/azure_vmss.go @@ -128,6 +128,15 @@ func (ss *scaleSet) getVmssVM(nodeName string) (ssName, instanceID string, vm co return ssName, instanceID, *(cachedVM.(*compute.VirtualMachineScaleSetVM)), nil } +func (ss *scaleSet) GetProvisioningStateByNodeName(name string) (provisioningState string, err error) { + _, _, vm, err := ss.getVmssVM(name) + if err != nil { + return provisioningState, err + } + + return *vm.ProvisioningState, nil +} + // getCachedVirtualMachineByInstanceID gets scaleSetVMInfo from cache. // The node must belong to one of scale sets. func (ss *scaleSet) getVmssVMByInstanceID(resourceGroup, scaleSetName, instanceID string) (vm compute.VirtualMachineScaleSetVM, err error) {