diff --git a/colossalai/amp/torch_amp/_grad_scaler.py b/colossalai/amp/torch_amp/_grad_scaler.py index de39b3e16..7b78998fb 100644 --- a/colossalai/amp/torch_amp/_grad_scaler.py +++ b/colossalai/amp/torch_amp/_grad_scaler.py @@ -3,16 +3,18 @@ # modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py # to support tensor parallel -import torch -from collections import defaultdict, abc import warnings +from collections import abc, defaultdict from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from colossalai.context import ParallelMode + +import torch import torch.distributed as dist -from colossalai.core import global_context as gpc -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from packaging import version +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc class _MultiDeviceReplicator(object):