|
|
|
@ -1,21 +1,22 @@
|
|
|
|
|
#!/usr/bin/env python
|
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
from typing import Union
|
|
|
|
|
from typing import Any, Iterable, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
from typing import Iterable, Any, Tuple
|
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
|
|
|
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
from colossalai.utils import conditional_context
|
|
|
|
|
|
|
|
|
|
from colossalai.engine import BaseGradientHandler
|
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
|
|
|
|
from colossalai.utils import conditional_context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradAccumOptimizer(ColossalaiOptimizer):
|
|
|
|
|
"""A wrapper for the optimizer to enable gradient accumulation by skipping the steps
|
|
|
|
|
"""A wrapper for the optimizer to enable gradient accumulation by skipping the steps
|
|
|
|
|
before accumulation size is reached.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -161,7 +162,7 @@ class GradAccumDataloader:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradAccumLrSchedulerByStep(_LRScheduler):
|
|
|
|
|
"""A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps
|
|
|
|
|
"""A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps
|
|
|
|
|
before accumulation size is reached.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|