no overlap for save ckpt

pull/540/head
lijiaxing 2023-12-19 17:45:55 +08:00
parent eae9b97ab2
commit 9bf24d9768
6 changed files with 37 additions and 32 deletions

View File

@ -102,7 +102,7 @@ class Engine:
"""Sets the gradient of all parameters in the model to zero.""" """Sets the gradient of all parameters in the model to zero."""
self.optimizer.zero_grad() self.optimizer.zero_grad()
def step(self): def step(self, disable_overlap=False):
""" """
Executes the parameter update step. This includes all-reduce operations of gradients, gradient clipping, Executes the parameter update step. This includes all-reduce operations of gradients, gradient clipping,
and parameter update. If successful, it also steps the learning rate scheduler and beta2 scheduler and parameter update. If successful, it also steps the learning rate scheduler and beta2 scheduler
@ -115,7 +115,7 @@ class Engine:
self._all_reduce_gradients() self._all_reduce_gradients()
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm) self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
success, grad_norm = self.optimizer.step() success, grad_norm = self.optimizer.step(disable_overlap=disable_overlap)
if success and self._lr_scheduler is not None: if success and self._lr_scheduler is not None:
self._lr_scheduler.step() self._lr_scheduler.step()

View File

@ -195,9 +195,9 @@ class Trainer:
"""Sets the gradient of all parameters in the model to zero.""" """Sets the gradient of all parameters in the model to zero."""
self._engine.zero_grad() self._engine.zero_grad()
def step(self): def step(self, disable_overlap=False):
"""Executes the parameter update step.""" """Executes the parameter update step."""
return self._engine.step() return self._engine.step(disable_overlap=disable_overlap)
def execute_schedule(self, data_iter: Iterable, **kwargs): def execute_schedule(self, data_iter: Iterable, **kwargs):
"""Runs the forward, loss computation, and backward for the model. """Runs the forward, loss computation, and backward for the model.

View File

@ -607,7 +607,7 @@ class HybridZeroOptimizer(BaseOptimizer):
return total_zero_grad_count return total_zero_grad_count
@llm_timeout(func_name="optim_step") @llm_timeout(func_name="optim_step")
def step(self, closure=None): def step(self, closure=None, disable_overlap=False):
"""Performs a single optimization step. """Performs a single optimization step.
Args: Args:
@ -715,7 +715,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._sync_grad() self._sync_grad()
timer("sync_grad").stop() timer("sync_grad").stop()
state, global_norms = self._step(closure=closure, norms=total_norms) state, global_norms = self._step(closure=closure, norms=total_norms, disable_overlap=disable_overlap)
if is_profiling: if is_profiling:
if grad_profiling_config.get("grad_norm_profiling", False): if grad_profiling_config.get("grad_norm_profiling", False):
global_norms["layer_grad_norm"] = total_layer_grad_norms global_norms["layer_grad_norm"] = total_layer_grad_norms
@ -728,7 +728,7 @@ class HybridZeroOptimizer(BaseOptimizer):
return state, global_norms return state, global_norms
def _step(self, closure=None, norms=None): def _step(self, closure=None, norms=None, disable_overlap=False):
assert closure is None, "closure is not supported by step()" assert closure is None, "closure is not supported by step()"
# check for overflow # check for overflow
@ -835,7 +835,7 @@ class HybridZeroOptimizer(BaseOptimizer):
fp16_param.data.copy_(fp32_param) fp16_param.data.copy_(fp32_param)
torch.cuda.synchronize() torch.cuda.synchronize()
self.broadcast_params() self.broadcast_params(disable_overlap=disable_overlap)
timer("step").stop() timer("step").stop()
@ -845,9 +845,11 @@ class HybridZeroOptimizer(BaseOptimizer):
global_norm_groups[group_name] = global_norm / loss_scale global_norm_groups[group_name] = global_norm / loss_scale
return True, global_norm_groups return True, global_norm_groups
def broadcast_params(self): def broadcast_params(self, disable_overlap=False):
handles = [] handles = []
assert all(
isinstance(value, list) and not value for value in self._param_bcast_sync_handler._bcast_handles.values()
)
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
for rank in range(self._zero_world_size[group_id]): for rank in range(self._zero_world_size[group_id]):
# The following operations are performed only on the rank to which parameters are assigned. # The following operations are performed only on the rank to which parameters are assigned.
@ -858,12 +860,13 @@ class HybridZeroOptimizer(BaseOptimizer):
# assert grank == rank, f"{grank} == {rank}" # assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank] g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
if self._overlap_sync_param: if self._overlap_sync_param and not disable_overlap:
handle = dict() handle = {
handle["tensor"] = fp16_param "tensor": fp16_param,
handle["src"] = g_rank "src": g_rank,
handle["group"] = gpc.get_group(self._broadcast_parallel_mode[group_id]) "group": gpc.get_group(self._broadcast_parallel_mode[group_id]),
handle["async_op"] = True "async_op": True,
}
self._param_bcast_sync_handler.add_bcast_handle(rank, handle) self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
else: else:
handle = dist.broadcast( handle = dist.broadcast(

View File

@ -824,18 +824,14 @@ class ParamBcastSyncHandler:
# BUG: The order of traversal is not necessarily the order of actual fwd/bwd execution!!! # BUG: The order of traversal is not necessarily the order of actual fwd/bwd execution!!!
last_module = None last_module = None
for module_idx, children in _chunk.named_children(): for module_idx, children in _chunk.named_children():
if gpc.get_global_rank() == 0:
print(f"children: {children.__class__.__name__}", flush=True)
# should be the transformer block definaton in modeling_xxx.py # should be the transformer block definaton in modeling_xxx.py
if isinstance(children, nn.ModuleList): if isinstance(children, nn.ModuleList):
assert module_idx != 0 assert module_idx != 0
# record the block that a parameter belongs to # record the block that a parameter belongs to
for layer_idx, block in enumerate(children): for _, block in enumerate(children):
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters()) # self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
self._block_to_param[block] = list(block.parameters()) self._block_to_param[block] = list(block.parameters())
self._block_next_block[last_module] = block self._block_next_block[last_module] = block
if gpc.get_global_rank() == 0:
print(f"{block.__class__.__name__}_layer_{layer_idx}", flush=True)
last_module = block last_module = block
else: else:
# record the block that a parameter belongs to # record the block that a parameter belongs to
@ -843,7 +839,6 @@ class ParamBcastSyncHandler:
if module_idx == 0: if module_idx == 0:
assert "embedding" in f"{children.__class__.__name__}" assert "embedding" in f"{children.__class__.__name__}"
assert last_module is None assert last_module is None
self._block_next_block[children] = children
else: else:
self._block_next_block[last_module] = children self._block_next_block[last_module] = children

View File

@ -921,7 +921,7 @@ class CheckpointManager:
if self.enable_save_ckpt: if self.enable_save_ckpt:
self.try_ping_storage() self.try_ping_storage()
def quit_signal_handler(self, train_state) -> bool: def quit_signal_handler(self, train_state, step_count=None) -> bool:
""" """
Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file, Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file,
all ranks will save ckpt and exit. all ranks will save ckpt and exit.
@ -933,6 +933,9 @@ class CheckpointManager:
Returns: Returns:
bool: whether to quit. bool: whether to quit.
""" """
if step_count is None:
step_count = train_state.step_count
now_break, now_save_ckpt, save_type = False, False, CheckpointSaveType.NORMAL_CHECKPOINT now_break, now_save_ckpt, save_type = False, False, CheckpointSaveType.NORMAL_CHECKPOINT
if self.stop_file_path is None: if self.stop_file_path is None:
@ -950,32 +953,34 @@ class CheckpointManager:
action_step = action_step_t.item() action_step = action_step_t.item()
del action_step_t del action_step_t
if action_step < 0 and abs(action_step) == train_state.step_count: if action_step < 0 and abs(action_step) == step_count:
now_save_ckpt = True now_save_ckpt = True
if action_step > 0 and action_step == train_state.step_count: if action_step > 0 and action_step == step_count:
now_break, now_save_ckpt = True, True now_break, now_save_ckpt = True, True
if action_step != 0 and gpc.is_rank_for_log(): if action_step != 0 and gpc.is_rank_for_log():
msg = "Stop" if action_step > 0 else "Save" msg = "Stop" if action_step > 0 else "Save"
action_step = abs(action_step) action_step = abs(action_step)
if train_state.step_count <= action_step: if step_count <= action_step:
if self.feishu_address: if self.feishu_address:
send_alert_message( send_alert_message(
address=self.feishu_address, address=self.feishu_address,
message=f"training will {msg} at step_count {action_step}!\ message=f"training will {msg} at step_count {action_step}!\
now step_count is {train_state.step_count}", now step_count is {step_count}",
) )
return now_break, now_save_ckpt, save_type return now_break, now_save_ckpt, save_type
def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool): def is_now_to_save_ckpt(self, train_state, step_count=None) -> (bool, CheckpointSaveType, bool):
if step_count is None:
step_count = train_state.step_count
save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: if self.oss_snapshot_freq > 1 and step_count % self.oss_snapshot_freq == 0:
save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT
if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps: if step_count % self.checkpoint_every == 0 or step_count == train_state.total_steps:
save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state) now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state, step_count)
if save_ckpts is False: if save_ckpts is False:
save_ckpts = singal_save_ckpts save_ckpts = singal_save_ckpts
save_type = singal_save_type save_type = singal_save_type

View File

@ -238,7 +238,9 @@ def main(args):
timer("fwd-bwd").stop() timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm) # update parameters, and returns (success_update, grad_norm)
trainer_result = trainer.step() trainer_result = trainer.step(
disable_overlap=ckpt_manager.is_now_to_save_ckpt(train_state, train_state.step_count + 1)[0]
)
assert trainer_result is not None assert trainer_result is not None
success_update, grad_norm_groups = trainer_result success_update, grad_norm_groups = trainer_result