diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 82a4a21..491e2b0 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -356,6 +356,9 @@ def args_sanity_check(): f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}" ) + if "batch_count" not in gpc.config: + gpc.config._add_item("batch_count", 0) + if "moe_loss_coeff" not in gpc.config.loss: gpc.config.loss._add_item("moe_loss_coeff", 1.0) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 491e59c..01b40ab 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -39,6 +39,7 @@ from .utils import ( compute_layer_zero_grad_count, compute_norm, compute_param_norm, + compute_vocab_grad_norm, compute_zero_grad_count, ) @@ -563,6 +564,28 @@ class HybridZeroOptimizer(BaseOptimizer): ) return total_param_norms + def _compute_vocab_grad_norm_stage( + self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_vocab_grad_norm=None + ): + params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket) + if len(params) == 0: + dtype = self.param_groups[group_id]["dtype"] + grads = [self.padding_grad.to(dtype)] + params = [self.padding_tensor.to(dtype)] + + vocab_grad_norm = None + + if self._clip_grad_norm > 0: + vocab_grad_norm = compute_vocab_grad_norm( + grads, + params, + last_stage=last_stage, + previous_vocab_grad_norm=previous_vocab_grad_norm, + zero_mode=self._broadcast_parallel_mode[group_id], + ) + + return vocab_grad_norm + def _count_zero_grads_stage( self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_zero_grad_count=None ): @@ -615,12 +638,19 @@ class HybridZeroOptimizer(BaseOptimizer): groups_norms = [] groups_param_norms = [] group_param_zero_grad_count = [] + group_vocab_norms = [] + batch_count = gpc.config.get("batch_count") + interval_steps = grad_profiling_config.get("interval_steps", 1) + is_profiling = batch_count % interval_steps == 0 if batch_count is not None else False for group_id in range(self.num_param_groups): groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) - if grad_profiling_config.get("grad_norm_profiling", False): - groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id)) - if grad_profiling_config.get("zero_grad_profiling", False): - group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id)) + if is_profiling: + if grad_profiling_config.get("grad_norm_profiling", False): + groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id)) + if grad_profiling_config.get("zero_grad_profiling", False): + group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id)) + if grad_profiling_config.get("vocab_grad_norm_profiling", False): + group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id)) # clear reduced grads # grads in the last bucket is reduced @@ -637,6 +667,7 @@ class HybridZeroOptimizer(BaseOptimizer): total_layer_grad_norms = {} total_param_zero_grad_count = {} total_layer_zero_grad_count = {} + total_vocab_grad_norms = {} for group_id in range(self.num_param_groups): group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default" group_name = f"{group_id}_{group_name}" @@ -646,39 +677,56 @@ class HybridZeroOptimizer(BaseOptimizer): last_stage=True, previous_norm=groups_norms[group_id], ) - if grad_profiling_config.get("grad_norm_profiling", False): - param_norms = self._compute_param_norm_stage( - group_id=group_id, - last_bucket=True, - last_stage=True, - previous_param_norms=groups_param_norms[group_id], - ) - total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm( - param_norms=param_norms, loss_scale=self.loss_scale.item() - ) - if grad_profiling_config.get("zero_grad_profiling", False): - zero_grad_count = self._count_zero_grads_stage( - group_id=group_id, - last_bucket=True, - last_stage=True, - previous_zero_grad_count=group_param_zero_grad_count[group_id], - ) - ( - total_layer_zero_grad_count[group_name], - total_param_zero_grad_count[group_name], - ) = compute_layer_zero_grad_count(zero_grad_count) + if is_profiling: + if grad_profiling_config.get("grad_norm_profiling", False): + param_norms = self._compute_param_norm_stage( + group_id=group_id, + last_bucket=True, + last_stage=True, + previous_param_norms=groups_param_norms[group_id], + ) + total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm( + param_norms=param_norms, loss_scale=self.loss_scale.item() + ) + if grad_profiling_config.get("zero_grad_profiling", False): + zero_grad_count = self._count_zero_grads_stage( + group_id=group_id, + last_bucket=True, + last_stage=True, + previous_zero_grad_count=group_param_zero_grad_count[group_id], + ) + ( + total_layer_zero_grad_count[group_name], + total_param_zero_grad_count[group_name], + ) = compute_layer_zero_grad_count(zero_grad_count) + if grad_profiling_config.get("vocab_grad_norm_profiling", False): + vocab_grad_norms = self._compute_vocab_grad_norm_stage( + group_id=group_id, + last_bucket=True, + last_stage=True, + previous_vocab_grad_norm=group_vocab_norms[group_id], + ) + inf_mask = vocab_grad_norms == -1 + nan_mask = vocab_grad_norms == -2 + vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item() + vocab_grad_norms[inf_mask] = -1 + vocab_grad_norms[nan_mask] = -2 + total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu") timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() state, global_norms = self._step(closure=closure, norms=total_norms) - if grad_profiling_config.get("grad_norm_profiling", False): - global_norms["layer_grad_norm"] = total_layer_grad_norms - global_norms["param_grad_norm"] = total_param_grad_norms - if grad_profiling_config.get("zero_grad_profiling", False): - global_norms["layer_zero_grad"] = total_layer_zero_grad_count - global_norms["param_zero_grad"] = total_param_zero_grad_count + if is_profiling: + if grad_profiling_config.get("grad_norm_profiling", False): + global_norms["layer_grad_norm"] = total_layer_grad_norms + global_norms["param_grad_norm"] = total_param_grad_norms + if grad_profiling_config.get("zero_grad_profiling", False): + global_norms["layer_zero_grad"] = total_layer_zero_grad_count + global_norms["param_zero_grad"] = total_param_zero_grad_count + if grad_profiling_config.get("vocab_grad_norm_profiling", False): + global_norms["vocab_grad_norm"] = total_vocab_grad_norms return state, global_norms diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 0cc7451..db9eefa 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -218,7 +218,17 @@ def calc_zero_grad(grads): return torch.tensor([zero_count, grad_size]) -def reduce_grads(gradients, parameters, fine_grained=False): +def get_norm(grads, norm_type, enable_cuda_kernels): + if norm_type == inf: + grad_norm = max(g.data.abs().max() for g in grads) + elif norm_type == 2.0 and enable_cuda_kernels: + grad_norm = calc_l2_norm(grads) ** norm_type + else: + grad_norm = calc_lp(grads, norm_type) + return grad_norm + + +def reduce_grads(gradients, parameters, fine_grained=False, only_output=False): parallel_grads = [] if fine_grained: parallel_grads = {} @@ -229,6 +239,14 @@ def reduce_grads(gradients, parameters, fine_grained=False): if param_name not in parallel_grads: parallel_grads[param_name] = [] parallel_grads[param_name].append(g.data.float()) + elif only_output: + param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding" + if ( + gpc.config.model["vocab_size"] == g.shape[0] + and gpc.config.model["hidden_size"] == g.shape[1] + and "embedding" not in param_name.lower() + ): + parallel_grads.append(g.data.float()) else: parallel_grads.append(g.data.float()) @@ -306,10 +324,7 @@ def compute_norm( else: tensor_parallel_grads = reduce_grads(gradients, parameters) - if norm_type == 2.0 and enable_cuda_kernels: - tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type - else: - tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type) + tensor_parallel_norm = get_norm(tensor_parallel_grads, norm_type, enable_cuda_kernels) # If norm is type of float, then we convert them into torch.Tensor. tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) @@ -359,6 +374,59 @@ def compute_norm( return total_norm +def compute_vocab_grad_norm( + gradients, + parameters, + last_stage=False, + previous_vocab_grad_norm=None, + norm_type=2, + zero_mode=ParallelMode.ZERO1, +): + enable_cuda_kernels = gradients[0].device.type == "cuda" + # Norm parameters. + norm_type = float(norm_type) + vocab_size = gpc.config.model["vocab_size"] + + param_grads = reduce_grads(gradients, parameters, only_output=True) + + vocab_grad_norm = torch.zeros((vocab_size,), dtype=torch.float32).to(get_current_device()) + if param_grads: + for grad in param_grads: + # get grad norm of each vocab + for i in range(vocab_size): + cur_vocab_grad_norm = get_norm([grad[i, :]], norm_type, enable_cuda_kernels)[0] + vocab_grad_norm[i] += get_tensor_norm(cur_vocab_grad_norm, move_to_cuda=True) + + if last_stage is False: + return vocab_grad_norm + + if previous_vocab_grad_norm is not None: + vocab_grad_norm = vocab_grad_norm + previous_vocab_grad_norm + + if gpc.is_initialized(ParallelMode.MODEL): + dist.all_reduce( + vocab_grad_norm, + op=dist.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.MODEL), + ) + + dist.all_reduce(vocab_grad_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode)) + + if zero_mode == ParallelMode.EXPERT_DATA: + pg = gpc.get_group(ParallelMode.EXPERT) + scaled_norm = vocab_grad_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=pg) + vocab_grad_norm = scaled_norm_tensor.item() + + # Scale. + vocab_grad_norm[vocab_grad_norm == float("inf")] = -1 + vocab_grad_norm[vocab_grad_norm == -float("inf")] = -1 + vocab_grad_norm[torch.isnan(vocab_grad_norm)] = -2 + + return vocab_grad_norm + + def compute_param_metric( gradients, parameters, @@ -384,12 +452,7 @@ def compute_param_metric( for param_name, grads in param_grads.items(): if metric_type == "norm": - if norm_type == inf: - param_metric = max(g.data.abs().max() for g in grads) - elif norm_type == 2.0 and enable_cuda_kernels: - param_metric = calc_l2_norm(grads) ** norm_type - else: - param_metric = calc_lp(grads, norm_type) + param_metric = get_norm(grads, norm_type, enable_cuda_kernels) param_metrics[param_name] = param_metric.item() if torch.is_tensor(param_metric) else param_metric elif metric_type == "zero_grad": param_zero_grad_count = calc_zero_grad(grads) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 1f68af9..89e2d06 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -2,6 +2,8 @@ # -*- encoding: utf-8 -*- import functools +import os +import pickle import time from functools import partial from typing import Callable, Iterable, Optional, Union @@ -49,7 +51,7 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.train.utils import create_param_groups -from internlm.utils.common import DummyProfile +from internlm.utils.common import DummyProfile, launch_time from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import ( @@ -164,8 +166,10 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): A tuple of (optimizer, beta2_scheduler, lr_scheduler). """ grad_profiling_config = gpc.config.get("grad_profiling", {}) - if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( - "zero_grad_profiling", False + if ( + grad_profiling_config.get("grad_norm_profiling", False) + or grad_profiling_config.get("zero_grad_profiling", False) + or grad_profiling_config.get("vocab_grad_norm_profiling", False) ): # set the layer name as an attribute of the model parameters set_model_params_layer_name(model) @@ -522,43 +526,40 @@ def record_current_batch_training_metrics( infos[key] = value grad_profiling_config = gpc.config.get("grad_profiling", {}) - if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( - "zero_grad_profiling", False - ): - layer_metrics = ["layer_grad_norm", "layer_zero_grad"] - param_metrics = ["param_grad_norm", "param_zero_grad"] + interval_steps = grad_profiling_config.get("interval_steps", 1) + if batch_count % interval_steps == 0: + layer_metrics = [metric for metric in ["layer_grad_norm", "layer_zero_grad"] if metric in grad_norm] + param_metrics = [metric for metric in ["param_grad_norm", "param_zero_grad"] if metric in grad_norm] layer_names = grad_profiling_config.get("layers", []) - for layer_metric_name in layer_metrics: - layer_metric = grad_norm.get(layer_metric_name, {}) - if layer_metric: - for group_name, layer_group in layer_metric.items(): - if layer_group: - title = f"{layer_metric_name}/{group_name}" - if layer_names: - filter_layer_metrics = {} - for layer_name, metric_value in layer_group.items(): - if layer_name in layer_names: - filter_layer_metrics[layer_name] = metric_value - writer.add_scalars(key=title, value=filter_layer_metrics, step=train_state.step_count) - else: - writer.add_scalars(key=title, value=layer_group, step=train_state.step_count) - del grad_norm[layer_metric_name] - - for param_metric_name in param_metrics: - param_metric = grad_norm.get(param_metric_name, {}) - if param_metric: - for group_name, layer_group in param_metric.items(): - for param_name, param_group in layer_group.items(): - title = f"{param_name}/{group_name}_{param_metric_name}" - if layer_names: - filter_param_group = {} - for layer_name, metric_value in param_group.items(): - if layer_name in layer_names: - filter_param_group[layer_name] = param_group[layer_name] - writer.add_scalars(key=title, value=filter_param_group, step=train_state.step_count) - else: - writer.add_scalars(key=title, value=param_group, step=train_state.step_count) - del grad_norm[param_metric_name] + for metric_name in layer_metrics: + metric = grad_norm.get(metric_name) + for group_name, layer_group in metric.items(): + title = f"{metric_name}/{group_name}" + metrics = {k: v for k, v in layer_group.items() if not layer_names or k in layer_names} + if metrics: + writer.add_scalars(key=title, value=metrics, step=train_state.step_count) + del grad_norm[metric_name] + for metric_name in param_metrics: + metric = grad_norm.get(metric_name) + for group_name, layer_group in metric.items(): + for param_name, param_group in layer_group.items(): + title = f"{param_name}/{group_name}_{metric_name}" + metrics = {k: v for k, v in param_group.items() if not layer_names or k in layer_names} + if metrics: + writer.add_scalars(key=title, value=metrics, step=train_state.step_count) + del grad_norm[metric_name] + if grad_profiling_config.get("vocab_grad_norm_profiling", False): + local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm" + os.makedirs(local_save_path, exist_ok=True) + local_save_file = f"{local_save_path}/vocab_grad_norm.pt" + vocab_grad_norms = grad_norm.get("vocab_grad_norm") + if vocab_grad_norms: + try: + with open(local_save_file, "ab+") as vocab_f: + pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f) + except IOError as e: + logger.warning(f"Error saving vocab_grad_norm: {e}") + del grad_norm["vocab_grad_norm"] line = "" for key, value in infos.items(): diff --git a/train.py b/train.py index 9f0c1ac..6874f9e 100644 --- a/train.py +++ b/train.py @@ -197,6 +197,7 @@ def main(args): empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) start_time = time.time() timer("one-batch").start() + gpc.config.batch_count = batch_count # load batch data batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)