diff --git a/internlm/core/engine.py b/internlm/core/engine.py index eb33e35..a0e4da5 100644 --- a/internlm/core/engine.py +++ b/internlm/core/engine.py @@ -102,7 +102,7 @@ class Engine: """Sets the gradient of all parameters in the model to zero.""" self.optimizer.zero_grad() - def step(self): + def step(self, disable_overlap=False): """ Executes the parameter update step. This includes all-reduce operations of gradients, gradient clipping, and parameter update. If successful, it also steps the learning rate scheduler and beta2 scheduler @@ -115,7 +115,7 @@ class Engine: self._all_reduce_gradients() self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm) - success, grad_norm = self.optimizer.step() + success, grad_norm = self.optimizer.step(disable_overlap=disable_overlap) if success and self._lr_scheduler is not None: self._lr_scheduler.step() diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index b189031..ff3bb90 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -195,9 +195,9 @@ class Trainer: """Sets the gradient of all parameters in the model to zero.""" self._engine.zero_grad() - def step(self): + def step(self, disable_overlap=False): """Executes the parameter update step.""" - return self._engine.step() + return self._engine.step(disable_overlap=disable_overlap) def execute_schedule(self, data_iter: Iterable, **kwargs): """Runs the forward, loss computation, and backward for the model. diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index ca7449a..aad9e70 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -607,7 +607,7 @@ class HybridZeroOptimizer(BaseOptimizer): return total_zero_grad_count @llm_timeout(func_name="optim_step") - def step(self, closure=None): + def step(self, closure=None, disable_overlap=False): """Performs a single optimization step. Args: @@ -715,7 +715,7 @@ class HybridZeroOptimizer(BaseOptimizer): self._sync_grad() timer("sync_grad").stop() - state, global_norms = self._step(closure=closure, norms=total_norms) + state, global_norms = self._step(closure=closure, norms=total_norms, disable_overlap=disable_overlap) if is_profiling: if grad_profiling_config.get("grad_norm_profiling", False): global_norms["layer_grad_norm"] = total_layer_grad_norms @@ -728,7 +728,7 @@ class HybridZeroOptimizer(BaseOptimizer): return state, global_norms - def _step(self, closure=None, norms=None): + def _step(self, closure=None, norms=None, disable_overlap=False): assert closure is None, "closure is not supported by step()" # check for overflow @@ -835,7 +835,7 @@ class HybridZeroOptimizer(BaseOptimizer): fp16_param.data.copy_(fp32_param) torch.cuda.synchronize() - self.broadcast_params() + self.broadcast_params(disable_overlap=disable_overlap) timer("step").stop() @@ -845,9 +845,11 @@ class HybridZeroOptimizer(BaseOptimizer): global_norm_groups[group_name] = global_norm / loss_scale return True, global_norm_groups - def broadcast_params(self): + def broadcast_params(self, disable_overlap=False): handles = [] - + assert all( + isinstance(value, list) and not value for value in self._param_bcast_sync_handler._bcast_handles.values() + ) for group_id in range(self.num_param_groups): for rank in range(self._zero_world_size[group_id]): # The following operations are performed only on the rank to which parameters are assigned. @@ -858,12 +860,13 @@ class HybridZeroOptimizer(BaseOptimizer): # assert grank == rank, f"{grank} == {rank}" g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank] - if self._overlap_sync_param: - handle = dict() - handle["tensor"] = fp16_param - handle["src"] = g_rank - handle["group"] = gpc.get_group(self._broadcast_parallel_mode[group_id]) - handle["async_op"] = True + if self._overlap_sync_param and not disable_overlap: + handle = { + "tensor": fp16_param, + "src": g_rank, + "group": gpc.get_group(self._broadcast_parallel_mode[group_id]), + "async_op": True, + } self._param_bcast_sync_handler.add_bcast_handle(rank, handle) else: handle = dist.broadcast( diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 0c60329..35ce63d 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -824,18 +824,14 @@ class ParamBcastSyncHandler: # BUG: The order of traversal is not necessarily the order of actual fwd/bwd execution!!! last_module = None for module_idx, children in _chunk.named_children(): - if gpc.get_global_rank() == 0: - print(f"children: {children.__class__.__name__}", flush=True) # should be the transformer block definaton in modeling_xxx.py if isinstance(children, nn.ModuleList): assert module_idx != 0 # record the block that a parameter belongs to - for layer_idx, block in enumerate(children): + for _, block in enumerate(children): # self._block_to_param[f"{name}.{idx}"] = list(block.parameters()) self._block_to_param[block] = list(block.parameters()) self._block_next_block[last_module] = block - if gpc.get_global_rank() == 0: - print(f"{block.__class__.__name__}_layer_{layer_idx}", flush=True) last_module = block else: # record the block that a parameter belongs to @@ -843,7 +839,6 @@ class ParamBcastSyncHandler: if module_idx == 0: assert "embedding" in f"{children.__class__.__name__}" assert last_module is None - self._block_next_block[children] = children else: self._block_next_block[last_module] = children diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 234944c..6c41f7e 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -921,7 +921,7 @@ class CheckpointManager: if self.enable_save_ckpt: self.try_ping_storage() - def quit_signal_handler(self, train_state) -> bool: + def quit_signal_handler(self, train_state, step_count=None) -> bool: """ Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file, all ranks will save ckpt and exit. @@ -933,6 +933,9 @@ class CheckpointManager: Returns: bool: whether to quit. """ + if step_count is None: + step_count = train_state.step_count + now_break, now_save_ckpt, save_type = False, False, CheckpointSaveType.NORMAL_CHECKPOINT if self.stop_file_path is None: @@ -950,32 +953,34 @@ class CheckpointManager: action_step = action_step_t.item() del action_step_t - if action_step < 0 and abs(action_step) == train_state.step_count: + if action_step < 0 and abs(action_step) == step_count: now_save_ckpt = True - if action_step > 0 and action_step == train_state.step_count: + if action_step > 0 and action_step == step_count: now_break, now_save_ckpt = True, True if action_step != 0 and gpc.is_rank_for_log(): msg = "Stop" if action_step > 0 else "Save" action_step = abs(action_step) - if train_state.step_count <= action_step: + if step_count <= action_step: if self.feishu_address: send_alert_message( address=self.feishu_address, message=f"training will {msg} at step_count {action_step}!\ -now step_count is {train_state.step_count}", +now step_count is {step_count}", ) return now_break, now_save_ckpt, save_type - def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool): + def is_now_to_save_ckpt(self, train_state, step_count=None) -> (bool, CheckpointSaveType, bool): + if step_count is None: + step_count = train_state.step_count save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False - if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: + if self.oss_snapshot_freq > 1 and step_count % self.oss_snapshot_freq == 0: save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT - if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps: + if step_count % self.checkpoint_every == 0 or step_count == train_state.total_steps: save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT - now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state) + now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state, step_count) if save_ckpts is False: save_ckpts = singal_save_ckpts save_type = singal_save_type diff --git a/train.py b/train.py index 6874f9e..a5447a2 100644 --- a/train.py +++ b/train.py @@ -238,7 +238,9 @@ def main(args): timer("fwd-bwd").stop() # update parameters, and returns (success_update, grad_norm) - trainer_result = trainer.step() + trainer_result = trainer.step( + disable_overlap=ckpt_manager.is_now_to_save_ckpt(train_state, train_state.step_count + 1)[0] + ) assert trainer_result is not None success_update, grad_norm_groups = trainer_result