import torch.nn as nn from typing import List from colossalai.engine import BaseGradientHandler from typing import Iterable from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler __all__ = [ 'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep', 'GradAccumGradientHandler' ] def accumulate_gradient(model: nn.Module, optimizer: Optimizer, dataloader: Iterable, accumulate_size: int, gradient_handlers: List[BaseGradientHandler] = None, lr_scheduler: _LRScheduler = None): r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation. Args: model (:class:`torch.nn.Module`): your model object for gradient accumulation. optimizer (:class:`torch.optim.Optimizer`): your optimizer object for gradient accumulation. dataloader (:class:`torch.utils.data.DataLoader` or iterable objects): your dataloader object, would be called like iter(dataloader) accumulate_size (int): the number of steps to accumulate gradients gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]): list of gradient handler objects. Default is None. lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`): your ``lr_scheduler`` object for gradient accumulation. Defaults to None. More details about `gradient_handlers` could be found in `Gradient_handler <https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/engine/gradient_handler>`_. More details about `lr_scheduler` could be found `lr_scheduler <https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/nn/lr_scheduler>`_. and `how to adjust learning rate <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_. """ optimizer = GradAccumOptimizer(optimizer, accumulate_size=accumulate_size, model=model) dataloader = GradAccumDataloader(dataloader, accumulate_size=accumulate_size) if gradient_handlers is not None: gradient_handlers = [GradAccumGradientHandler(handler, accumulate_size) for handler in gradient_handlers] if lr_scheduler is not None: lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size) return optimizer, dataloader, gradient_handlers, lr_scheduler