[zero] use colo model data api in optimv2 (#511)

pull/510/head
Jiarui Fang 2022-03-24 17:19:34 +08:00 committed by GitHub
parent 9330be0f3c
commit bca0c49a9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -15,8 +15,8 @@ from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from ._utils import has_inf_or_nan
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move
class OptimState(Enum):
@ -161,7 +161,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.col_attr.sharded_data_tensor.copy_payload(p.data)
colo_model_data_tensor_move(p, p.col_attr.sharded_data_tensor)
if not is_param_sharded:
# We gather full fp16 param here