mirror of https://github.com/InternLM/InternLM
no overlap for save ckpt
parent
eae9b97ab2
commit
9bf24d9768
|
@ -102,7 +102,7 @@ class Engine:
|
|||
"""Sets the gradient of all parameters in the model to zero."""
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def step(self):
|
||||
def step(self, disable_overlap=False):
|
||||
"""
|
||||
Executes the parameter update step. This includes all-reduce operations of gradients, gradient clipping,
|
||||
and parameter update. If successful, it also steps the learning rate scheduler and beta2 scheduler
|
||||
|
@ -115,7 +115,7 @@ class Engine:
|
|||
self._all_reduce_gradients()
|
||||
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
|
||||
|
||||
success, grad_norm = self.optimizer.step()
|
||||
success, grad_norm = self.optimizer.step(disable_overlap=disable_overlap)
|
||||
|
||||
if success and self._lr_scheduler is not None:
|
||||
self._lr_scheduler.step()
|
||||
|
|
|
@ -195,9 +195,9 @@ class Trainer:
|
|||
"""Sets the gradient of all parameters in the model to zero."""
|
||||
self._engine.zero_grad()
|
||||
|
||||
def step(self):
|
||||
def step(self, disable_overlap=False):
|
||||
"""Executes the parameter update step."""
|
||||
return self._engine.step()
|
||||
return self._engine.step(disable_overlap=disable_overlap)
|
||||
|
||||
def execute_schedule(self, data_iter: Iterable, **kwargs):
|
||||
"""Runs the forward, loss computation, and backward for the model.
|
||||
|
|
|
@ -607,7 +607,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
return total_zero_grad_count
|
||||
|
||||
@llm_timeout(func_name="optim_step")
|
||||
def step(self, closure=None):
|
||||
def step(self, closure=None, disable_overlap=False):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Args:
|
||||
|
@ -715,7 +715,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._sync_grad()
|
||||
timer("sync_grad").stop()
|
||||
|
||||
state, global_norms = self._step(closure=closure, norms=total_norms)
|
||||
state, global_norms = self._step(closure=closure, norms=total_norms, disable_overlap=disable_overlap)
|
||||
if is_profiling:
|
||||
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||
global_norms["layer_grad_norm"] = total_layer_grad_norms
|
||||
|
@ -728,7 +728,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
return state, global_norms
|
||||
|
||||
def _step(self, closure=None, norms=None):
|
||||
def _step(self, closure=None, norms=None, disable_overlap=False):
|
||||
assert closure is None, "closure is not supported by step()"
|
||||
|
||||
# check for overflow
|
||||
|
@ -835,7 +835,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
self.broadcast_params()
|
||||
self.broadcast_params(disable_overlap=disable_overlap)
|
||||
|
||||
timer("step").stop()
|
||||
|
||||
|
@ -845,9 +845,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
global_norm_groups[group_name] = global_norm / loss_scale
|
||||
return True, global_norm_groups
|
||||
|
||||
def broadcast_params(self):
|
||||
def broadcast_params(self, disable_overlap=False):
|
||||
handles = []
|
||||
|
||||
assert all(
|
||||
isinstance(value, list) and not value for value in self._param_bcast_sync_handler._bcast_handles.values()
|
||||
)
|
||||
for group_id in range(self.num_param_groups):
|
||||
for rank in range(self._zero_world_size[group_id]):
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
|
@ -858,12 +860,13 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# assert grank == rank, f"{grank} == {rank}"
|
||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
|
||||
|
||||
if self._overlap_sync_param:
|
||||
handle = dict()
|
||||
handle["tensor"] = fp16_param
|
||||
handle["src"] = g_rank
|
||||
handle["group"] = gpc.get_group(self._broadcast_parallel_mode[group_id])
|
||||
handle["async_op"] = True
|
||||
if self._overlap_sync_param and not disable_overlap:
|
||||
handle = {
|
||||
"tensor": fp16_param,
|
||||
"src": g_rank,
|
||||
"group": gpc.get_group(self._broadcast_parallel_mode[group_id]),
|
||||
"async_op": True,
|
||||
}
|
||||
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
||||
else:
|
||||
handle = dist.broadcast(
|
||||
|
|
|
@ -824,18 +824,14 @@ class ParamBcastSyncHandler:
|
|||
# BUG: The order of traversal is not necessarily the order of actual fwd/bwd execution!!!
|
||||
last_module = None
|
||||
for module_idx, children in _chunk.named_children():
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"children: {children.__class__.__name__}", flush=True)
|
||||
# should be the transformer block definaton in modeling_xxx.py
|
||||
if isinstance(children, nn.ModuleList):
|
||||
assert module_idx != 0
|
||||
# record the block that a parameter belongs to
|
||||
for layer_idx, block in enumerate(children):
|
||||
for _, block in enumerate(children):
|
||||
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
|
||||
self._block_to_param[block] = list(block.parameters())
|
||||
self._block_next_block[last_module] = block
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"{block.__class__.__name__}_layer_{layer_idx}", flush=True)
|
||||
last_module = block
|
||||
else:
|
||||
# record the block that a parameter belongs to
|
||||
|
@ -843,7 +839,6 @@ class ParamBcastSyncHandler:
|
|||
if module_idx == 0:
|
||||
assert "embedding" in f"{children.__class__.__name__}"
|
||||
assert last_module is None
|
||||
self._block_next_block[children] = children
|
||||
else:
|
||||
self._block_next_block[last_module] = children
|
||||
|
||||
|
|
|
@ -921,7 +921,7 @@ class CheckpointManager:
|
|||
if self.enable_save_ckpt:
|
||||
self.try_ping_storage()
|
||||
|
||||
def quit_signal_handler(self, train_state) -> bool:
|
||||
def quit_signal_handler(self, train_state, step_count=None) -> bool:
|
||||
"""
|
||||
Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file,
|
||||
all ranks will save ckpt and exit.
|
||||
|
@ -933,6 +933,9 @@ class CheckpointManager:
|
|||
Returns:
|
||||
bool: whether to quit.
|
||||
"""
|
||||
if step_count is None:
|
||||
step_count = train_state.step_count
|
||||
|
||||
now_break, now_save_ckpt, save_type = False, False, CheckpointSaveType.NORMAL_CHECKPOINT
|
||||
|
||||
if self.stop_file_path is None:
|
||||
|
@ -950,32 +953,34 @@ class CheckpointManager:
|
|||
action_step = action_step_t.item()
|
||||
del action_step_t
|
||||
|
||||
if action_step < 0 and abs(action_step) == train_state.step_count:
|
||||
if action_step < 0 and abs(action_step) == step_count:
|
||||
now_save_ckpt = True
|
||||
|
||||
if action_step > 0 and action_step == train_state.step_count:
|
||||
if action_step > 0 and action_step == step_count:
|
||||
now_break, now_save_ckpt = True, True
|
||||
|
||||
if action_step != 0 and gpc.is_rank_for_log():
|
||||
msg = "Stop" if action_step > 0 else "Save"
|
||||
action_step = abs(action_step)
|
||||
if train_state.step_count <= action_step:
|
||||
if step_count <= action_step:
|
||||
if self.feishu_address:
|
||||
send_alert_message(
|
||||
address=self.feishu_address,
|
||||
message=f"training will {msg} at step_count {action_step}!\
|
||||
now step_count is {train_state.step_count}",
|
||||
now step_count is {step_count}",
|
||||
)
|
||||
|
||||
return now_break, now_save_ckpt, save_type
|
||||
|
||||
def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool):
|
||||
def is_now_to_save_ckpt(self, train_state, step_count=None) -> (bool, CheckpointSaveType, bool):
|
||||
if step_count is None:
|
||||
step_count = train_state.step_count
|
||||
save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False
|
||||
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
|
||||
if self.oss_snapshot_freq > 1 and step_count % self.oss_snapshot_freq == 0:
|
||||
save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT
|
||||
if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps:
|
||||
if step_count % self.checkpoint_every == 0 or step_count == train_state.total_steps:
|
||||
save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT
|
||||
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
|
||||
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state, step_count)
|
||||
if save_ckpts is False:
|
||||
save_ckpts = singal_save_ckpts
|
||||
save_type = singal_save_type
|
||||
|
|
4
train.py
4
train.py
|
@ -238,7 +238,9 @@ def main(args):
|
|||
timer("fwd-bwd").stop()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
trainer_result = trainer.step()
|
||||
trainer_result = trainer.step(
|
||||
disable_overlap=ckpt_manager.is_now_to_save_ckpt(train_state, train_state.step_count + 1)[0]
|
||||
)
|
||||
assert trainer_result is not None
|
||||
|
||||
success_update, grad_norm_groups = trainer_result
|
||||
|
|
Loading…
Reference in New Issue