mirror of https://github.com/hpcaitech/ColossalAI
update optim
parent
44014faa67
commit
ccad7014c6
|
@ -209,6 +209,7 @@ def main():
|
||||||
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
||||||
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||||
|
pp_print_rank = is_pp_last_stage and (coordinator.local_rank == "0")
|
||||||
coordinator.print_on_master(f"Finish init booster")
|
coordinator.print_on_master(f"Finish init booster")
|
||||||
|
|
||||||
# Load ckpt
|
# Load ckpt
|
||||||
|
@ -224,7 +225,7 @@ def main():
|
||||||
with tqdm(
|
with tqdm(
|
||||||
range(total_len),
|
range(total_len),
|
||||||
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
|
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
|
||||||
disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
|
disable=not coordinator.is_master() if use_pipeline == False else not pp_print_rank,
|
||||||
) as pbar:
|
) as pbar:
|
||||||
for step in pbar:
|
for step in pbar:
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
|
@ -238,7 +239,7 @@ def main():
|
||||||
return_outputs=True,
|
return_outputs=True,
|
||||||
)
|
)
|
||||||
# Backward and optimize
|
# Backward and optimize
|
||||||
if is_pp_last_stage:
|
if pp_print_rank:
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
pbar.set_postfix({"loss": loss.item()})
|
pbar.set_postfix({"loss": loss.item()})
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -173,13 +173,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
# if there are moe params, store in addtional group in optim
|
# if there are moe params, store in addtional group in optim
|
||||||
if len(self.working_moe_params) > 0:
|
if len(self.working_moe_params) > 0:
|
||||||
|
self._sync_master_param = False
|
||||||
param_group = dict()
|
param_group = dict()
|
||||||
for key, value in self.optim.param_groups[0].items():
|
for key, value in self.optim.param_groups[0].items():
|
||||||
if key != "params":
|
if key != "params":
|
||||||
param_group[key] = value
|
param_group[key] = value
|
||||||
self.master_moe_params = []
|
self.master_moe_params = []
|
||||||
for param in self.working_moe_params:
|
for param in self.working_moe_params:
|
||||||
self.master_moe_params.append(param.to(torch.float32))
|
self.master_moe_params.append(param.clone().to(torch.float32).detach())
|
||||||
param_group["params"] = self.master_moe_params
|
param_group["params"] = self.master_moe_params
|
||||||
self.optim.param_groups.append(param_group)
|
self.optim.param_groups.append(param_group)
|
||||||
|
|
||||||
|
@ -602,6 +603,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
# update param for moe ep
|
# update param for moe ep
|
||||||
# move grad to master param and compute norm
|
# move grad to master param and compute norm
|
||||||
if len(self.working_moe_params) > 0:
|
if len(self.working_moe_params) > 0:
|
||||||
|
if self._sync_master_param == False:
|
||||||
|
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||||
|
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
|
||||||
|
self._sync_master_param = True
|
||||||
moe_grads = []
|
moe_grads = []
|
||||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||||
if master_moe_param.grad is not None:
|
if master_moe_param.grad is not None:
|
||||||
|
|
Loading…
Reference in New Issue