diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index a975ed8da..e69cf6654 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -15,8 +15,8 @@ from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from torch.optim import Optimizer - -from ._utils import has_inf_or_nan +from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move class OptimState(Enum): @@ -161,7 +161,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16 # TODO() optimize this line CPU (fp32) -> GPU (fp16) - p.col_attr.sharded_data_tensor.copy_payload(p.data) + colo_model_data_tensor_move(p, p.col_attr.sharded_data_tensor) if not is_param_sharded: # We gather full fp16 param here