mirror of https://github.com/hpcaitech/ColossalAI
82 lines
2.0 KiB
Python
82 lines
2.0 KiB
Python
![]() |
#!/usr/bin/env python
|
||
|
from abc import ABC, abstractmethod
|
||
|
from typing import Union
|
||
|
|
||
|
import torch
|
||
|
|
||
|
__all__ = ["BaseAccelerator"]
|
||
|
|
||
|
|
||
|
class BaseAccelerator(ABC):
|
||
|
def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None:
|
||
|
self._name = name
|
||
|
self._communication_backend = communication_backend
|
||
|
self._is_synchronous = is_synchronous
|
||
|
|
||
|
# =======================
|
||
|
# immutable attributes
|
||
|
# =======================
|
||
|
|
||
|
@property
|
||
|
def name(self) -> str:
|
||
|
"""
|
||
|
Return the name of the accelerator.
|
||
|
"""
|
||
|
return self._name
|
||
|
|
||
|
@property
|
||
|
def communication_backend(self) -> str:
|
||
|
"""
|
||
|
Return the name of the backend communication library.
|
||
|
"""
|
||
|
return self._communication_backend
|
||
|
|
||
|
@property
|
||
|
def is_synchronous(self) -> bool:
|
||
|
"""
|
||
|
Return whether the accelerator is a synchronous device.
|
||
|
"""
|
||
|
return self._is_synchronous
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
cls_name = self.__class__.__name__
|
||
|
return f"{cls_name}(name={self._name}, communication_backend={self._communication_backend}, is_synchronous={self._is_synchronous})"
|
||
|
|
||
|
# =======================
|
||
|
# device APIs
|
||
|
# =======================
|
||
|
@abstractmethod
|
||
|
def current_device(self) -> int:
|
||
|
"""
|
||
|
Return the current device index.
|
||
|
"""
|
||
|
|
||
|
@abstractmethod
|
||
|
def set_device(self, device: Union[torch.device, int]) -> None:
|
||
|
"""
|
||
|
Bind the current process to a device.
|
||
|
"""
|
||
|
|
||
|
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||
|
"""
|
||
|
Return the name of the device.
|
||
|
"""
|
||
|
|
||
|
@abstractmethod
|
||
|
def synchronize(self, device: Union[torch.device, int] = None):
|
||
|
"""
|
||
|
Synchronize the current process.
|
||
|
"""
|
||
|
|
||
|
@abstractmethod
|
||
|
def is_available(self):
|
||
|
"""
|
||
|
Check if the accelerator is available.
|
||
|
"""
|
||
|
|
||
|
@abstractmethod
|
||
|
def device_count(self):
|
||
|
"""
|
||
|
Return the number of devices on the machine.
|
||
|
"""
|