From ccad7014c64dcae70e55a88b6bbaa23fe255bee7 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Fri, 29 Dec 2023 16:51:29 +0800 Subject: [PATCH] update optim --- applications/ColossalMoE/train.py | 5 +++-- colossalai/zero/low_level/low_level_optim.py | 7 ++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py index 79052f910..b2bdd4252 100644 --- a/applications/ColossalMoE/train.py +++ b/applications/ColossalMoE/train.py @@ -209,6 +209,7 @@ def main(): model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + pp_print_rank = is_pp_last_stage and (coordinator.local_rank == "0") coordinator.print_on_master(f"Finish init booster") # Load ckpt @@ -224,7 +225,7 @@ def main(): with tqdm( range(total_len), desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", - disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage, + disable=not coordinator.is_master() if use_pipeline == False else not pp_print_rank, ) as pbar: for step in pbar: if use_pipeline: @@ -238,7 +239,7 @@ def main(): return_outputs=True, ) # Backward and optimize - if is_pp_last_stage: + if pp_print_rank: loss = outputs["loss"] pbar.set_postfix({"loss": loss.item()}) else: diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 4f841effd..2c17b7da7 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -173,13 +173,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # if there are moe params, store in addtional group in optim if len(self.working_moe_params) > 0: + self._sync_master_param = False param_group = dict() for key, value in self.optim.param_groups[0].items(): if key != "params": param_group[key] = value self.master_moe_params = [] for param in self.working_moe_params: - self.master_moe_params.append(param.to(torch.float32)) + self.master_moe_params.append(param.clone().to(torch.float32).detach()) param_group["params"] = self.master_moe_params self.optim.param_groups.append(param_group) @@ -602,6 +603,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # update param for moe ep # move grad to master param and compute norm if len(self.working_moe_params) > 0: + if self._sync_master_param == False: + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach() + self._sync_master_param = True moe_grads = [] for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): if master_moe_param.grad is not None: