mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
44 lines
1.9 KiB
44 lines
1.9 KiB
import torch.nn as nn
|
|
from typing import List
|
|
from colossalai.engine import BaseGradientHandler
|
|
from typing import Iterable
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler
|
|
|
|
|
|
def accumulate_gradient(model: nn.Module,
|
|
optimizer: Optimizer,
|
|
dataloader: Iterable,
|
|
accumulate_size: int,
|
|
gradient_handlers: List[BaseGradientHandler] = None,
|
|
lr_scheduler: _LRScheduler = None):
|
|
"""
|
|
:param model: your model object
|
|
:type model: :class:`torch.nn.Module`
|
|
:param optimizer: your optimizer object
|
|
:type optimizer: :class:`torch.optim.Optimizer`
|
|
:param dataloader: your dataloader object
|
|
:type dataloader: Iterable
|
|
:param accumulate_size: the number of steps to accumulate gradients
|
|
:type accumulate_size: int
|
|
:param gradient_handlers: list of gradient handler objects. Default is None
|
|
:type gradient_handlers: List[:class:`colossalai.engine.BaseGradientHandler`]
|
|
:param lr_scheduler: your lr scheduler object. Default is None
|
|
:type lr_scheduler: `torch.optim.lr_scheduler._LRScheduler`
|
|
"""
|
|
optimizer = GradAccumOptimizer(optimizer, accumulate_size=accumulate_size, model=model)
|
|
dataloader = GradAccumDataloader(dataloader, accumulate_size=accumulate_size)
|
|
|
|
if gradient_handlers is not None:
|
|
gradient_handlers = [GradAccumGradientHandler(handler, accumulate_size) for handler in gradient_handlers]
|
|
|
|
if lr_scheduler is not None:
|
|
lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size)
|
|
|
|
return optimizer, dataloader, gradient_handlers, lr_scheduler
|
|
|
|
|
|
__all__ = ['accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer',
|
|
'GradAccumLrSchedulerByStep', 'GradAccumGradientHandler']
|