refactor code for sync_model_param()

pull/375/head
Wenwen Qu 2023-09-27 18:06:42 +08:00
parent 00478761f7
commit 7e505f3c59
1 changed files with 4 additions and 8 deletions

View File

@ -12,12 +12,6 @@ def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
def sync_tensor(tensor, parallel_mode):
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
ranks = gpc.get_ranks_in_group(parallel_mode)
dist.broadcast(tensor, src=ranks[0], group=gpc.get_group(parallel_mode))
def sync_model_param(model):
r"""Make sure data parameters are consistent during Data Parallel Mode.
@ -30,9 +24,11 @@ def sync_model_param(model):
)
for param in model.parameters():
if sync_moe_param and is_moe_param(param):
sync_tensor(param, ParallelMode.EXPERT_DATA)
ranks = gpc.get_ranks_in_group(ParallelMode.EXPERT_DATA)
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.EXPERT_DATA))
else:
sync_tensor(param, ParallelMode.DATA)
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
def sync_model_param_within_tp(model):