import warnings from contextlib import contextmanager from typing import Callable, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from colossalai.checkpoint_io import GeneralCheckpointIO from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory from .plugin import Plugin __all__ = ['Booster'] class Booster: """ Booster is a high-level API for training neural networks. It provides a unified interface for training with different precisio, accelerator, and plugin. Examples: >>> colossalai.launch(...) >>> plugin = GeminiPlugin(stage=3, ...) >>> booster = Booster(precision='fp16', plugin=plugin) >>> >>> model = GPT2() >>> optimizer = Adam(model.parameters()) >>> dataloader = Dataloader(Dataset) >>> lr_scheduler = LinearWarmupScheduler() >>> criterion = GPTLMLoss() >>> >>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader) >>> >>> for epoch in range(max_epochs): >>> for input_ids, attention_mask in dataloader: >>> outputs = model(input_ids, attention_mask) >>> loss = criterion(outputs.logits, input_ids) >>> booster.backward(loss, optimizer) >>> optimizer.step() >>> lr_scheduler.step() >>> optimizer.zero_grad() Args: device (str or torch.device): The device to run the training. Default: 'cuda'. mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None. If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'. 'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex. plugin (Plugin): The plugin to run the training. Default: None. """ def __init__(self, device: str = 'cuda', mixed_precision: Union[MixedPrecision, str] = None, plugin: Optional[Plugin] = None) -> None: if plugin is not None: assert isinstance( plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.' self.plugin = plugin # set accelerator if self.plugin and self.plugin.control_device(): self.accelerator = None warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') else: self.accelerator = Accelerator(device) # set precision if self.plugin and self.plugin.control_precision(): warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') self.mixed_precision = None elif mixed_precision is None: self.mixed_precision = None else: # validate and set precision if isinstance(mixed_precision, str): # the user will take the default arguments for amp training self.mixed_precision = mixed_precision_factory(mixed_precision) elif isinstance(mixed_precision, MixedPrecision): # the user can customize the arguments by passing the precision object self.mixed_precision = mixed_precision else: raise ValueError( f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.' ) if self.plugin is not None and self.plugin.control_checkpoint_io(): self.checkpoint_io = self.plugin.get_checkpoint_io() else: self.checkpoint_io = GeneralCheckpointIO() def boost( self, model: nn.Module, optimizer: Optimizer, criterion: Callable = None, dataloader: DataLoader = None, lr_scheduler: LRScheduler = None, ) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: """ Boost the model, optimizer, criterion, lr_scheduler, and dataloader. Args: model (nn.Module): The model to be boosted. optimizer (Optimizer): The optimizer to be boosted. criterion (Callable): The criterion to be boosted. dataloader (DataLoader): The dataloader to be boosted. lr_scheduler (LRScheduler): The lr_scheduler to be boosted. """ # TODO(FrankLeeeee): consider multi-model and multi-optimizer case # TODO(FrankLeeeee): consider multi-dataloader case # transform model for mixed precision if self.plugin: model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( model, optimizer, criterion, dataloader, lr_scheduler) if self.plugin and not self.plugin.control_device(): # transform model for accelerator model = self.accelerator.configure(model) if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()): # transform model for mixed precision # when mixed_precision is specified and the plugin is not given or does not control the precision model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion) return model, optimizer, criterion, dataloader, lr_scheduler def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: # TODO: implement this method with plugin optimizer.backward(loss) def execute_pipeline(self, data_iter: Iterator, model: nn.Module, criterion: Callable[[torch.Tensor], torch.Tensor], optimizer: Optimizer, return_loss: bool = True, return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]: # TODO: implement this method # run pipeline forward backward pass # return loss or outputs if needed pass def no_sync(self, model: nn.Module) -> contextmanager: assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.' assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' return self.plugin.no_sync(model) def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): self.checkpoint_io.load_model(model, checkpoint, strict) def save_model(self, model: nn.Module, checkpoint: str, prefix: str = None, shard: bool = False, size_per_shard: int = 1024): self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard) def load_optimizer(self, optimizer: Optimizer, checkpoint: str): self.checkpoint_io.load_optimizer(optimizer, checkpoint) def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint) def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)