From a9b8402d93ac69bb9a8b46e21cfe3697409972fe Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 20 Mar 2023 13:59:24 +0800 Subject: [PATCH] [booster] added the accelerator implementation (#3159) --- colossalai/booster/accelerator.py | 48 +++++++++++++++++-- colossalai/booster/booster.py | 15 +++++- tests/test_booster/test_accelerator.py | 13 +++++ .../test_torchrec_model/test_dlrm_model.py | 1 + 4 files changed, 72 insertions(+), 5 deletions(-) create mode 100644 tests/test_booster/test_accelerator.py diff --git a/colossalai/booster/accelerator.py b/colossalai/booster/accelerator.py index 63ba193e3..fc2c4a400 100644 --- a/colossalai/booster/accelerator.py +++ b/colossalai/booster/accelerator.py @@ -3,12 +3,52 @@ import torch.nn as nn __all__ = ['Accelerator'] +_supported_devices = [ + 'cpu', + 'cuda', + + # To be supported + # 'xpu', + # 'npu', + # 'tpu', +] + class Accelerator: + """ + Accelerator is an abstraction for the hardware device that is used to run the model. - def __init__(self, device: torch.device): + Args: + device (str): The device to be used. Currently only support 'cpu' and 'gpu'. + """ + + def __init__(self, device: str): self.device = device - def setup_model(self, model: nn.Module) -> nn.Module: - # TODO: implement this method - pass + assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}" + + def bind(self): + """ + Set the default device for the current process. + """ + if self.device == 'cpu': + pass + elif self.device == 'cuda': + # TODO(FrankLeeeee): use global environment to check if it is a dist job + # if is_distributed: + # local_rank = EnvTable().get_local_rank() + # torch.cuda.set_device(torch.device(f'cuda:{local_rank}')) + torch.cuda.set_device(torch.device('cuda')) + pass + else: + raise ValueError(f"Device {self.device} is not supported yet") + + def configure_model(self, model: nn.Module) -> nn.Module: + """ + Move the model to the device. + + Args: + model (nn.Module): The model to be moved. + """ + model = model.to(torch.device(self.device)) + return model diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 7b351ae34..7d7f21ca6 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -8,6 +8,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory from .plugin import Plugin @@ -51,9 +52,16 @@ class Booster: """ def __init__(self, - device: Union[str, torch.device] = 'cuda', + device: str = 'cuda', mixed_precision: Union[MixedPrecision, str] = None, plugin: Optional[Plugin] = None) -> None: + # TODO(FrankLeeeee): add plugin control logic + # if self.plugin is not None and self.plugin.control_accelerator: + # ... + # create acclerator + self.acceleartor = Accelerator(device) + self.acceleartor.set_default_device() + # validate and set precision if isinstance(MixedPrecision, str): # the user will take the default arguments for amp training @@ -78,6 +86,11 @@ class Booster: lr_scheduler (LRScheduler): The lr_scheduler to be boosted. dataloader (DataLoader): The dataloader to be boosted. """ + # TODO(FrankLeeeee): add plugin control logic + # if self.plugin is not None and self.plugin.control_accelerator: + # ... + model = self.acceleartor.configure_model(model) + # TODO(FrankLeeeee): consider multi-model and multi-optimizer case # TODO(lsg): Add plugin control logic # e.g. diff --git a/tests/test_booster/test_accelerator.py b/tests/test_booster/test_accelerator.py new file mode 100644 index 000000000..4bfa3fd06 --- /dev/null +++ b/tests/test_booster/test_accelerator.py @@ -0,0 +1,13 @@ +import pytest +import torch.nn as nn +from torchvision.models import resnet18 + +from colossalai.booster.accelerator import Accelerator + + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +def test_accelerator(device): + acceleartor = Accelerator(device) + model = nn.Linear(8, 8) + model = acceleartor.configure_model(model) + assert next(model.parameters()).device.type == device diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 27a882913..71ecf7fca 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -56,6 +56,7 @@ def test_torchrec_dlrm_models(): data = data_gen_fn() # dlrm_interactionarch is not supported + # TODO(FrankLeeeee): support this model if name == 'dlrm_interactionarch': continue