mirror of https://github.com/hpcaitech/ColossalAI
[booster] init module structure and definition (#3056)
parent
faa8526b85
commit
f19b49e164
|
@ -0,0 +1,5 @@
|
||||||
|
from .accelerator import Accelerator
|
||||||
|
from .booster import Booster
|
||||||
|
from .environment_table import EnvironmentTable
|
||||||
|
from .plugin import Plugin
|
||||||
|
from .precision import Precision
|
|
@ -0,0 +1,14 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
__all__ = ['Accelerator']
|
||||||
|
|
||||||
|
|
||||||
|
class Accelerator:
|
||||||
|
|
||||||
|
def __init__(self, device: torch.device):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
# TODO: implement this method
|
||||||
|
pass
|
|
@ -0,0 +1,66 @@
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from .plugin import Plugin
|
||||||
|
|
||||||
|
__all__ = ['Booster']
|
||||||
|
|
||||||
|
|
||||||
|
class Booster:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
device: Union[str, torch.device] = 'cuda',
|
||||||
|
precision: str = 'fp32',
|
||||||
|
grad_clipping_type: str = 'norm',
|
||||||
|
grad_clipping_value: float = 0.0,
|
||||||
|
plugin: Optional[Plugin] = None) -> None:
|
||||||
|
# TODO: implement this method
|
||||||
|
pass
|
||||||
|
|
||||||
|
def boost(
|
||||||
|
self, *args: Union[nn.Module, Optimizer, LRScheduler, DataLoader]
|
||||||
|
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
||||||
|
# TODO: implement this method
|
||||||
|
pass
|
||||||
|
|
||||||
|
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
||||||
|
# TODO: implement this method
|
||||||
|
pass
|
||||||
|
|
||||||
|
def execute_pipeline(self,
|
||||||
|
data_iter: Iterator,
|
||||||
|
model: nn.Module,
|
||||||
|
criterion: Callable[[torch.Tensor], torch.Tensor],
|
||||||
|
optimizer: Optimizer,
|
||||||
|
return_loss: bool = True,
|
||||||
|
return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]:
|
||||||
|
# TODO: implement this method
|
||||||
|
# run pipeline forward backward pass
|
||||||
|
# return loss or outputs if needed
|
||||||
|
pass
|
||||||
|
|
||||||
|
def no_sync(self, model: nn.Module) -> contextmanager:
|
||||||
|
# TODO: implement this method
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save(self,
|
||||||
|
obj: Union[nn.Module, Optimizer, LRScheduler],
|
||||||
|
path_like: str,
|
||||||
|
plan: str = 'torch',
|
||||||
|
**kwargs) -> None:
|
||||||
|
# TODO: implement this method
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load(self,
|
||||||
|
obj: Union[nn.Module, Optimizer, LRScheduler],
|
||||||
|
path_like: str,
|
||||||
|
plan: str = 'torch',
|
||||||
|
**kwargs) -> None:
|
||||||
|
# TODO: implement this method
|
||||||
|
pass
|
|
@ -0,0 +1,18 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
__all__ = ['EnvironmentTable']
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentTable:
|
||||||
|
|
||||||
|
def __init__(self, intra_op_world_sizes: List[int]):
|
||||||
|
# TODO: implement this method
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_master(self) -> bool:
|
||||||
|
# TODO: implement this method
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO: implement more utility methods as given in
|
||||||
|
# https://github.com/hpcaitech/ColossalAI/issues/3051
|
|
@ -0,0 +1,46 @@
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
__all__ = ['Plugin']
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_devices(self) -> List[torch.device]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_precisions(self) -> List[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def control_precision(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def control_device(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def support_no_sync(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setup_model(self, model: nn.Module, device_mesh_pool: DeviceMesh) -> nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setup_dataloader(self, dataloader: DataLoader) -> DataLoader:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device_mesh_shape(self) -> List[Tuple[int, ...]]:
|
||||||
|
pass
|
|
@ -0,0 +1,25 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
__all__ = ['Precision']
|
||||||
|
|
||||||
|
|
||||||
|
class Precision:
|
||||||
|
|
||||||
|
def __init__(self, precision_type: torch.dtype, grad_clipping_type: str, grad_clipping_value: float):
|
||||||
|
self.precision_type = precision_type
|
||||||
|
self.grad_clipping_type = grad_clipping_type
|
||||||
|
self.grad_clipping_value = grad_clipping_value
|
||||||
|
|
||||||
|
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
# TODO: implement this method
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
||||||
|
# TODO: implement this method
|
||||||
|
# inject grad clipping and unscale loss
|
||||||
|
pass
|
||||||
|
|
||||||
|
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
|
||||||
|
pass
|
Loading…
Reference in New Issue