[booster] added the accelerator implementation (#3159)

pull/3178/head
Frank Lee 2023-03-20 13:59:24 +08:00 committed by GitHub
parent 1ad3a636b1
commit a9b8402d93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 5 deletions

View File

@ -3,12 +3,52 @@ import torch.nn as nn
__all__ = ['Accelerator']
_supported_devices = [
'cpu',
'cuda',
# To be supported
# 'xpu',
# 'npu',
# 'tpu',
]
class Accelerator:
"""
Accelerator is an abstraction for the hardware device that is used to run the model.
def __init__(self, device: torch.device):
Args:
device (str): The device to be used. Currently only support 'cpu' and 'gpu'.
"""
def __init__(self, device: str):
self.device = device
def setup_model(self, model: nn.Module) -> nn.Module:
# TODO: implement this method
pass
assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
def bind(self):
"""
Set the default device for the current process.
"""
if self.device == 'cpu':
pass
elif self.device == 'cuda':
# TODO(FrankLeeeee): use global environment to check if it is a dist job
# if is_distributed:
# local_rank = EnvTable().get_local_rank()
# torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
torch.cuda.set_device(torch.device('cuda'))
pass
else:
raise ValueError(f"Device {self.device} is not supported yet")
def configure_model(self, model: nn.Module) -> nn.Module:
"""
Move the model to the device.
Args:
model (nn.Module): The model to be moved.
"""
model = model.to(torch.device(self.device))
return model

View File

@ -8,6 +8,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
@ -51,9 +52,16 @@ class Booster:
"""
def __init__(self,
device: Union[str, torch.device] = 'cuda',
device: str = 'cuda',
mixed_precision: Union[MixedPrecision, str] = None,
plugin: Optional[Plugin] = None) -> None:
# TODO(FrankLeeeee): add plugin control logic
# if self.plugin is not None and self.plugin.control_accelerator:
# ...
# create acclerator
self.acceleartor = Accelerator(device)
self.acceleartor.set_default_device()
# validate and set precision
if isinstance(MixedPrecision, str):
# the user will take the default arguments for amp training
@ -78,6 +86,11 @@ class Booster:
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
dataloader (DataLoader): The dataloader to be boosted.
"""
# TODO(FrankLeeeee): add plugin control logic
# if self.plugin is not None and self.plugin.control_accelerator:
# ...
model = self.acceleartor.configure_model(model)
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(lsg): Add plugin control logic
# e.g.

View File

@ -0,0 +1,13 @@
import pytest
import torch.nn as nn
from torchvision.models import resnet18
from colossalai.booster.accelerator import Accelerator
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_accelerator(device):
acceleartor = Accelerator(device)
model = nn.Linear(8, 8)
model = acceleartor.configure_model(model)
assert next(model.parameters()).device.type == device

View File

@ -56,6 +56,7 @@ def test_torchrec_dlrm_models():
data = data_gen_fn()
# dlrm_interactionarch is not supported
# TODO(FrankLeeeee): support this model
if name == 'dlrm_interactionarch':
continue