mirror of https://github.com/hpcaitech/ColossalAI
[booster] added the accelerator implementation (#3159)
parent
1ad3a636b1
commit
a9b8402d93
|
@ -3,12 +3,52 @@ import torch.nn as nn
|
|||
|
||||
__all__ = ['Accelerator']
|
||||
|
||||
_supported_devices = [
|
||||
'cpu',
|
||||
'cuda',
|
||||
|
||||
# To be supported
|
||||
# 'xpu',
|
||||
# 'npu',
|
||||
# 'tpu',
|
||||
]
|
||||
|
||||
|
||||
class Accelerator:
|
||||
"""
|
||||
Accelerator is an abstraction for the hardware device that is used to run the model.
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
Args:
|
||||
device (str): The device to be used. Currently only support 'cpu' and 'gpu'.
|
||||
"""
|
||||
|
||||
def __init__(self, device: str):
|
||||
self.device = device
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
# TODO: implement this method
|
||||
pass
|
||||
assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
|
||||
|
||||
def bind(self):
|
||||
"""
|
||||
Set the default device for the current process.
|
||||
"""
|
||||
if self.device == 'cpu':
|
||||
pass
|
||||
elif self.device == 'cuda':
|
||||
# TODO(FrankLeeeee): use global environment to check if it is a dist job
|
||||
# if is_distributed:
|
||||
# local_rank = EnvTable().get_local_rank()
|
||||
# torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
|
||||
torch.cuda.set_device(torch.device('cuda'))
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Device {self.device} is not supported yet")
|
||||
|
||||
def configure_model(self, model: nn.Module) -> nn.Module:
|
||||
"""
|
||||
Move the model to the device.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be moved.
|
||||
"""
|
||||
model = model.to(torch.device(self.device))
|
||||
return model
|
||||
|
|
|
@ -8,6 +8,7 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||
from .plugin import Plugin
|
||||
|
||||
|
@ -51,9 +52,16 @@ class Booster:
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
device: Union[str, torch.device] = 'cuda',
|
||||
device: str = 'cuda',
|
||||
mixed_precision: Union[MixedPrecision, str] = None,
|
||||
plugin: Optional[Plugin] = None) -> None:
|
||||
# TODO(FrankLeeeee): add plugin control logic
|
||||
# if self.plugin is not None and self.plugin.control_accelerator:
|
||||
# ...
|
||||
# create acclerator
|
||||
self.acceleartor = Accelerator(device)
|
||||
self.acceleartor.set_default_device()
|
||||
|
||||
# validate and set precision
|
||||
if isinstance(MixedPrecision, str):
|
||||
# the user will take the default arguments for amp training
|
||||
|
@ -78,6 +86,11 @@ class Booster:
|
|||
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
|
||||
dataloader (DataLoader): The dataloader to be boosted.
|
||||
"""
|
||||
# TODO(FrankLeeeee): add plugin control logic
|
||||
# if self.plugin is not None and self.plugin.control_accelerator:
|
||||
# ...
|
||||
model = self.acceleartor.configure_model(model)
|
||||
|
||||
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
|
||||
# TODO(lsg): Add plugin control logic
|
||||
# e.g.
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
import pytest
|
||||
import torch.nn as nn
|
||||
from torchvision.models import resnet18
|
||||
|
||||
from colossalai.booster.accelerator import Accelerator
|
||||
|
||||
|
||||
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
|
||||
def test_accelerator(device):
|
||||
acceleartor = Accelerator(device)
|
||||
model = nn.Linear(8, 8)
|
||||
model = acceleartor.configure_model(model)
|
||||
assert next(model.parameters()).device.type == device
|
|
@ -56,6 +56,7 @@ def test_torchrec_dlrm_models():
|
|||
data = data_gen_fn()
|
||||
|
||||
# dlrm_interactionarch is not supported
|
||||
# TODO(FrankLeeeee): support this model
|
||||
if name == 'dlrm_interactionarch':
|
||||
continue
|
||||
|
||||
|
|
Loading…
Reference in New Issue