diff --git a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py index 89c28c3be..cf66be1cd 100644 --- a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py +++ b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py @@ -1,21 +1,22 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Union +from typing import Any, Iterable, Tuple, Union + import torch.nn as nn from torch import Tensor -from typing import Iterable, Any, Tuple -from colossalai.nn.optimizer import ColossalaiOptimizer from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from colossalai.utils import conditional_context + from colossalai.engine import BaseGradientHandler +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils import conditional_context class GradAccumOptimizer(ColossalaiOptimizer): - """A wrapper for the optimizer to enable gradient accumulation by skipping the steps + """A wrapper for the optimizer to enable gradient accumulation by skipping the steps before accumulation size is reached. Args: @@ -161,7 +162,7 @@ class GradAccumDataloader: class GradAccumLrSchedulerByStep(_LRScheduler): - """A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps + """A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps before accumulation size is reached. Args: