mirror of https://github.com/hpcaitech/ColossalAI
57 lines
1.4 KiB
Python
57 lines
1.4 KiB
Python
#!/usr/bin/env python
|
|
from typing import Union
|
|
|
|
import torch
|
|
|
|
from .base_accelerator import BaseAccelerator
|
|
|
|
__all__ = ["CudaAccelerator"]
|
|
|
|
|
|
class CudaAccelerator(BaseAccelerator):
|
|
"""
|
|
Accelerator class for Nvidia CUDA devices.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False)
|
|
|
|
# =======================
|
|
# device APIs
|
|
# =======================
|
|
def current_device(self) -> int:
|
|
"""
|
|
Return the current device index.
|
|
"""
|
|
return torch.cuda.current_device()
|
|
|
|
def set_device(self, device: Union[torch.device, int]) -> None:
|
|
"""
|
|
Bind the current process to a device.
|
|
"""
|
|
torch.cuda.set_device(device)
|
|
|
|
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
|
"""
|
|
Return the name of the device.
|
|
"""
|
|
return torch.cuda.get_device_name(device)
|
|
|
|
def synchronize(self, device: Union[torch.device, int] = None):
|
|
"""
|
|
Synchronize the current process.
|
|
"""
|
|
torch.cuda.synchronize(device)
|
|
|
|
def is_available(self):
|
|
"""
|
|
Check if the accelerator is available.
|
|
"""
|
|
return torch.cuda.is_available()
|
|
|
|
def device_count(self):
|
|
"""
|
|
Return the number of devices on the machine.
|
|
"""
|
|
return torch.cuda.device_count()
|