mirror of https://github.com/InternLM/InternLM
refactor code for sync_model_param()
parent
00478761f7
commit
7e505f3c59
|
@ -12,12 +12,6 @@ def is_model_parallel_parameter(p):
|
|||
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
|
||||
|
||||
|
||||
def sync_tensor(tensor, parallel_mode):
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||
ranks = gpc.get_ranks_in_group(parallel_mode)
|
||||
dist.broadcast(tensor, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||
|
||||
|
||||
def sync_model_param(model):
|
||||
r"""Make sure data parameters are consistent during Data Parallel Mode.
|
||||
|
||||
|
@ -30,9 +24,11 @@ def sync_model_param(model):
|
|||
)
|
||||
for param in model.parameters():
|
||||
if sync_moe_param and is_moe_param(param):
|
||||
sync_tensor(param, ParallelMode.EXPERT_DATA)
|
||||
ranks = gpc.get_ranks_in_group(ParallelMode.EXPERT_DATA)
|
||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.EXPERT_DATA))
|
||||
else:
|
||||
sync_tensor(param, ParallelMode.DATA)
|
||||
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
|
||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
|
||||
def sync_model_param_within_tp(model):
|
||||
|
|
Loading…
Reference in New Issue