|
|
|
#!/usr/bin/env python
|
|
|
|
from collections import OrderedDict
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
from .base_accelerator import BaseAccelerator
|
|
|
|
from .cpu_accelerator import CpuAccelerator
|
|
|
|
from .cuda_accelerator import CudaAccelerator
|
|
|
|
from .npu_accelerator import NpuAccelerator
|
|
|
|
|
|
|
|
__all__ = ["set_accelerator", "auto_set_accelerator", "get_accelerator"]
|
|
|
|
|
|
|
|
|
|
|
|
_ACCELERATOR = None
|
|
|
|
|
|
|
|
|
|
|
|
# we use ordered dictionary here to associate the
|
|
|
|
# order with device check priority
|
|
|
|
# i.e. auto_set_accelerator will check cuda first
|
|
|
|
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator)
|
|
|
|
|
|
|
|
|
|
|
|
def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
|
|
|
|
"""
|
|
|
|
Set the global accelerator for the current process.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
accelerator (Union[str, BaseAccelerator]): the type of accelerator to which the current device belongs.
|
|
|
|
"""
|
|
|
|
|
|
|
|
global _ACCELERATOR
|
|
|
|
|
|
|
|
if isinstance(accelerator, str):
|
|
|
|
_ACCELERATOR = _ACCELERATOR_MAPPING[accelerator]()
|
|
|
|
elif isinstance(accelerator, BaseAccelerator):
|
|
|
|
_ACCELERATOR = accelerator
|
|
|
|
else:
|
|
|
|
raise TypeError("accelerator must be either a string or an instance of BaseAccelerator")
|
|
|
|
|
|
|
|
|
|
|
|
def auto_set_accelerator() -> None:
|
|
|
|
"""
|
|
|
|
Automatically check if any accelerator is available.
|
|
|
|
If an accelerator is available, set it as the global accelerator.
|
|
|
|
"""
|
|
|
|
global _ACCELERATOR
|
|
|
|
|
|
|
|
for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items():
|
|
|
|
try:
|
|
|
|
accelerator = accelerator_cls()
|
|
|
|
if accelerator_name == "cpu" or accelerator.is_available():
|
|
|
|
_ACCELERATOR = accelerator
|
|
|
|
break
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
|
|
|
|
if _ACCELERATOR is None:
|
|
|
|
raise RuntimeError("No accelerator is available.")
|
|
|
|
|
|
|
|
|
|
|
|
def get_accelerator() -> BaseAccelerator:
|
|
|
|
"""
|
|
|
|
Return the accelerator for the current process. If the accelerator is not initialized, it will be initialized
|
|
|
|
to the default accelerator type.
|
|
|
|
|
|
|
|
Returns: the accelerator for the current process.
|
|
|
|
"""
|
|
|
|
global _ACCELERATOR
|
|
|
|
|
|
|
|
if _ACCELERATOR is None:
|
|
|
|
auto_set_accelerator()
|
|
|
|
return _ACCELERATOR
|