diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index d90e0e0..3a10227 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -12,12 +12,6 @@ def is_model_parallel_parameter(p): return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) -def sync_tensor(tensor, parallel_mode): - if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: - ranks = gpc.get_ranks_in_group(parallel_mode) - dist.broadcast(tensor, src=ranks[0], group=gpc.get_group(parallel_mode)) - - def sync_model_param(model): r"""Make sure data parameters are consistent during Data Parallel Mode. @@ -30,9 +24,11 @@ def sync_model_param(model): ) for param in model.parameters(): if sync_moe_param and is_moe_param(param): - sync_tensor(param, ParallelMode.EXPERT_DATA) + ranks = gpc.get_ranks_in_group(ParallelMode.EXPERT_DATA) + dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.EXPERT_DATA)) else: - sync_tensor(param, ParallelMode.DATA) + ranks = gpc.get_ranks_in_group(ParallelMode.DATA) + dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA)) def sync_model_param_within_tp(model):