diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d1edb4f..108bb58 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -611,14 +611,15 @@ class HybridZeroOptimizer(BaseOptimizer): self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True) # compute norm for gradients in the before bucket + grad_profiling_config = gpc.config.get("grad_profiling", {}) groups_norms = [] groups_param_norms = [] group_param_zero_grad_count = [] for group_id in range(self.num_param_groups): groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) - if gpc.config.get("grad_norm_profiling", False): + if grad_profiling_config.get("grad_norm_profiling", False): groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id)) - if gpc.config.get("zero_grad_profiling", False): + if grad_profiling_config.get("zero_grad_profiling", False): group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id)) # clear reduced grads @@ -645,7 +646,7 @@ class HybridZeroOptimizer(BaseOptimizer): last_stage=True, previous_norm=groups_norms[group_id], ) - if gpc.config.get("grad_norm_profiling", False): + if grad_profiling_config.get("grad_norm_profiling", False): param_norms = self._compute_param_norm_stage( group_id=group_id, last_bucket=True, @@ -655,7 +656,7 @@ class HybridZeroOptimizer(BaseOptimizer): total_layer_norms[group_name], total_param_norms[group_name] = compute_layer_norm( param_norms=param_norms, loss_scale=self.loss_scale.item() ) - if gpc.config.get("zero_grad_profiling", False): + if grad_profiling_config.get("zero_grad_profiling", False): zero_grad_count = self._count_zero_grads_stage( group_id=group_id, last_bucket=True, @@ -672,10 +673,10 @@ class HybridZeroOptimizer(BaseOptimizer): timer("sync_grad").stop() state, global_norms = self._step(closure=closure, norms=total_norms) - if gpc.config.get("grad_norm_profiling", False): + if grad_profiling_config.get("grad_norm_profiling", False): global_norms["layer_norm"] = total_layer_norms global_norms["param_norm"] = total_param_norms - if gpc.config.get("zero_grad_profiling", False): + if grad_profiling_config.get("zero_grad_profiling", False): global_norms["layer_zero_grad"] = total_layer_zero_grad_count global_norms["param_zero_grad"] = total_param_zero_grad_count diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 4f5f1bb..dd8e190 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -158,7 +158,10 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler). """ - if gpc.config.get("grad_norm_profiling", False) or gpc.config.get("zero_grad_profiling", False): + grad_profiling_config = gpc.config.get("grad_profiling", {}) + if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( + "zero_grad_profiling", False + ): # set the layer name as an attribute of the model parameters set_model_params_layer_name(model) @@ -526,26 +529,42 @@ def record_current_batch_training_metrics( for key, value in acc_perplex.items(): infos[key] = value - if gpc.config.get("grad_norm_profiling", False) or gpc.config.get("zero_grad_profiling", False): + grad_profiling_config = gpc.config.get("grad_profiling", {}) + if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( + "zero_grad_profiling", False + ): layer_metrics = ["layer_norm", "layer_zero_grad"] param_metrics = ["param_norm", "param_zero_grad"] - + layer_names = grad_profiling_config.get("layers", []) for layer_metric_name in layer_metrics: layer_metric = grad_norm.get(layer_metric_name, {}) if layer_metric: - for group_name, value in layer_metric.items(): - if value: + for group_name, layer_group in layer_metric.items(): + if layer_group: title = f"{layer_metric_name}/{group_name}" - writer.add_scalars(key=title, value=value, step=train_state.step_count) + if layer_names: + filter_layer_metrics = {} + for layer_name, metric_value in layer_group.items(): + if layer_name in layer_names: + filter_layer_metrics[layer_name] = metric_value + writer.add_scalars(key=title, value=filter_layer_metrics, step=train_state.step_count) + else: + writer.add_scalars(key=title, value=layer_group, step=train_state.step_count) del grad_norm[layer_metric_name] for param_metric_name in param_metrics: param_metric = grad_norm.get(param_metric_name, {}) if param_metric: for group_name, layer_group in param_metric.items(): - if layer_group: - for param_name, param_group in layer_group.items(): - title = f"{param_name}/{group_name}_{param_metric_name}" + for param_name, param_group in layer_group.items(): + title = f"{param_name}/{group_name}_{param_metric_name}" + if layer_names: + filter_param_group = {} + for layer_name, metric_value in param_group.items(): + if layer_name in layer_names: + filter_param_group[layer_name] = param_group[layer_name] + writer.add_scalars(key=title, value=filter_param_group, step=train_state.step_count) + else: writer.add_scalars(key=title, value=param_group, step=train_state.step_count) del grad_norm[param_metric_name]