2023-03-21 09:39:30 +00:00
|
|
|
import warnings
|
2023-03-09 03:27:46 +00:00
|
|
|
from contextlib import contextmanager
|
2023-09-12 02:47:23 +00:00
|
|
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
2023-03-09 03:27:46 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
2023-03-27 02:24:14 +00:00
|
|
|
from colossalai.checkpoint_io import GeneralCheckpointIO
|
2023-07-04 04:00:33 +00:00
|
|
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
2023-03-27 02:24:14 +00:00
|
|
|
|
2023-03-20 05:59:24 +00:00
|
|
|
from .accelerator import Accelerator
|
2023-03-17 03:00:15 +00:00
|
|
|
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
2023-03-09 03:27:46 +00:00
|
|
|
from .plugin import Plugin
|
2023-07-25 16:53:57 +00:00
|
|
|
from .plugin.pp_plugin_base import PipelinePluginBase
|
2023-03-09 03:27:46 +00:00
|
|
|
|
|
|
|
__all__ = ['Booster']
|
|
|
|
|
|
|
|
|
|
|
|
class Booster:
|
2023-03-17 03:00:15 +00:00
|
|
|
"""
|
|
|
|
Booster is a high-level API for training neural networks. It provides a unified interface for
|
2023-04-26 03:38:43 +00:00
|
|
|
training with different precision, accelerator, and plugin.
|
2023-03-17 03:00:15 +00:00
|
|
|
|
2023-09-12 02:47:23 +00:00
|
|
|
|
|
|
|
```python
|
|
|
|
# Following is pseudocode
|
|
|
|
|
|
|
|
colossalai.launch(...)
|
|
|
|
plugin = GeminiPlugin(...)
|
|
|
|
booster = Booster(precision='fp16', plugin=plugin)
|
|
|
|
|
|
|
|
model = GPT2()
|
|
|
|
optimizer = HybridAdam(model.parameters())
|
|
|
|
dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
|
|
|
lr_scheduler = LinearWarmupScheduler()
|
|
|
|
criterion = GPTLMLoss()
|
|
|
|
|
|
|
|
model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)
|
|
|
|
|
|
|
|
for epoch in range(max_epochs):
|
|
|
|
for input_ids, attention_mask in dataloader:
|
|
|
|
outputs = model(input_ids.cuda(), attention_mask.cuda())
|
|
|
|
loss = criterion(outputs.logits, input_ids)
|
|
|
|
booster.backward(loss, optimizer)
|
|
|
|
optimizer.step()
|
|
|
|
lr_scheduler.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
```
|
2023-03-17 03:00:15 +00:00
|
|
|
|
|
|
|
Args:
|
2023-09-07 02:42:59 +00:00
|
|
|
device (str or torch.device): The device to run the training. Default: None.
|
|
|
|
If plugin is not used or plugin doesn't control the device,
|
|
|
|
this argument will be set as training device ('cuda' will be used if argument is None).
|
2023-03-17 03:00:15 +00:00
|
|
|
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
|
|
|
|
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
|
|
|
|
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
|
|
|
|
plugin (Plugin): The plugin to run the training. Default: None.
|
|
|
|
"""
|
2023-03-09 03:27:46 +00:00
|
|
|
|
|
|
|
def __init__(self,
|
2023-09-07 02:42:59 +00:00
|
|
|
device: Optional[str] = None,
|
2023-09-12 02:47:23 +00:00
|
|
|
mixed_precision: Optional[Union[MixedPrecision, str]] = None,
|
2023-03-09 03:27:46 +00:00
|
|
|
plugin: Optional[Plugin] = None) -> None:
|
2023-03-21 09:39:30 +00:00
|
|
|
if plugin is not None:
|
|
|
|
assert isinstance(
|
|
|
|
plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.'
|
|
|
|
self.plugin = plugin
|
|
|
|
|
|
|
|
# set accelerator
|
2023-03-27 02:24:14 +00:00
|
|
|
if self.plugin and self.plugin.control_device():
|
2023-03-21 09:39:30 +00:00
|
|
|
self.accelerator = None
|
2023-09-07 02:42:59 +00:00
|
|
|
if device is not None:
|
|
|
|
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
|
2023-03-17 03:00:15 +00:00
|
|
|
else:
|
2023-09-07 02:42:59 +00:00
|
|
|
device = device or 'cuda'
|
2023-03-21 09:39:30 +00:00
|
|
|
self.accelerator = Accelerator(device)
|
2023-03-17 03:00:15 +00:00
|
|
|
|
2023-03-21 09:39:30 +00:00
|
|
|
# set precision
|
2023-03-27 02:24:14 +00:00
|
|
|
if self.plugin and self.plugin.control_precision():
|
2023-09-07 02:42:59 +00:00
|
|
|
if mixed_precision is not None:
|
|
|
|
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
|
2023-03-27 02:24:14 +00:00
|
|
|
self.mixed_precision = None
|
|
|
|
elif mixed_precision is None:
|
|
|
|
self.mixed_precision = None
|
2023-03-21 09:39:30 +00:00
|
|
|
else:
|
|
|
|
# validate and set precision
|
2023-03-27 02:24:14 +00:00
|
|
|
if isinstance(mixed_precision, str):
|
2023-03-21 09:39:30 +00:00
|
|
|
# the user will take the default arguments for amp training
|
|
|
|
self.mixed_precision = mixed_precision_factory(mixed_precision)
|
|
|
|
elif isinstance(mixed_precision, MixedPrecision):
|
|
|
|
# the user can customize the arguments by passing the precision object
|
|
|
|
self.mixed_precision = mixed_precision
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
|
|
|
|
)
|
|
|
|
|
2023-03-27 02:24:14 +00:00
|
|
|
if self.plugin is not None and self.plugin.control_checkpoint_io():
|
|
|
|
self.checkpoint_io = self.plugin.get_checkpoint_io()
|
|
|
|
else:
|
|
|
|
self.checkpoint_io = GeneralCheckpointIO()
|
|
|
|
|
2023-03-21 09:39:30 +00:00
|
|
|
def boost(
|
|
|
|
self,
|
|
|
|
model: nn.Module,
|
2023-06-15 09:38:42 +00:00
|
|
|
optimizer: Optional[Optimizer] = None,
|
|
|
|
criterion: Optional[Callable] = None,
|
|
|
|
dataloader: Optional[DataLoader] = None,
|
|
|
|
lr_scheduler: Optional[LRScheduler] = None,
|
2023-03-21 09:39:30 +00:00
|
|
|
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
2023-03-17 03:00:15 +00:00
|
|
|
"""
|
2023-09-12 02:47:23 +00:00
|
|
|
Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.
|
2023-03-09 03:27:46 +00:00
|
|
|
|
2023-03-17 03:00:15 +00:00
|
|
|
Args:
|
2023-09-12 02:47:23 +00:00
|
|
|
model (nn.Module): Convert model into a wrapped model for distributive training.
|
|
|
|
The model might be decorated or partitioned by plugin's strategy after execution of this method.
|
|
|
|
optimizer (Optimizer, optional): Convert optimizer into a wrapped optimizer for distributive training.
|
|
|
|
The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
|
|
|
|
criterion (Callable, optional): The function that calculates loss. Defaults to None.
|
|
|
|
dataloader (DataLoader, optional): The prepared dataloader for training. Defaults to None.
|
|
|
|
lr_scheduler (LRScheduler, optional): The learning scheduler for training. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.
|
2023-03-17 03:00:15 +00:00
|
|
|
"""
|
|
|
|
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
|
2023-03-21 09:39:30 +00:00
|
|
|
# TODO(FrankLeeeee): consider multi-dataloader case
|
2023-03-17 03:00:15 +00:00
|
|
|
# transform model for mixed precision
|
2023-03-21 09:39:30 +00:00
|
|
|
if self.plugin:
|
|
|
|
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
|
|
|
|
model, optimizer, criterion, dataloader, lr_scheduler)
|
|
|
|
|
2023-03-27 02:24:14 +00:00
|
|
|
if self.plugin and not self.plugin.control_device():
|
2023-03-21 09:39:30 +00:00
|
|
|
# transform model for accelerator
|
|
|
|
model = self.accelerator.configure(model)
|
|
|
|
|
2023-03-27 02:24:14 +00:00
|
|
|
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
|
2023-03-21 09:39:30 +00:00
|
|
|
# transform model for mixed precision
|
2023-03-27 02:24:14 +00:00
|
|
|
# when mixed_precision is specified and the plugin is not given or does not control the precision
|
2023-03-21 09:39:30 +00:00
|
|
|
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
|
|
|
|
|
|
|
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
2023-03-09 03:27:46 +00:00
|
|
|
|
|
|
|
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
2023-09-12 02:47:23 +00:00
|
|
|
"""Execution of backward during training step.
|
2023-05-22 02:56:47 +00:00
|
|
|
|
|
|
|
Args:
|
2023-09-12 02:47:23 +00:00
|
|
|
loss (torch.Tensor): The loss for backpropagation.
|
2023-05-22 02:56:47 +00:00
|
|
|
optimizer (Optimizer): The optimizer to be updated.
|
|
|
|
"""
|
2023-08-14 09:43:33 +00:00
|
|
|
# TODO(frank lee): implement this method with plugin
|
2023-03-17 03:00:15 +00:00
|
|
|
optimizer.backward(loss)
|
2023-03-09 03:27:46 +00:00
|
|
|
|
|
|
|
def execute_pipeline(self,
|
|
|
|
data_iter: Iterator,
|
|
|
|
model: nn.Module,
|
2023-07-25 16:53:57 +00:00
|
|
|
criterion: Callable[[Any, Any], torch.Tensor],
|
2023-09-07 02:42:59 +00:00
|
|
|
optimizer: Optional[Optimizer] = None,
|
2023-03-09 03:27:46 +00:00
|
|
|
return_loss: bool = True,
|
2023-09-12 02:47:23 +00:00
|
|
|
return_outputs: bool = False) -> Dict[str, Any]:
|
|
|
|
"""
|
|
|
|
Execute forward & backward when utilizing pipeline parallel.
|
|
|
|
Return loss or Huggingface style model outputs if needed.
|
|
|
|
|
|
|
|
Warning: This function is tailored for the scenario of pipeline parallel.
|
|
|
|
As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward())
|
|
|
|
when doing pipeline parallel training with booster, which will cause unexpected errors.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
data_iter(Iterator): The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:
|
|
|
|
1. wrap the dataloader to iterator through: iter(dataloader)
|
|
|
|
2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
|
|
|
|
model (nn.Module): The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline.
|
|
|
|
criterion: (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
|
|
|
'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
|
|
|
|
optimizer (Optimizer, optional): The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
|
|
|
|
return_loss (bool, optional): Whether to return loss in the dict returned by this method. Defaults to True.
|
|
|
|
return_output (bool, optional): Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}.
|
|
|
|
ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
|
|
|
|
ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
|
|
|
|
"""
|
2023-07-25 16:53:57 +00:00
|
|
|
assert isinstance(self.plugin,
|
|
|
|
PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
|
|
|
|
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
|
2023-03-09 03:27:46 +00:00
|
|
|
|
2023-07-04 04:00:33 +00:00
|
|
|
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
|
2023-05-22 02:56:47 +00:00
|
|
|
"""Context manager to disable gradient synchronization across DP process groups.
|
2023-07-04 04:00:33 +00:00
|
|
|
Support torch DDP and Low Level ZeRO-1 for now.
|
2023-05-22 02:56:47 +00:00
|
|
|
|
|
|
|
Args:
|
2023-07-04 04:00:33 +00:00
|
|
|
model (nn.Module): The model to be disabled gradient synchronization, for DDP
|
|
|
|
optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1
|
2023-05-22 02:56:47 +00:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
contextmanager: Context to disable gradient synchronization.
|
|
|
|
"""
|
2023-03-21 09:39:30 +00:00
|
|
|
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
|
2023-07-04 04:00:33 +00:00
|
|
|
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
|
|
|
return self.plugin.no_sync(model, optimizer)
|
2023-03-09 03:27:46 +00:00
|
|
|
|
2023-09-12 02:47:23 +00:00
|
|
|
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
2023-05-19 10:05:08 +00:00
|
|
|
"""Load model from checkpoint.
|
|
|
|
|
|
|
|
Args:
|
2023-06-16 06:14:05 +00:00
|
|
|
model (nn.Module or ModelWrapper): A model boosted by Booster.
|
2023-05-19 10:05:08 +00:00
|
|
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
|
|
|
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
|
|
|
strict (bool, optional): whether to strictly enforce that the keys
|
|
|
|
in :attr:`state_dict` match the keys returned by this module's
|
|
|
|
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
|
|
|
|
"""
|
2023-03-27 02:24:14 +00:00
|
|
|
self.checkpoint_io.load_model(model, checkpoint, strict)
|
2023-03-09 03:27:46 +00:00
|
|
|
|
2023-03-27 02:24:14 +00:00
|
|
|
def save_model(self,
|
2023-06-16 06:14:05 +00:00
|
|
|
model: Union[nn.Module, ModelWrapper],
|
2023-03-27 02:24:14 +00:00
|
|
|
checkpoint: str,
|
|
|
|
shard: bool = False,
|
2023-06-16 06:14:05 +00:00
|
|
|
gather_dtensor: bool = True,
|
|
|
|
prefix: Optional[str] = None,
|
|
|
|
size_per_shard: int = 1024,
|
2023-09-12 02:47:23 +00:00
|
|
|
use_safetensors: bool = False) -> None:
|
2023-05-19 10:05:08 +00:00
|
|
|
"""Save model to checkpoint.
|
|
|
|
|
|
|
|
Args:
|
2023-06-16 06:14:05 +00:00
|
|
|
model (nn.Module or ModelWrapper): A model boosted by Booster.
|
2023-05-19 10:05:08 +00:00
|
|
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
|
|
|
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
|
|
|
shard (bool, optional): Whether to save checkpoint a sharded way.
|
2023-09-12 02:47:23 +00:00
|
|
|
If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
|
2023-06-16 06:14:05 +00:00
|
|
|
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
|
|
|
|
prefix (str, optional): A prefix added to parameter and buffer
|
|
|
|
names to compose the keys in state_dict. Defaults to None.
|
2023-05-19 10:05:08 +00:00
|
|
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
2023-06-16 06:14:05 +00:00
|
|
|
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
|
2023-05-19 10:05:08 +00:00
|
|
|
"""
|
2023-06-16 06:14:05 +00:00
|
|
|
self.checkpoint_io.save_model(model,
|
|
|
|
checkpoint=checkpoint,
|
|
|
|
shard=shard,
|
|
|
|
gather_dtensor=gather_dtensor,
|
|
|
|
prefix=prefix,
|
|
|
|
size_per_shard=size_per_shard,
|
|
|
|
use_safetensors=use_safetensors)
|
2023-03-27 02:24:14 +00:00
|
|
|
|
2023-09-12 02:47:23 +00:00
|
|
|
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
|
2023-05-19 10:05:08 +00:00
|
|
|
"""Load optimizer from checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): An optimizer boosted by Booster.
|
|
|
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
|
|
|
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
2023-06-16 06:14:05 +00:00
|
|
|
prefix (str, optional): A prefix added to parameter and buffer
|
|
|
|
names to compose the keys in state_dict. Defaults to None.
|
|
|
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
2023-05-19 10:05:08 +00:00
|
|
|
"""
|
2023-03-27 02:24:14 +00:00
|
|
|
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
|
|
|
|
2023-06-16 06:14:05 +00:00
|
|
|
def save_optimizer(self,
|
|
|
|
optimizer: Optimizer,
|
|
|
|
checkpoint: str,
|
|
|
|
shard: bool = False,
|
|
|
|
gather_dtensor: bool = True,
|
|
|
|
prefix: Optional[str] = None,
|
2023-09-12 02:47:23 +00:00
|
|
|
size_per_shard: int = 1024) -> None:
|
2023-06-16 06:14:05 +00:00
|
|
|
"""
|
|
|
|
Save optimizer to checkpoint.
|
2023-05-19 10:05:08 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (Optimizer): An optimizer boosted by Booster.
|
|
|
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
|
|
|
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
|
|
|
shard (bool, optional): Whether to save checkpoint a sharded way.
|
|
|
|
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
2023-06-16 06:14:05 +00:00
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
|
|
|
prefix (str, optional): A prefix added to parameter and buffer
|
|
|
|
names to compose the keys in state_dict. Defaults to None.
|
2023-05-19 10:05:08 +00:00
|
|
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
|
|
|
"""
|
2023-06-16 06:14:05 +00:00
|
|
|
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
|
2023-03-27 02:24:14 +00:00
|
|
|
|
2023-09-12 02:47:23 +00:00
|
|
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
2023-05-19 10:05:08 +00:00
|
|
|
"""Save lr scheduler to checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.
|
|
|
|
checkpoint (str): Path to the checkpoint. It must be a local file path.
|
|
|
|
"""
|
2023-03-27 02:24:14 +00:00
|
|
|
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
|
|
|
|
|
2023-09-12 02:47:23 +00:00
|
|
|
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
2023-05-19 10:05:08 +00:00
|
|
|
"""Load lr scheduler from checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.
|
|
|
|
checkpoint (str): Path to the checkpoint. It must be a local file path.
|
|
|
|
"""
|
2023-03-27 02:24:14 +00:00
|
|
|
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|