mirror of https://github.com/hpcaitech/ColossalAI
64 lines
1.4 KiB
Python
64 lines
1.4 KiB
Python
#!/usr/bin/env python
|
|
|
|
from typing import Union
|
|
|
|
import torch
|
|
|
|
from .base_accelerator import BaseAccelerator
|
|
|
|
try:
|
|
import torch_npu # noqa
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
__all__ = ["NpuAccelerator"]
|
|
|
|
|
|
class NpuAccelerator(BaseAccelerator):
|
|
"""
|
|
Accelerator class for Huawei NPU devices.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__(name="npu", communication_backend="hccl", is_synchronous=False)
|
|
|
|
# =======================
|
|
# device APIs
|
|
# =======================
|
|
def current_device(self) -> int:
|
|
"""
|
|
Return the current device index.
|
|
"""
|
|
return torch.npu.current_device()
|
|
|
|
def set_device(self, device: Union[torch.device, int]) -> None:
|
|
"""
|
|
Bind the current process to a device.
|
|
"""
|
|
torch.npu.set_device(device)
|
|
|
|
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
|
"""
|
|
Return the name of the device.
|
|
"""
|
|
return torch.npu.get_device_name(device)
|
|
|
|
def synchronize(self, device: Union[torch.device, int] = None):
|
|
"""
|
|
Synchronize the current process.
|
|
"""
|
|
torch.npu.synchronize(device)
|
|
|
|
def is_available(self):
|
|
"""
|
|
Check if the accelerator is available.
|
|
"""
|
|
return torch.npu.is_available()
|
|
|
|
def device_count(self):
|
|
"""
|
|
Return the number of devices on the machine.
|
|
"""
|
|
return torch.npu.device_count()
|