diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e96d2d9..5bcd5ca 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -355,6 +355,9 @@ def args_sanity_check(): f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}" ) + if "batch_count" not in gpc.config: + gpc.config._add_item("batch_count", 0) + if "moe_loss_coeff" not in gpc.config.loss: gpc.config.loss._add_item("moe_loss_coeff", 1.0) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 98ecf84..3fc3338 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -639,14 +639,18 @@ class HybridZeroOptimizer(BaseOptimizer): groups_param_norms = [] group_param_zero_grad_count = [] group_vocab_norms = [] + batch_count = gpc.config.batch_count + interval_steps = grad_profiling_config.get("interval_steps", 1) + is_profiling = batch_count % interval_steps == 0 for group_id in range(self.num_param_groups): groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) - if grad_profiling_config.get("grad_norm_profiling", False): - groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id)) - if grad_profiling_config.get("zero_grad_profiling", False): - group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id)) - if grad_profiling_config.get("vocab_grad_norm_profiling", False): - group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id)) + if is_profiling: + if grad_profiling_config.get("grad_norm_profiling", False): + groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id)) + if grad_profiling_config.get("zero_grad_profiling", False): + group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id)) + if grad_profiling_config.get("vocab_grad_norm_profiling", False): + group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id)) # clear reduced grads # grads in the last bucket is reduced @@ -673,54 +677,56 @@ class HybridZeroOptimizer(BaseOptimizer): last_stage=True, previous_norm=groups_norms[group_id], ) - if grad_profiling_config.get("grad_norm_profiling", False): - param_norms = self._compute_param_norm_stage( - group_id=group_id, - last_bucket=True, - last_stage=True, - previous_param_norms=groups_param_norms[group_id], - ) - total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm( - param_norms=param_norms, loss_scale=self.loss_scale.item() - ) - if grad_profiling_config.get("zero_grad_profiling", False): - zero_grad_count = self._count_zero_grads_stage( - group_id=group_id, - last_bucket=True, - last_stage=True, - previous_zero_grad_count=group_param_zero_grad_count[group_id], - ) - ( - total_layer_zero_grad_count[group_name], - total_param_zero_grad_count[group_name], - ) = compute_layer_zero_grad_count(zero_grad_count) - if grad_profiling_config.get("vocab_grad_norm_profiling", False): - vocab_grad_norms = self._compute_vocab_grad_norm_stage( - group_id=group_id, - last_bucket=True, - last_stage=True, - previous_vocab_grad_norm=group_vocab_norms[group_id], - ) - inf_mask = vocab_grad_norms == -1 - nan_mask = vocab_grad_norms == -2 - vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item() - vocab_grad_norms[inf_mask] = -1 - vocab_grad_norms[nan_mask] = -2 - total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu") + if is_profiling: + if grad_profiling_config.get("grad_norm_profiling", False): + param_norms = self._compute_param_norm_stage( + group_id=group_id, + last_bucket=True, + last_stage=True, + previous_param_norms=groups_param_norms[group_id], + ) + total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm( + param_norms=param_norms, loss_scale=self.loss_scale.item() + ) + if grad_profiling_config.get("zero_grad_profiling", False): + zero_grad_count = self._count_zero_grads_stage( + group_id=group_id, + last_bucket=True, + last_stage=True, + previous_zero_grad_count=group_param_zero_grad_count[group_id], + ) + ( + total_layer_zero_grad_count[group_name], + total_param_zero_grad_count[group_name], + ) = compute_layer_zero_grad_count(zero_grad_count) + if grad_profiling_config.get("vocab_grad_norm_profiling", False): + vocab_grad_norms = self._compute_vocab_grad_norm_stage( + group_id=group_id, + last_bucket=True, + last_stage=True, + previous_vocab_grad_norm=group_vocab_norms[group_id], + ) + inf_mask = vocab_grad_norms == -1 + nan_mask = vocab_grad_norms == -2 + vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item() + vocab_grad_norms[inf_mask] = -1 + vocab_grad_norms[nan_mask] = -2 + total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu") timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() state, global_norms = self._step(closure=closure, norms=total_norms) - if grad_profiling_config.get("grad_norm_profiling", False): - global_norms["layer_grad_norm"] = total_layer_grad_norms - global_norms["param_grad_norm"] = total_param_grad_norms - if grad_profiling_config.get("zero_grad_profiling", False): - global_norms["layer_zero_grad"] = total_layer_zero_grad_count - global_norms["param_zero_grad"] = total_param_zero_grad_count - if grad_profiling_config.get("vocab_grad_norm_profiling", False): - global_norms["vocab_grad_norm"] = total_vocab_grad_norms + if is_profiling: + if grad_profiling_config.get("grad_norm_profiling", False): + global_norms["layer_grad_norm"] = total_layer_grad_norms + global_norms["param_grad_norm"] = total_param_grad_norms + if grad_profiling_config.get("zero_grad_profiling", False): + global_norms["layer_zero_grad"] = total_layer_zero_grad_count + global_norms["param_zero_grad"] = total_param_zero_grad_count + if grad_profiling_config.get("vocab_grad_norm_profiling", False): + global_norms["vocab_grad_norm"] = total_vocab_grad_norms return state, global_norms diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index b8d005a..db9eefa 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -241,9 +241,11 @@ def reduce_grads(gradients, parameters, fine_grained=False, only_output=False): parallel_grads[param_name].append(g.data.float()) elif only_output: param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding" - grad_profiling_config = gpc.config.get("grad_profiling", {}) - layer_names = grad_profiling_config.get("layers", []) - if gpc.config.model["vocab_size"] == g.shape[0] and param_name.split("-")[0] in layer_names: + if ( + gpc.config.model["vocab_size"] == g.shape[0] + and gpc.config.model["hidden_size"] == g.shape[1] + and "embedding" not in param_name.lower() + ): parallel_grads.append(g.data.float()) else: parallel_grads.append(g.data.float()) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index fab779a..9953b5b 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -521,55 +521,39 @@ def record_current_batch_training_metrics( infos[key] = value grad_profiling_config = gpc.config.get("grad_profiling", {}) - if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( - "zero_grad_profiling", False - ): - layer_metrics = ["layer_grad_norm", "layer_zero_grad"] - param_metrics = ["param_grad_norm", "param_zero_grad"] + interval_steps = grad_profiling_config.get("interval_steps", 1) + if batch_count % interval_steps == 0: + layer_metrics = [metric for metric in ["layer_grad_norm", "layer_zero_grad"] if metric in grad_norm] + param_metrics = [metric for metric in ["param_grad_norm", "param_zero_grad"] if metric in grad_norm] layer_names = grad_profiling_config.get("layers", []) - for layer_metric_name in layer_metrics: - layer_metric = grad_norm.get(layer_metric_name, {}) - if layer_metric: - for group_name, layer_group in layer_metric.items(): - if layer_group: - title = f"{layer_metric_name}/{group_name}" - if layer_names: - filter_layer_metrics = {} - for layer_name, metric_value in layer_group.items(): - if layer_name in layer_names: - filter_layer_metrics[layer_name] = metric_value - if filter_layer_metrics: - writer.add_scalars( - key=title, value=filter_layer_metrics, step=train_state.step_count - ) - else: - writer.add_scalars(key=title, value=layer_group, step=train_state.step_count) - del grad_norm[layer_metric_name] - - for param_metric_name in param_metrics: - param_metric = grad_norm.get(param_metric_name, {}) - if param_metric: - for group_name, layer_group in param_metric.items(): - for param_name, param_group in layer_group.items(): - title = f"{param_name}/{group_name}_{param_metric_name}" - if layer_names: - filter_param_group = {} - for layer_name, metric_value in param_group.items(): - if layer_name in layer_names: - filter_param_group[layer_name] = param_group[layer_name] - if filter_param_group: - writer.add_scalars(key=title, value=filter_param_group, step=train_state.step_count) - else: - writer.add_scalars(key=title, value=param_group, step=train_state.step_count) - del grad_norm[param_metric_name] - if grad_profiling_config.get("vocab_grad_norm_profiling", False): - local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm" - os.makedirs(local_save_path, exist_ok=True) - local_save_file = f"{local_save_path}/vocab_grad_norm.pt" - vocab_grad_norms = grad_norm.get("vocab_grad_norm", {}) - if vocab_grad_norms: - with open(local_save_file, "ab+") as vocab_f: - pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f) + for metric_name in layer_metrics: + metric = grad_norm.get(metric_name) + for group_name, layer_group in metric.items(): + title = f"{metric_name}/{group_name}" + metrics = {k: v for k, v in layer_group.items() if not layer_names or k in layer_names} + if metrics: + writer.add_scalars(key=title, value=metrics, step=train_state.step_count) + del grad_norm[metric_name] + for metric_name in param_metrics: + metric = grad_norm.get(metric_name) + for group_name, layer_group in metric.items(): + for param_name, param_group in layer_group.items(): + title = f"{param_name}/{group_name}_{metric_name}" + metrics = {k: v for k, v in param_group.items() if not layer_names or k in layer_names} + if metrics: + writer.add_scalars(key=title, value=metrics, step=train_state.step_count) + del grad_norm[metric_name] + if grad_profiling_config.get("vocab_grad_norm_profiling", False): + local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm" + os.makedirs(local_save_path, exist_ok=True) + local_save_file = f"{local_save_path}/vocab_grad_norm.pt" + vocab_grad_norms = grad_norm.get("vocab_grad_norm") + if vocab_grad_norms: + try: + with open(local_save_file, "ab+") as vocab_f: + pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f) + except IOError as e: + logger.warning(f"Error saving vocab_grad_norm: {e}") del grad_norm["vocab_grad_norm"] line = "" diff --git a/train.py b/train.py index 35e39fa..e90fdb8 100644 --- a/train.py +++ b/train.py @@ -197,6 +197,7 @@ def main(args): empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) start_time = time.time() timer("one-batch").start() + gpc.config.batch_count = batch_count # load batch data batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)