|
|
@ -1,17 +1,18 @@
|
|
|
|
#!/usr/bin/env python
|
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Union
|
|
|
|
from typing import Any, Iterable, Tuple, Union
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch import Tensor
|
|
|
|
from torch import Tensor
|
|
|
|
from typing import Iterable, Any, Tuple
|
|
|
|
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
|
|
|
|
|
|
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
|
|
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from colossalai.utils import conditional_context
|
|
|
|
|
|
|
|
from colossalai.engine import BaseGradientHandler
|
|
|
|
from colossalai.engine import BaseGradientHandler
|
|
|
|
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
|
|
|
|
|
|
|
from colossalai.utils import conditional_context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradAccumOptimizer(ColossalaiOptimizer):
|
|
|
|
class GradAccumOptimizer(ColossalaiOptimizer):
|
|
|
|