diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 2099883fb..e35b29c2a 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -11,7 +11,7 @@ from typing import Callable, Dict, List, Optional, Union import torch import torch.distributed as dist -from torch._six import inf +from torch import inf from torch.nn.parameter import Parameter from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES diff --git a/colossalai/zero/sharded_optim/_utils.py b/colossalai/zero/sharded_optim/_utils.py index e67434401..68928b232 100644 --- a/colossalai/zero/sharded_optim/_utils.py +++ b/colossalai/zero/sharded_optim/_utils.py @@ -3,7 +3,7 @@ from typing import Optional import torch import torch.distributed as dist -from torch._six import inf +from torch.six import inf from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from colossalai.tensor import ColoParameter