fix above codes:

*treat optim.zero_world_size and optim.zero_local_rank as list in model_checkpoint.py and test_model_checkpoint.py
    *add overlap and zero check for moe in args_sanity_check(.)
pull/404/head
Qu Wenwen 2023-10-09 16:05:43 +08:00
parent c018e9216f
commit e8fcbb1ad5
4 changed files with 10 additions and 3 deletions

View File

@ -319,6 +319,13 @@ def args_sanity_check():
if "moe_loss_coeff" not in gpc.config.loss: if "moe_loss_coeff" not in gpc.config.loss:
gpc.config.loss._add_item("moe_loss_coeff", 1.0) gpc.config.loss._add_item("moe_loss_coeff", 1.0)
# moe not support overlap and zero1.5 for now
if hasattr(gpc.config.model, "num_experts"):
assert (
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
), "not support overlap and moe at the same time"
assert gpc.config.parallel.zero1 == -1, "moe only support zero1, set zero1=-1 can fix this"
def launch( def launch(
config: Union[str, Path, Config, Dict], config: Union[str, Path, Config, Dict],

View File

@ -288,7 +288,6 @@ class HybridZeroOptimizer(BaseOptimizer):
for j in range(len(param.size())): for j in range(len(param.size())):
global_id = "_".join([global_id, str(param.size()[j])]) global_id = "_".join([global_id, str(param.size()[j])])
if self._overlap_sync_param: if self._overlap_sync_param:
assert not hasattr(gpc.config.model, "num_experts")
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param) rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
else: else:
rank_to_go = numel_per_rank.index(min(numel_per_rank)) rank_to_go = numel_per_rank.index(min(numel_per_rank))

View File

@ -392,13 +392,14 @@ def save_optimizer_checkpoint(optim, state_path):
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1) zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
zero_size = gpc.get_world_size(ParallelMode.ZERO1)
tp_size = gpc.get_world_size(ParallelMode.TENSOR) tp_size = gpc.get_world_size(ParallelMode.TENSOR)
pp_size = gpc.get_world_size(ParallelMode.PIPELINE) pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt" fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
states = optim.state_dict() states = optim.state_dict()
if isinstance(optim, HybridZeroOptimizer): if isinstance(optim, HybridZeroOptimizer):
if gpc.get_global_rank() < optim.zero_world_size * tp_size * pp_size: if gpc.get_global_rank() < zero_size * tp_size * pp_size:
llm_save(os.path.join(state_path, fp), states) llm_save(os.path.join(state_path, fp), states)
if "zero_devide_optim_plan" in states: if "zero_devide_optim_plan" in states:
params_per_rank_id_dict = states.pop("zero_devide_optim_plan") params_per_rank_id_dict = states.pop("zero_devide_optim_plan")

View File

@ -109,7 +109,7 @@ def compare_optim_state(optim1, optim2):
fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2): for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
re &= group_id_1 == group_id_2 re &= group_id_1 == group_id_2
if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]: if optim1.zero_local_rank[group_id_1] not in optim1.param_group_no_params_ranks[group_id_1]:
re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2]) re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
else: else:
for group1, group2 in zip(optim1.param_groups, optim2.param_groups): for group1, group2 in zip(optim1.param_groups, optim2.param_groups):