|
|
|
@ -1,21 +1,22 @@
|
|
|
|
|
#!/usr/bin/env python |
|
|
|
|
# -*- encoding: utf-8 -*- |
|
|
|
|
|
|
|
|
|
from typing import Union |
|
|
|
|
from typing import Any, Iterable, Tuple, Union |
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
from torch import Tensor |
|
|
|
|
from typing import Iterable, Any, Tuple |
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer |
|
|
|
|
from torch.nn.parallel.distributed import DistributedDataParallel |
|
|
|
|
from torch.optim import Optimizer |
|
|
|
|
from torch.optim.lr_scheduler import _LRScheduler |
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
from colossalai.utils import conditional_context |
|
|
|
|
|
|
|
|
|
from colossalai.engine import BaseGradientHandler |
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer |
|
|
|
|
from colossalai.utils import conditional_context |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradAccumOptimizer(ColossalaiOptimizer): |
|
|
|
|
"""A wrapper for the optimizer to enable gradient accumulation by skipping the steps |
|
|
|
|
"""A wrapper for the optimizer to enable gradient accumulation by skipping the steps |
|
|
|
|
before accumulation size is reached. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -161,7 +162,7 @@ class GradAccumDataloader:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradAccumLrSchedulerByStep(_LRScheduler): |
|
|
|
|
"""A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps |
|
|
|
|
"""A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps |
|
|
|
|
before accumulation size is reached. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|