import warnings from contextlib import contextmanager from typing import Any, Callable, Dict, Iterator, List, Optional, Union import torch import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory from .plugin import Plugin from .plugin.pp_plugin_base import PipelinePluginBase __all__ = ["Booster"] class Booster: """ Booster is a high-level API for training neural networks. It provides a unified interface for training with different precision, accelerator, and plugin. ```python # Following is pseudocode colossalai.launch(...) plugin = GeminiPlugin(...) booster = Booster(precision='fp16', plugin=plugin) model = GPT2() optimizer = HybridAdam(model.parameters()) dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) lr_scheduler = LinearWarmupScheduler() criterion = GPTLMLoss() model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler) for epoch in range(max_epochs): for input_ids, attention_mask in dataloader: outputs = model(input_ids.cuda(), attention_mask.cuda()) loss = criterion(outputs.logits, input_ids) booster.backward(loss, optimizer) optimizer.step() lr_scheduler.step() optimizer.zero_grad() ``` Args: device (str or torch.device): The device to run the training. Default: None. If plugin is not used or plugin doesn't control the device, this argument will be set as training device ('cuda' will be used if argument is None). mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None. If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'. 'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex. plugin (Plugin): The plugin to run the training. Default: None. """ def __init__( self, device: Optional[str] = None, mixed_precision: Optional[Union[MixedPrecision, str]] = None, plugin: Optional[Plugin] = None, ) -> None: if plugin is not None: assert isinstance( plugin, Plugin ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}." self.plugin = plugin # set accelerator if self.plugin and self.plugin.control_device(): self.accelerator = None if device is not None: warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.") else: device = device or "cuda" self.accelerator = Accelerator(device) # set precision if self.plugin and self.plugin.control_precision(): if mixed_precision is not None: warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.") self.mixed_precision = None elif mixed_precision is None: self.mixed_precision = None else: # validate and set precision if isinstance(mixed_precision, str): # the user will take the default arguments for amp training self.mixed_precision = mixed_precision_factory(mixed_precision) elif isinstance(mixed_precision, MixedPrecision): # the user can customize the arguments by passing the precision object self.mixed_precision = mixed_precision else: raise ValueError( f"Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}." ) if self.plugin is not None and self.plugin.control_checkpoint_io(): self.checkpoint_io = self.plugin.get_checkpoint_io() else: self.checkpoint_io = GeneralCheckpointIO() def boost( self, model: nn.Module, optimizer: Optional[Optimizer] = None, criterion: Optional[Callable] = None, dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: """ Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader. Args: model (nn.Module): Convert model into a wrapped model for distributive training. The model might be decorated or partitioned by plugin's strategy after execution of this method. optimizer (Optimizer, optional): Convert optimizer into a wrapped optimizer for distributive training. The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None. criterion (Callable, optional): The function that calculates loss. Defaults to None. dataloader (DataLoader, optional): The prepared dataloader for training. Defaults to None. lr_scheduler (LRScheduler, optional): The learning scheduler for training. Defaults to None. Returns: List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments. """ # TODO(FrankLeeeee): consider multi-model and multi-optimizer case # TODO(FrankLeeeee): consider multi-dataloader case # transform model for mixed precision if self.plugin: model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( model, optimizer, criterion, dataloader, lr_scheduler ) if self.plugin and not self.plugin.control_device(): # transform model for accelerator model = self.accelerator.configure_model(model) if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()): # transform model for mixed precision # when mixed_precision is specified and the plugin is not given or does not control the precision model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion) return model, optimizer, criterion, dataloader, lr_scheduler def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: """Execution of backward during training step. Args: loss (torch.Tensor): The loss for backpropagation. optimizer (Optimizer): The optimizer to be updated. """ # TODO(frank lee): implement this method with plugin optimizer.backward(loss) def execute_pipeline( self, data_iter: Iterator, model: nn.Module, criterion: Callable[[Any, Any], torch.Tensor], optimizer: Optional[Optimizer] = None, return_loss: bool = True, return_outputs: bool = False, ) -> Dict[str, Any]: """ Execute forward & backward when utilizing pipeline parallel. Return loss or Huggingface style model outputs if needed. Warning: This function is tailored for the scenario of pipeline parallel. As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward()) when doing pipeline parallel training with booster, which will cause unexpected errors. Args: data_iter(Iterator): The iterator for getting the next batch of data. Usually there are two ways to obtain this argument: 1. wrap the dataloader to iterator through: iter(dataloader) 2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch]) model (nn.Module): The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline. criterion: (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. 'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here. optimizer (Optimizer, optional): The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None. return_loss (bool, optional): Whether to return loss in the dict returned by this method. Defaults to True. return_output (bool, optional): Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False. Returns: Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}. ret_dict['loss'] is the loss of forward if return_loss is set to True, else None. ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None. """ assert isinstance( self.plugin, PipelinePluginBase ), f"The plugin {self.plugin.__class__.__name__} does not support pipeline." return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs) def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager: """Context manager to disable gradient synchronization across DP process groups. Support torch DDP and Low Level ZeRO-1 for now. Args: model (nn.Module): The model to be disabled gradient synchronization, for DDP optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1 Returns: contextmanager: Context to disable gradient synchronization. """ assert ( self.plugin is not None ), f"no_sync is only enabled when a plugin is provided and the plugin supports no_sync." assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync." return self.plugin.no_sync(model, optimizer) def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None: """Load model from checkpoint. Args: model (nn.Module or ModelWrapper): A model boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Defaults to True. """ self.checkpoint_io.load_model(model, checkpoint, strict) def save_model( self, model: Union[nn.Module, ModelWrapper], checkpoint: str, shard: bool = False, gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, ) -> None: """Save model to checkpoint. Args: model (nn.Module or ModelWrapper): A model boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local path. It is a file path if ``shard=False``. Otherwise, it is a directory path. shard (bool, optional): Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False. gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True. prefix (str, optional): A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved. """ self.checkpoint_io.save_model( model, checkpoint=checkpoint, shard=shard, gather_dtensor=gather_dtensor, prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, ) def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None: """Load optimizer from checkpoint. Args: optimizer (Optimizer): An optimizer boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. prefix (str, optional): A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ self.checkpoint_io.load_optimizer(optimizer, checkpoint) def save_optimizer( self, optimizer: Optimizer, checkpoint: str, shard: bool = False, gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024, ) -> None: """ Save optimizer to checkpoint. Args: optimizer (Optimizer): An optimizer boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local path. It is a file path if ``shard=False``. Otherwise, it is a directory path. shard (bool, optional): Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. prefix (str, optional): A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None: """Save lr scheduler to checkpoint. Args: lr_scheduler (LRScheduler): A lr scheduler boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local file path. """ self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint) def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None: """Load lr scheduler from checkpoint. Args: lr_scheduler (LRScheduler): A lr scheduler boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local file path. """ self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)