mirror of https://github.com/hpcaitech/ColossalAI
[accelerator] init the accelerator module (#5129)
* [accelerator] init the accelerator module * polish code * polish code * polish code * polish codepull/5237/head
parent
68fcaa2225
commit
f4e72c9992
|
@ -1,4 +1,5 @@
|
|||
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
|
||||
from . import accelerator
|
||||
|
||||
try:
|
||||
# .version will be created by setup.py
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# 🚀 Accelerator
|
||||
|
||||
## 🔗 Table of Contents
|
||||
|
||||
- [🚀 Accelerator](#-accelerator)
|
||||
- [🔗 Table of Contents](#-table-of-contents)
|
||||
- [📚 Introduction](#-introduction)
|
||||
- [📌 Design and Acknowledgement](#-design-and-acknowledgement)
|
||||
|
||||
## 📚 Introduction
|
||||
|
||||
This module offers a layer of abstraction for ColossalAI. With this module, the user can easily switch between different accelerator backends, such as Nvidia GPUs, Huawei NPUs, etc. This module is an attempt to make users' code portable across different hardware platform with a simple `auto_set_accelerator()` API.
|
||||
|
||||
## 📌 Design and Acknowledgement
|
||||
|
||||
Our `accelerator` module is heavily inspired by [`deepspeed/accelerator`](https://www.deepspeed.ai/tutorials/accelerator-abstraction-interface/). We found that it is a very well-designed and well-structured module that can be easily integrated into our project. We would like to thank the DeepSpeed team for their great work.
|
||||
|
||||
We implemented this accelerator module from scratch. At the same time, we have implemented our own modifications:
|
||||
1. we updated the accelerator API names to be aligned with PyTorch's native API names.
|
||||
2. we did not include the `op builder` in the `accelerator`. Instead, we have reconstructed our `kernel` module to automatically match the accelerator and its corresponding kernel implementations, so as to make modules less tangled.
|
|
@ -0,0 +1,13 @@
|
|||
from .api import auto_set_accelerator, get_accelerator, set_accelerator
|
||||
from .base_accelerator import BaseAccelerator
|
||||
from .cuda_accelerator import CudaAccelerator
|
||||
from .npu_accelerator import NpuAccelerator
|
||||
|
||||
__all__ = [
|
||||
"get_accelerator",
|
||||
"set_accelerator",
|
||||
"auto_set_accelerator",
|
||||
"BaseAccelerator",
|
||||
"CudaAccelerator",
|
||||
"NpuAccelerator",
|
||||
]
|
|
@ -0,0 +1,72 @@
|
|||
#!/usr/bin/env python
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
from .cuda_accelerator import CudaAccelerator
|
||||
from .npu_accelerator import NpuAccelerator
|
||||
|
||||
__all__ = ["set_accelerator", "auto_set_accelerator", "get_accelerator"]
|
||||
|
||||
|
||||
_ACCELERATOR = None
|
||||
|
||||
|
||||
# we use ordered dictionary here to associate the
|
||||
# order with device check priority
|
||||
# i.e. auto_set_accelerator will check cuda first
|
||||
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator)
|
||||
|
||||
|
||||
def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
|
||||
"""
|
||||
Set the global accelerator for the current process.
|
||||
|
||||
Args:
|
||||
accelerator (Union[str, BaseAccelerator]): the type of accelerator to which the current device belongs.
|
||||
"""
|
||||
|
||||
global _ACCELERATOR
|
||||
|
||||
if isinstance(accelerator, str):
|
||||
_ACCELERATOR = _ACCELERATOR_MAPPING[accelerator]()
|
||||
elif isinstance(accelerator, BaseAccelerator):
|
||||
_ACCELERATOR = accelerator
|
||||
else:
|
||||
raise TypeError("accelerator must be either a string or an instance of BaseAccelerator")
|
||||
|
||||
|
||||
def auto_set_accelerator() -> None:
|
||||
"""
|
||||
Automatically check if any accelerator is available.
|
||||
If an accelerator is availabe, set it as the global accelerator.
|
||||
"""
|
||||
global _ACCELERATOR
|
||||
|
||||
for _, accelerator_cls in _ACCELERATOR_MAPPING.items():
|
||||
try:
|
||||
accelerator = accelerator_cls()
|
||||
if accelerator.is_available():
|
||||
_ACCELERATOR = accelerator
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
if _ACCELERATOR is None:
|
||||
raise RuntimeError(
|
||||
f"No accelerator is available. Please check your environment. The list of accelerators we support is {list(_ACCELERATOR_MAPPING.keys())}"
|
||||
)
|
||||
|
||||
|
||||
def get_accelerator() -> BaseAccelerator:
|
||||
"""
|
||||
Return the accelerator for the current process. If the accelerator is not initialized, it will be initialized
|
||||
to the default accelerator type.
|
||||
|
||||
Returns: the accelerator for the current process.
|
||||
"""
|
||||
global _ACCELERATOR
|
||||
|
||||
if _ACCELERATOR is None:
|
||||
auto_set_accelerator()
|
||||
return _ACCELERATOR
|
|
@ -0,0 +1,81 @@
|
|||
#!/usr/bin/env python
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["BaseAccelerator"]
|
||||
|
||||
|
||||
class BaseAccelerator(ABC):
|
||||
def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None:
|
||||
self._name = name
|
||||
self._communication_backend = communication_backend
|
||||
self._is_synchronous = is_synchronous
|
||||
|
||||
# =======================
|
||||
# immutable attributes
|
||||
# =======================
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Return the name of the accelerator.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def communication_backend(self) -> str:
|
||||
"""
|
||||
Return the name of the backend communication library.
|
||||
"""
|
||||
return self._communication_backend
|
||||
|
||||
@property
|
||||
def is_synchronous(self) -> bool:
|
||||
"""
|
||||
Return whether the accelerator is a synchronous device.
|
||||
"""
|
||||
return self._is_synchronous
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cls_name = self.__class__.__name__
|
||||
return f"{cls_name}(name={self._name}, communication_backend={self._communication_backend}, is_synchronous={self._is_synchronous})"
|
||||
|
||||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
@abstractmethod
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
Return the current device index.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_device(self, device: Union[torch.device, int]) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
|
||||
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||||
"""
|
||||
Return the name of the device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def synchronize(self, device: Union[torch.device, int] = None):
|
||||
"""
|
||||
Synchronize the current process.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self):
|
||||
"""
|
||||
Check if the accelerator is available.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def device_count(self):
|
||||
"""
|
||||
Return the number of devices on the machine.
|
||||
"""
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
|
||||
__all__ = ["CudaAccelerator"]
|
||||
|
||||
|
||||
class CudaAccelerator(BaseAccelerator):
|
||||
"""
|
||||
Accelerator class for Nvidia CUDA devices.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False)
|
||||
|
||||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
Return the current device index.
|
||||
"""
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def set_device(self, device: Union[torch.device, int]) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||||
"""
|
||||
Return the name of the device.
|
||||
"""
|
||||
return torch.cuda.get_device_name(device)
|
||||
|
||||
def synchronize(self, device: Union[torch.device, int] = None):
|
||||
"""
|
||||
Synchronize the current process.
|
||||
"""
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
def is_available(self):
|
||||
"""
|
||||
Check if the accelerator is available.
|
||||
"""
|
||||
return torch.cuda.is_available()
|
||||
|
||||
def device_count(self):
|
||||
"""
|
||||
Return the number of devices on the machine.
|
||||
"""
|
||||
return torch.cuda.device_count()
|
|
@ -0,0 +1,63 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
|
||||
try:
|
||||
import torch_npu # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["NpuAccelerator"]
|
||||
|
||||
|
||||
class NpuAccelerator(BaseAccelerator):
|
||||
"""
|
||||
Accelerator class for Huawei NPU devices.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="npu", communication_backend="hccl", is_synchronous=False)
|
||||
|
||||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
Return the current device index.
|
||||
"""
|
||||
return torch.npu.current_device()
|
||||
|
||||
def set_device(self, device: Union[torch.device, int]) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
torch.npu.set_device(device)
|
||||
|
||||
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||||
"""
|
||||
Return the name of the device.
|
||||
"""
|
||||
return torch.npu.get_device_name(device)
|
||||
|
||||
def synchronize(self, device: Union[torch.device, int] = None):
|
||||
"""
|
||||
Synchronize the current process.
|
||||
"""
|
||||
torch.npu.synchronize(device)
|
||||
|
||||
def is_available(self):
|
||||
"""
|
||||
Check if the accelerator is available.
|
||||
"""
|
||||
return torch.npu.is_available()
|
||||
|
||||
def device_count(self):
|
||||
"""
|
||||
Return the number of devices on the machine.
|
||||
"""
|
||||
return torch.npu.device_count()
|
Loading…
Reference in New Issue