update optim

pull/5190/head
Xuanlei Zhao 2023-12-29 16:51:29 +08:00
parent 44014faa67
commit ccad7014c6
2 changed files with 9 additions and 3 deletions

View File

@ -209,6 +209,7 @@ def main():
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
pp_print_rank = is_pp_last_stage and (coordinator.local_rank == "0")
coordinator.print_on_master(f"Finish init booster")
# Load ckpt
@ -224,7 +225,7 @@ def main():
with tqdm(
range(total_len),
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
disable=not coordinator.is_master() if use_pipeline == False else not pp_print_rank,
) as pbar:
for step in pbar:
if use_pipeline:
@ -238,7 +239,7 @@ def main():
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
if pp_print_rank:
loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()})
else:

View File

@ -173,13 +173,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# if there are moe params, store in addtional group in optim
if len(self.working_moe_params) > 0:
self._sync_master_param = False
param_group = dict()
for key, value in self.optim.param_groups[0].items():
if key != "params":
param_group[key] = value
self.master_moe_params = []
for param in self.working_moe_params:
self.master_moe_params.append(param.to(torch.float32))
self.master_moe_params.append(param.clone().to(torch.float32).detach())
param_group["params"] = self.master_moe_params
self.optim.param_groups.append(param_group)
@ -602,6 +603,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update param for moe ep
# move grad to master param and compute norm
if len(self.working_moe_params) > 0:
if self._sync_master_param == False:
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
self._sync_master_param = True
moe_grads = []
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
if master_moe_param.grad is not None: