update optim

pull/5190/head
Xuanlei Zhao 2023-12-29 16:51:29 +08:00
parent 44014faa67
commit ccad7014c6
2 changed files with 9 additions and 3 deletions

View File

@ -209,6 +209,7 @@ def main():
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
pp_print_rank = is_pp_last_stage and (coordinator.local_rank == "0")
coordinator.print_on_master(f"Finish init booster") coordinator.print_on_master(f"Finish init booster")
# Load ckpt # Load ckpt
@ -224,7 +225,7 @@ def main():
with tqdm( with tqdm(
range(total_len), range(total_len),
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage, disable=not coordinator.is_master() if use_pipeline == False else not pp_print_rank,
) as pbar: ) as pbar:
for step in pbar: for step in pbar:
if use_pipeline: if use_pipeline:
@ -238,7 +239,7 @@ def main():
return_outputs=True, return_outputs=True,
) )
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if pp_print_rank:
loss = outputs["loss"] loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()}) pbar.set_postfix({"loss": loss.item()})
else: else:

View File

@ -173,13 +173,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# if there are moe params, store in addtional group in optim # if there are moe params, store in addtional group in optim
if len(self.working_moe_params) > 0: if len(self.working_moe_params) > 0:
self._sync_master_param = False
param_group = dict() param_group = dict()
for key, value in self.optim.param_groups[0].items(): for key, value in self.optim.param_groups[0].items():
if key != "params": if key != "params":
param_group[key] = value param_group[key] = value
self.master_moe_params = [] self.master_moe_params = []
for param in self.working_moe_params: for param in self.working_moe_params:
self.master_moe_params.append(param.to(torch.float32)) self.master_moe_params.append(param.clone().to(torch.float32).detach())
param_group["params"] = self.master_moe_params param_group["params"] = self.master_moe_params
self.optim.param_groups.append(param_group) self.optim.param_groups.append(param_group)
@ -602,6 +603,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update param for moe ep # update param for moe ep
# move grad to master param and compute norm # move grad to master param and compute norm
if len(self.working_moe_params) > 0: if len(self.working_moe_params) > 0:
if self._sync_master_param == False:
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
self._sync_master_param = True
moe_grads = [] moe_grads = []
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
if master_moe_param.grad is not None: if master_moe_param.grad is not None: