|
|
@ -3,16 +3,18 @@
|
|
|
|
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
|
|
|
|
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
|
|
|
|
# to support tensor parallel
|
|
|
|
# to support tensor parallel
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from collections import defaultdict, abc
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
from collections import abc, defaultdict
|
|
|
|
from enum import Enum
|
|
|
|
from enum import Enum
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from colossalai.context import ParallelMode
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.distributed as dist
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
|
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|
|
|
|
|
|
|
from packaging import version
|
|
|
|
from packaging import version
|
|
|
|
|
|
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from colossalai.context import ParallelMode
|
|
|
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _MultiDeviceReplicator(object):
|
|
|
|
class _MultiDeviceReplicator(object):
|
|
|
|