ColossalAI/colossalai/accelerator/npu_accelerator.py

64 lines
1.4 KiB
Python

#!/usr/bin/env python
from typing import Union
import torch
from .base_accelerator import BaseAccelerator
try:
import torch_npu # noqa
except ImportError:
pass
__all__ = ["NpuAccelerator"]
class NpuAccelerator(BaseAccelerator):
"""
Accelerator class for Huawei NPU devices.
"""
def __init__(self):
super().__init__(name="npu", communication_backend="hccl", is_synchronous=False)
# =======================
# device APIs
# =======================
def current_device(self) -> int:
"""
Return the current device index.
"""
return torch.npu.current_device()
def set_device(self, device: Union[torch.device, int]) -> None:
"""
Bind the current process to a device.
"""
torch.npu.set_device(device)
def get_device_name(self, device: Union[torch.device, int]) -> str:
"""
Return the name of the device.
"""
return torch.npu.get_device_name(device)
def synchronize(self, device: Union[torch.device, int] = None):
"""
Synchronize the current process.
"""
torch.npu.synchronize(device)
def is_available(self):
"""
Check if the accelerator is available.
"""
return torch.npu.is_available()
def device_count(self):
"""
Return the number of devices on the machine.
"""
return torch.npu.device_count()