You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
__all__ = ['Accelerator']
|
|
|
|
|
|
|
|
_supported_devices = [
|
|
|
|
'cpu',
|
|
|
|
'cuda',
|
|
|
|
|
|
|
|
# To be supported
|
|
|
|
# 'xpu',
|
|
|
|
# 'npu',
|
|
|
|
# 'tpu',
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
class Accelerator:
|
|
|
|
"""
|
|
|
|
Accelerator is an abstraction for the hardware device that is used to run the model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
device (str): The device to be used. Currently only support 'cpu' and 'gpu'.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, device: str):
|
|
|
|
self.device = device
|
|
|
|
|
|
|
|
assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
|
|
|
|
|
|
|
|
def bind(self):
|
|
|
|
"""
|
|
|
|
Set the default device for the current process.
|
|
|
|
"""
|
|
|
|
if self.device == 'cpu':
|
|
|
|
pass
|
|
|
|
elif self.device == 'cuda':
|
|
|
|
# TODO(FrankLeeeee): use global environment to check if it is a dist job
|
|
|
|
# if is_distributed:
|
|
|
|
# local_rank = EnvTable().get_local_rank()
|
|
|
|
# torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
|
|
|
|
torch.cuda.set_device(torch.device('cuda'))
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Device {self.device} is not supported yet")
|
|
|
|
|
|
|
|
def configure_model(self, model: nn.Module) -> nn.Module:
|
|
|
|
"""
|
|
|
|
Move the model to the device.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): The model to be moved.
|
|
|
|
"""
|
|
|
|
model = model.to(torch.device(self.device))
|
|
|
|
return model
|