mirror of https://github.com/InternLM/InternLM
fix above codes:
*treat optim.zero_world_size and optim.zero_local_rank as list in model_checkpoint.py and test_model_checkpoint.py *add overlap and zero check for moe in args_sanity_check(.)pull/404/head
parent
c018e9216f
commit
e8fcbb1ad5
|
@ -319,6 +319,13 @@ def args_sanity_check():
|
|||
if "moe_loss_coeff" not in gpc.config.loss:
|
||||
gpc.config.loss._add_item("moe_loss_coeff", 1.0)
|
||||
|
||||
# moe not support overlap and zero1.5 for now
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
assert (
|
||||
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
|
||||
), "not support overlap and moe at the same time"
|
||||
assert gpc.config.parallel.zero1 == -1, "moe only support zero1, set zero1=-1 can fix this"
|
||||
|
||||
|
||||
def launch(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
|
|
|
@ -288,7 +288,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
for j in range(len(param.size())):
|
||||
global_id = "_".join([global_id, str(param.size()[j])])
|
||||
if self._overlap_sync_param:
|
||||
assert not hasattr(gpc.config.model, "num_experts")
|
||||
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
||||
else:
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
|
|
|
@ -392,13 +392,14 @@ def save_optimizer_checkpoint(optim, state_path):
|
|||
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
zero_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
|
||||
|
||||
states = optim.state_dict()
|
||||
if isinstance(optim, HybridZeroOptimizer):
|
||||
if gpc.get_global_rank() < optim.zero_world_size * tp_size * pp_size:
|
||||
if gpc.get_global_rank() < zero_size * tp_size * pp_size:
|
||||
llm_save(os.path.join(state_path, fp), states)
|
||||
if "zero_devide_optim_plan" in states:
|
||||
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")
|
||||
|
|
|
@ -109,7 +109,7 @@ def compare_optim_state(optim1, optim2):
|
|||
fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
|
||||
for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
|
||||
re &= group_id_1 == group_id_2
|
||||
if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]:
|
||||
if optim1.zero_local_rank[group_id_1] not in optim1.param_group_no_params_ranks[group_id_1]:
|
||||
re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
|
||||
else:
|
||||
for group1, group2 in zip(optim1.param_groups, optim2.param_groups):
|
||||
|
|
Loading…
Reference in New Issue