|
|
@ -1,17 +1,18 @@ |
|
|
|
#!/usr/bin/env python |
|
|
|
#!/usr/bin/env python |
|
|
|
# -*- encoding: utf-8 -*- |
|
|
|
# -*- encoding: utf-8 -*- |
|
|
|
|
|
|
|
|
|
|
|
from typing import Union |
|
|
|
from typing import Any, Iterable, Tuple, Union |
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn as nn |
|
|
|
from torch import Tensor |
|
|
|
from torch import Tensor |
|
|
|
from typing import Iterable, Any, Tuple |
|
|
|
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer |
|
|
|
|
|
|
|
from torch.nn.parallel.distributed import DistributedDataParallel |
|
|
|
from torch.nn.parallel.distributed import DistributedDataParallel |
|
|
|
from torch.optim import Optimizer |
|
|
|
from torch.optim import Optimizer |
|
|
|
from torch.optim.lr_scheduler import _LRScheduler |
|
|
|
from torch.optim.lr_scheduler import _LRScheduler |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from colossalai.utils import conditional_context |
|
|
|
|
|
|
|
from colossalai.engine import BaseGradientHandler |
|
|
|
from colossalai.engine import BaseGradientHandler |
|
|
|
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer |
|
|
|
|
|
|
|
from colossalai.utils import conditional_context |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradAccumOptimizer(ColossalaiOptimizer): |
|
|
|
class GradAccumOptimizer(ColossalaiOptimizer): |
|
|
|