no overlap for save ckpt

pull/540/head
lijiaxing 2023-12-19 17:45:55 +08:00
parent eae9b97ab2
commit 9bf24d9768
6 changed files with 37 additions and 32 deletions

View File

@ -102,7 +102,7 @@ class Engine:
"""Sets the gradient of all parameters in the model to zero."""
self.optimizer.zero_grad()
def step(self):
def step(self, disable_overlap=False):
"""
Executes the parameter update step. This includes all-reduce operations of gradients, gradient clipping,
and parameter update. If successful, it also steps the learning rate scheduler and beta2 scheduler
@ -115,7 +115,7 @@ class Engine:
self._all_reduce_gradients()
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
success, grad_norm = self.optimizer.step()
success, grad_norm = self.optimizer.step(disable_overlap=disable_overlap)
if success and self._lr_scheduler is not None:
self._lr_scheduler.step()

View File

@ -195,9 +195,9 @@ class Trainer:
"""Sets the gradient of all parameters in the model to zero."""
self._engine.zero_grad()
def step(self):
def step(self, disable_overlap=False):
"""Executes the parameter update step."""
return self._engine.step()
return self._engine.step(disable_overlap=disable_overlap)
def execute_schedule(self, data_iter: Iterable, **kwargs):
"""Runs the forward, loss computation, and backward for the model.

View File

@ -607,7 +607,7 @@ class HybridZeroOptimizer(BaseOptimizer):
return total_zero_grad_count
@llm_timeout(func_name="optim_step")
def step(self, closure=None):
def step(self, closure=None, disable_overlap=False):
"""Performs a single optimization step.
Args:
@ -715,7 +715,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._sync_grad()
timer("sync_grad").stop()
state, global_norms = self._step(closure=closure, norms=total_norms)
state, global_norms = self._step(closure=closure, norms=total_norms, disable_overlap=disable_overlap)
if is_profiling:
if grad_profiling_config.get("grad_norm_profiling", False):
global_norms["layer_grad_norm"] = total_layer_grad_norms
@ -728,7 +728,7 @@ class HybridZeroOptimizer(BaseOptimizer):
return state, global_norms
def _step(self, closure=None, norms=None):
def _step(self, closure=None, norms=None, disable_overlap=False):
assert closure is None, "closure is not supported by step()"
# check for overflow
@ -835,7 +835,7 @@ class HybridZeroOptimizer(BaseOptimizer):
fp16_param.data.copy_(fp32_param)
torch.cuda.synchronize()
self.broadcast_params()
self.broadcast_params(disable_overlap=disable_overlap)
timer("step").stop()
@ -845,9 +845,11 @@ class HybridZeroOptimizer(BaseOptimizer):
global_norm_groups[group_name] = global_norm / loss_scale
return True, global_norm_groups
def broadcast_params(self):
def broadcast_params(self, disable_overlap=False):
handles = []
assert all(
isinstance(value, list) and not value for value in self._param_bcast_sync_handler._bcast_handles.values()
)
for group_id in range(self.num_param_groups):
for rank in range(self._zero_world_size[group_id]):
# The following operations are performed only on the rank to which parameters are assigned.
@ -858,12 +860,13 @@ class HybridZeroOptimizer(BaseOptimizer):
# assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
if self._overlap_sync_param:
handle = dict()
handle["tensor"] = fp16_param
handle["src"] = g_rank
handle["group"] = gpc.get_group(self._broadcast_parallel_mode[group_id])
handle["async_op"] = True
if self._overlap_sync_param and not disable_overlap:
handle = {
"tensor": fp16_param,
"src": g_rank,
"group": gpc.get_group(self._broadcast_parallel_mode[group_id]),
"async_op": True,
}
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
else:
handle = dist.broadcast(

View File

@ -824,18 +824,14 @@ class ParamBcastSyncHandler:
# BUG: The order of traversal is not necessarily the order of actual fwd/bwd execution!!!
last_module = None
for module_idx, children in _chunk.named_children():
if gpc.get_global_rank() == 0:
print(f"children: {children.__class__.__name__}", flush=True)
# should be the transformer block definaton in modeling_xxx.py
if isinstance(children, nn.ModuleList):
assert module_idx != 0
# record the block that a parameter belongs to
for layer_idx, block in enumerate(children):
for _, block in enumerate(children):
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
self._block_to_param[block] = list(block.parameters())
self._block_next_block[last_module] = block
if gpc.get_global_rank() == 0:
print(f"{block.__class__.__name__}_layer_{layer_idx}", flush=True)
last_module = block
else:
# record the block that a parameter belongs to
@ -843,7 +839,6 @@ class ParamBcastSyncHandler:
if module_idx == 0:
assert "embedding" in f"{children.__class__.__name__}"
assert last_module is None
self._block_next_block[children] = children
else:
self._block_next_block[last_module] = children

View File

@ -921,7 +921,7 @@ class CheckpointManager:
if self.enable_save_ckpt:
self.try_ping_storage()
def quit_signal_handler(self, train_state) -> bool:
def quit_signal_handler(self, train_state, step_count=None) -> bool:
"""
Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file,
all ranks will save ckpt and exit.
@ -933,6 +933,9 @@ class CheckpointManager:
Returns:
bool: whether to quit.
"""
if step_count is None:
step_count = train_state.step_count
now_break, now_save_ckpt, save_type = False, False, CheckpointSaveType.NORMAL_CHECKPOINT
if self.stop_file_path is None:
@ -950,32 +953,34 @@ class CheckpointManager:
action_step = action_step_t.item()
del action_step_t
if action_step < 0 and abs(action_step) == train_state.step_count:
if action_step < 0 and abs(action_step) == step_count:
now_save_ckpt = True
if action_step > 0 and action_step == train_state.step_count:
if action_step > 0 and action_step == step_count:
now_break, now_save_ckpt = True, True
if action_step != 0 and gpc.is_rank_for_log():
msg = "Stop" if action_step > 0 else "Save"
action_step = abs(action_step)
if train_state.step_count <= action_step:
if step_count <= action_step:
if self.feishu_address:
send_alert_message(
address=self.feishu_address,
message=f"training will {msg} at step_count {action_step}!\
now step_count is {train_state.step_count}",
now step_count is {step_count}",
)
return now_break, now_save_ckpt, save_type
def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool):
def is_now_to_save_ckpt(self, train_state, step_count=None) -> (bool, CheckpointSaveType, bool):
if step_count is None:
step_count = train_state.step_count
save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
if self.oss_snapshot_freq > 1 and step_count % self.oss_snapshot_freq == 0:
save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT
if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps:
if step_count % self.checkpoint_every == 0 or step_count == train_state.total_steps:
save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state, step_count)
if save_ckpts is False:
save_ckpts = singal_save_ckpts
save_type = singal_save_type

View File

@ -238,7 +238,9 @@ def main(args):
timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm)
trainer_result = trainer.step()
trainer_result = trainer.step(
disable_overlap=ckpt_manager.is_now_to_save_ckpt(train_state, train_state.step_count + 1)[0]
)
assert trainer_result is not None
success_update, grad_norm_groups = trainer_result