mirror of https://github.com/hpcaitech/ColossalAI
[zero] use colo model data api in optimv2 (#511)
parent
9330be0f3c
commit
bca0c49a9d
|
@ -15,8 +15,8 @@ from torch import Tensor
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from ._utils import has_inf_or_nan
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
|
@ -161,7 +161,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.col_attr.sharded_data_tensor.copy_payload(p.data)
|
||||
colo_model_data_tensor_move(p, p.col_attr.sharded_data_tensor)
|
||||
|
||||
if not is_param_sharded:
|
||||
# We gather full fp16 param here
|
||||
|
|
Loading…
Reference in New Issue