mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
43 lines
1.9 KiB
43 lines
1.9 KiB
import torch.nn as nn |
|
from typing import List |
|
from colossalai.engine import BaseGradientHandler |
|
from typing import Iterable |
|
from torch.optim import Optimizer |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler |
|
|
|
|
|
def accumulate_gradient(model: nn.Module, |
|
optimizer: Optimizer, |
|
dataloader: Iterable, |
|
accumulate_size: int, |
|
gradient_handlers: List[BaseGradientHandler] = None, |
|
lr_scheduler: _LRScheduler = None): |
|
""" |
|
:param model: your model object |
|
:type model: :class:`torch.nn.Module` |
|
:param optimizer: your optimizer object |
|
:type optimizer: :class:`torch.optim.Optimizer` |
|
:param dataloader: your dataloader object |
|
:type dataloader: Iterable |
|
:param accumulate_size: the number of steps to accumulate gradients |
|
:type accumulate_size: int |
|
:param gradient_handlers: list of gradient handler objects. Default is None |
|
:type gradient_handlers: List[:class:`colossalai.engine.BaseGradientHandler`] |
|
:param lr_scheduler: your lr scheduler object. Default is None |
|
:type lr_scheduler: `torch.optim.lr_scheduler._LRScheduler` |
|
""" |
|
optimizer = GradAccumOptimizer(optimizer, accumulate_size=accumulate_size, model=model) |
|
dataloader = GradAccumDataloader(dataloader, accumulate_size=accumulate_size) |
|
|
|
if gradient_handlers is not None: |
|
gradient_handlers = [GradAccumGradientHandler(handler, accumulate_size) for handler in gradient_handlers] |
|
|
|
if lr_scheduler is not None: |
|
lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size) |
|
|
|
return optimizer, dataloader, gradient_handlers, lr_scheduler |
|
|
|
|
|
__all__ = ['accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', |
|
'GradAccumLrSchedulerByStep', 'GradAccumGradientHandler']
|
|
|