mirror of https://github.com/hpcaitech/ColossalAI
update optim
parent
44014faa67
commit
ccad7014c6
|
@ -209,6 +209,7 @@ def main():
|
|||
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
||||
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
pp_print_rank = is_pp_last_stage and (coordinator.local_rank == "0")
|
||||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
# Load ckpt
|
||||
|
@ -224,7 +225,7 @@ def main():
|
|||
with tqdm(
|
||||
range(total_len),
|
||||
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
|
||||
disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
|
||||
disable=not coordinator.is_master() if use_pipeline == False else not pp_print_rank,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
|
@ -238,7 +239,7 @@ def main():
|
|||
return_outputs=True,
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
if pp_print_rank:
|
||||
loss = outputs["loss"]
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
else:
|
||||
|
|
|
@ -173,13 +173,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
# if there are moe params, store in addtional group in optim
|
||||
if len(self.working_moe_params) > 0:
|
||||
self._sync_master_param = False
|
||||
param_group = dict()
|
||||
for key, value in self.optim.param_groups[0].items():
|
||||
if key != "params":
|
||||
param_group[key] = value
|
||||
self.master_moe_params = []
|
||||
for param in self.working_moe_params:
|
||||
self.master_moe_params.append(param.to(torch.float32))
|
||||
self.master_moe_params.append(param.clone().to(torch.float32).detach())
|
||||
param_group["params"] = self.master_moe_params
|
||||
self.optim.param_groups.append(param_group)
|
||||
|
||||
|
@ -602,6 +603,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# update param for moe ep
|
||||
# move grad to master param and compute norm
|
||||
if len(self.working_moe_params) > 0:
|
||||
if self._sync_master_param == False:
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
|
||||
self._sync_master_param = True
|
||||
moe_grads = []
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
if master_moe_param.grad is not None:
|
||||
|
|
Loading…
Reference in New Issue