diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 9bead52..62970d9 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -24,6 +24,14 @@ def module_has_fp32_attr(module: nn.Module): return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module") +def set_output_attr_to_module(module: nn.Module): + setattr(module, "is_output", True) + + +def module_is_output(module: nn.Module): + return hasattr(module, "is_output") and getattr(module, "is_output") + + class NaiveAMPModel(nn.Module): """ This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16. @@ -189,3 +197,8 @@ class NaiveAMPModel(nn.Module): sub_module.to(fp32_dtype) sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32)) sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32)) + if gpc.config.get("output_tf32", False) and module_is_output(sub_module): + sub_module.to(fp32_dtype) + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32)) diff --git a/internlm/data/utils.py b/internlm/data/utils.py index fbcb6f7..92d08f3 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -24,13 +24,16 @@ def get_dataset_type_id(dataset_type_ids_map, path): return match_idxes[0] -def unpack_data(input_ids, cu_seqlens): - """ - input_ids: (n, packed_length) - Return: - output: (batch_size, max_length) +def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False): """ + input_ids: if input_ids is not type_ids, the shape is (1, packed_length) + else the shape is (micro_num, packed_length) + is_type_ids: whether the input_ids is type_ids + Return: + output: if input_ids is not type ids, the shape is (micro_bsz, max_length) + else the shape is (micro_num, micro_bsz, max_length) + """ bsz = input_ids.shape[0] num_sequence = gpc.config.data["micro_bsz"] @@ -45,7 +48,8 @@ def unpack_data(input_ids, cu_seqlens): output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]] outputs[i] = output - if bsz == 1: + # if the input_ids is not type_ids, we need squeeze the first dimension if it is 1. + if bsz == 1 and not is_type_ids: outputs = outputs.squeeze(0) return outputs diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 82a4a21..491e2b0 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -356,6 +356,9 @@ def args_sanity_check(): f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}" ) + if "batch_count" not in gpc.config: + gpc.config._add_item("batch_count", 0) + if "moe_loss_coeff" not in gpc.config.loss: gpc.config.loss._add_item("moe_loss_coeff", 1.0) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 204f71f..a47a5cd 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -13,6 +13,7 @@ from torch import nn from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.random import _SEED_MANAGER +from internlm.core.naive_amp import set_output_attr_to_module from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.initialize.launch import GLOBAL_SEED from internlm.model.embedding import Embedding1D @@ -368,6 +369,7 @@ class PackedFlashInternLm1D(nn.Module): dtype=dtype, weight_scale=embed_grad_scale, ) + set_output_attr_to_module(self.head) for _, param in self.head.named_parameters(): normal_(std=0.0052)(param) if gpc.get_world_size(ParallelMode.TENSOR) > 1: diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 491e59c..01b40ab 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -39,6 +39,7 @@ from .utils import ( compute_layer_zero_grad_count, compute_norm, compute_param_norm, + compute_vocab_grad_norm, compute_zero_grad_count, ) @@ -563,6 +564,28 @@ class HybridZeroOptimizer(BaseOptimizer): ) return total_param_norms + def _compute_vocab_grad_norm_stage( + self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_vocab_grad_norm=None + ): + params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket) + if len(params) == 0: + dtype = self.param_groups[group_id]["dtype"] + grads = [self.padding_grad.to(dtype)] + params = [self.padding_tensor.to(dtype)] + + vocab_grad_norm = None + + if self._clip_grad_norm > 0: + vocab_grad_norm = compute_vocab_grad_norm( + grads, + params, + last_stage=last_stage, + previous_vocab_grad_norm=previous_vocab_grad_norm, + zero_mode=self._broadcast_parallel_mode[group_id], + ) + + return vocab_grad_norm + def _count_zero_grads_stage( self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_zero_grad_count=None ): @@ -615,12 +638,19 @@ class HybridZeroOptimizer(BaseOptimizer): groups_norms = [] groups_param_norms = [] group_param_zero_grad_count = [] + group_vocab_norms = [] + batch_count = gpc.config.get("batch_count") + interval_steps = grad_profiling_config.get("interval_steps", 1) + is_profiling = batch_count % interval_steps == 0 if batch_count is not None else False for group_id in range(self.num_param_groups): groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) - if grad_profiling_config.get("grad_norm_profiling", False): - groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id)) - if grad_profiling_config.get("zero_grad_profiling", False): - group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id)) + if is_profiling: + if grad_profiling_config.get("grad_norm_profiling", False): + groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id)) + if grad_profiling_config.get("zero_grad_profiling", False): + group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id)) + if grad_profiling_config.get("vocab_grad_norm_profiling", False): + group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id)) # clear reduced grads # grads in the last bucket is reduced @@ -637,6 +667,7 @@ class HybridZeroOptimizer(BaseOptimizer): total_layer_grad_norms = {} total_param_zero_grad_count = {} total_layer_zero_grad_count = {} + total_vocab_grad_norms = {} for group_id in range(self.num_param_groups): group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default" group_name = f"{group_id}_{group_name}" @@ -646,39 +677,56 @@ class HybridZeroOptimizer(BaseOptimizer): last_stage=True, previous_norm=groups_norms[group_id], ) - if grad_profiling_config.get("grad_norm_profiling", False): - param_norms = self._compute_param_norm_stage( - group_id=group_id, - last_bucket=True, - last_stage=True, - previous_param_norms=groups_param_norms[group_id], - ) - total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm( - param_norms=param_norms, loss_scale=self.loss_scale.item() - ) - if grad_profiling_config.get("zero_grad_profiling", False): - zero_grad_count = self._count_zero_grads_stage( - group_id=group_id, - last_bucket=True, - last_stage=True, - previous_zero_grad_count=group_param_zero_grad_count[group_id], - ) - ( - total_layer_zero_grad_count[group_name], - total_param_zero_grad_count[group_name], - ) = compute_layer_zero_grad_count(zero_grad_count) + if is_profiling: + if grad_profiling_config.get("grad_norm_profiling", False): + param_norms = self._compute_param_norm_stage( + group_id=group_id, + last_bucket=True, + last_stage=True, + previous_param_norms=groups_param_norms[group_id], + ) + total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm( + param_norms=param_norms, loss_scale=self.loss_scale.item() + ) + if grad_profiling_config.get("zero_grad_profiling", False): + zero_grad_count = self._count_zero_grads_stage( + group_id=group_id, + last_bucket=True, + last_stage=True, + previous_zero_grad_count=group_param_zero_grad_count[group_id], + ) + ( + total_layer_zero_grad_count[group_name], + total_param_zero_grad_count[group_name], + ) = compute_layer_zero_grad_count(zero_grad_count) + if grad_profiling_config.get("vocab_grad_norm_profiling", False): + vocab_grad_norms = self._compute_vocab_grad_norm_stage( + group_id=group_id, + last_bucket=True, + last_stage=True, + previous_vocab_grad_norm=group_vocab_norms[group_id], + ) + inf_mask = vocab_grad_norms == -1 + nan_mask = vocab_grad_norms == -2 + vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item() + vocab_grad_norms[inf_mask] = -1 + vocab_grad_norms[nan_mask] = -2 + total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu") timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() state, global_norms = self._step(closure=closure, norms=total_norms) - if grad_profiling_config.get("grad_norm_profiling", False): - global_norms["layer_grad_norm"] = total_layer_grad_norms - global_norms["param_grad_norm"] = total_param_grad_norms - if grad_profiling_config.get("zero_grad_profiling", False): - global_norms["layer_zero_grad"] = total_layer_zero_grad_count - global_norms["param_zero_grad"] = total_param_zero_grad_count + if is_profiling: + if grad_profiling_config.get("grad_norm_profiling", False): + global_norms["layer_grad_norm"] = total_layer_grad_norms + global_norms["param_grad_norm"] = total_param_grad_norms + if grad_profiling_config.get("zero_grad_profiling", False): + global_norms["layer_zero_grad"] = total_layer_zero_grad_count + global_norms["param_zero_grad"] = total_param_zero_grad_count + if grad_profiling_config.get("vocab_grad_norm_profiling", False): + global_norms["vocab_grad_norm"] = total_vocab_grad_norms return state, global_norms diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 0cc7451..db9eefa 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -218,7 +218,17 @@ def calc_zero_grad(grads): return torch.tensor([zero_count, grad_size]) -def reduce_grads(gradients, parameters, fine_grained=False): +def get_norm(grads, norm_type, enable_cuda_kernels): + if norm_type == inf: + grad_norm = max(g.data.abs().max() for g in grads) + elif norm_type == 2.0 and enable_cuda_kernels: + grad_norm = calc_l2_norm(grads) ** norm_type + else: + grad_norm = calc_lp(grads, norm_type) + return grad_norm + + +def reduce_grads(gradients, parameters, fine_grained=False, only_output=False): parallel_grads = [] if fine_grained: parallel_grads = {} @@ -229,6 +239,14 @@ def reduce_grads(gradients, parameters, fine_grained=False): if param_name not in parallel_grads: parallel_grads[param_name] = [] parallel_grads[param_name].append(g.data.float()) + elif only_output: + param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding" + if ( + gpc.config.model["vocab_size"] == g.shape[0] + and gpc.config.model["hidden_size"] == g.shape[1] + and "embedding" not in param_name.lower() + ): + parallel_grads.append(g.data.float()) else: parallel_grads.append(g.data.float()) @@ -306,10 +324,7 @@ def compute_norm( else: tensor_parallel_grads = reduce_grads(gradients, parameters) - if norm_type == 2.0 and enable_cuda_kernels: - tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type - else: - tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type) + tensor_parallel_norm = get_norm(tensor_parallel_grads, norm_type, enable_cuda_kernels) # If norm is type of float, then we convert them into torch.Tensor. tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) @@ -359,6 +374,59 @@ def compute_norm( return total_norm +def compute_vocab_grad_norm( + gradients, + parameters, + last_stage=False, + previous_vocab_grad_norm=None, + norm_type=2, + zero_mode=ParallelMode.ZERO1, +): + enable_cuda_kernels = gradients[0].device.type == "cuda" + # Norm parameters. + norm_type = float(norm_type) + vocab_size = gpc.config.model["vocab_size"] + + param_grads = reduce_grads(gradients, parameters, only_output=True) + + vocab_grad_norm = torch.zeros((vocab_size,), dtype=torch.float32).to(get_current_device()) + if param_grads: + for grad in param_grads: + # get grad norm of each vocab + for i in range(vocab_size): + cur_vocab_grad_norm = get_norm([grad[i, :]], norm_type, enable_cuda_kernels)[0] + vocab_grad_norm[i] += get_tensor_norm(cur_vocab_grad_norm, move_to_cuda=True) + + if last_stage is False: + return vocab_grad_norm + + if previous_vocab_grad_norm is not None: + vocab_grad_norm = vocab_grad_norm + previous_vocab_grad_norm + + if gpc.is_initialized(ParallelMode.MODEL): + dist.all_reduce( + vocab_grad_norm, + op=dist.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.MODEL), + ) + + dist.all_reduce(vocab_grad_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode)) + + if zero_mode == ParallelMode.EXPERT_DATA: + pg = gpc.get_group(ParallelMode.EXPERT) + scaled_norm = vocab_grad_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=pg) + vocab_grad_norm = scaled_norm_tensor.item() + + # Scale. + vocab_grad_norm[vocab_grad_norm == float("inf")] = -1 + vocab_grad_norm[vocab_grad_norm == -float("inf")] = -1 + vocab_grad_norm[torch.isnan(vocab_grad_norm)] = -2 + + return vocab_grad_norm + + def compute_param_metric( gradients, parameters, @@ -384,12 +452,7 @@ def compute_param_metric( for param_name, grads in param_grads.items(): if metric_type == "norm": - if norm_type == inf: - param_metric = max(g.data.abs().max() for g in grads) - elif norm_type == 2.0 and enable_cuda_kernels: - param_metric = calc_l2_norm(grads) ** norm_type - else: - param_metric = calc_lp(grads, norm_type) + param_metric = get_norm(grads, norm_type, enable_cuda_kernels) param_metrics[param_name] = param_metric.item() if torch.is_tensor(param_metric) else param_metric elif metric_type == "zero_grad": param_zero_grad_count = calc_zero_grad(grads) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 1e36a21..474bfd2 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -2,6 +2,8 @@ # -*- encoding: utf-8 -*- import functools +import os +import pickle import time from functools import partial from typing import Callable, Iterable, Optional, Union @@ -49,10 +51,11 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.train.utils import create_param_groups -from internlm.utils.common import DummyProfile +from internlm.utils.common import DummyProfile, launch_time from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import ( + check_sequence_parallel, set_model_params_layer_name, sync_model_param, sync_model_param_within_tp, @@ -111,6 +114,10 @@ def initialize_model(): # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) + # check whether the norm module has IS_SEQUENCE_PARALLEL attribute + if gpc.config.parallel.sequence_parallel is True: + check_sequence_parallel(model) + return model @@ -159,8 +166,10 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): A tuple of (optimizer, beta2_scheduler, lr_scheduler). """ grad_profiling_config = gpc.config.get("grad_profiling", {}) - if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( - "zero_grad_profiling", False + if ( + grad_profiling_config.get("grad_norm_profiling", False) + or grad_profiling_config.get("zero_grad_profiling", False) + or grad_profiling_config.get("vocab_grad_norm_profiling", False) ): # set the layer name as an attribute of the model parameters set_model_params_layer_name(model) @@ -359,7 +368,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai if batch[0].get("type_ids", None) is not None: # if use_flash_attn is False, we need to unpack type_ids if not gpc.config.model.use_flash_attn: - batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"]) + batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True) return batch, train_iter @@ -517,43 +526,40 @@ def record_current_batch_training_metrics( infos[key] = value grad_profiling_config = gpc.config.get("grad_profiling", {}) - if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( - "zero_grad_profiling", False - ): - layer_metrics = ["layer_grad_norm", "layer_zero_grad"] - param_metrics = ["param_grad_norm", "param_zero_grad"] + interval_steps = grad_profiling_config.get("interval_steps", 1) + if batch_count % interval_steps == 0: + layer_metrics = [metric for metric in ["layer_grad_norm", "layer_zero_grad"] if metric in grad_norm] + param_metrics = [metric for metric in ["param_grad_norm", "param_zero_grad"] if metric in grad_norm] layer_names = grad_profiling_config.get("layers", []) - for layer_metric_name in layer_metrics: - layer_metric = grad_norm.get(layer_metric_name, {}) - if layer_metric: - for group_name, layer_group in layer_metric.items(): - if layer_group: - title = f"{layer_metric_name}/{group_name}" - if layer_names: - filter_layer_metrics = {} - for layer_name, metric_value in layer_group.items(): - if layer_name in layer_names: - filter_layer_metrics[layer_name] = metric_value - writer.add_scalars(key=title, value=filter_layer_metrics, step=train_state.step_count) - else: - writer.add_scalars(key=title, value=layer_group, step=train_state.step_count) - del grad_norm[layer_metric_name] - - for param_metric_name in param_metrics: - param_metric = grad_norm.get(param_metric_name, {}) - if param_metric: - for group_name, layer_group in param_metric.items(): - for param_name, param_group in layer_group.items(): - title = f"{param_name}/{group_name}_{param_metric_name}" - if layer_names: - filter_param_group = {} - for layer_name, metric_value in param_group.items(): - if layer_name in layer_names: - filter_param_group[layer_name] = param_group[layer_name] - writer.add_scalars(key=title, value=filter_param_group, step=train_state.step_count) - else: - writer.add_scalars(key=title, value=param_group, step=train_state.step_count) - del grad_norm[param_metric_name] + for metric_name in layer_metrics: + metric = grad_norm.get(metric_name) + for group_name, layer_group in metric.items(): + title = f"{metric_name}/{group_name}" + metrics = {k: v for k, v in layer_group.items() if not layer_names or k in layer_names} + if metrics: + writer.add_scalars(key=title, value=metrics, step=train_state.step_count) + del grad_norm[metric_name] + for metric_name in param_metrics: + metric = grad_norm.get(metric_name) + for group_name, layer_group in metric.items(): + for param_name, param_group in layer_group.items(): + title = f"{param_name}/{group_name}_{metric_name}" + metrics = {k: v for k, v in param_group.items() if not layer_names or k in layer_names} + if metrics: + writer.add_scalars(key=title, value=metrics, step=train_state.step_count) + del grad_norm[metric_name] + if grad_profiling_config.get("vocab_grad_norm_profiling", False): + local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm" + os.makedirs(local_save_path, exist_ok=True) + local_save_file = f"{local_save_path}/vocab_grad_norm.pt" + vocab_grad_norms = grad_norm.get("vocab_grad_norm") + if vocab_grad_norms: + try: + with open(local_save_file, "ab+") as vocab_f: + pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f) + except IOError as e: + logger.warning(f"Error saving vocab_grad_norm: {e}") + del grad_norm["vocab_grad_norm"] line = "" for key, value in infos.items(): diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 9b70fc8..e6bb18f 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -4,11 +4,14 @@ import torch.distributed as dist from torch import nn -from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel +from internlm.model.utils import try_import_RMSNorm from internlm.solver.pipeline_utils import partition_uniform +RMSNorm = try_import_RMSNorm() + def is_model_parallel_parameter(p): return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) @@ -98,3 +101,26 @@ def set_model_params_layer_name(model): layer_param_name = f"{layer_name}-{param_name}" param.__setattr__("layer_name", layer_name) param.__setattr__("param_name", f"{layer_name}-{param_name}") + + +def check_sequence_parallel(model): + """ + check whether the norm module has IS_SEQUENCE_PARALLEL attribute. + when the sequence_parallel is True, the norm module should have the IS_SEQUENCE_PARALLEL attribute + to illustrate the norm should conduct the all-reduce for its grad. + """ + + if not isinstance(model, nn.ModuleList): + model = [model] + + for _chunk in model: + if isinstance(_chunk, NaiveAMPModel): + _chunk = _chunk.model + + for _, module in _chunk.named_modules(): + if isinstance(module, (RMSNorm, nn.LayerNorm)): + for param in module.parameters(): + assert hasattr(param, IS_SEQUENCE_PARALLEL), ( + "when the gpc.config.parallel.sequence parallel is True," + "the params of norm module should have IS_SEQUENCE_PARALLEL attribute" + ) diff --git a/train.py b/train.py index 9f0c1ac..6874f9e 100644 --- a/train.py +++ b/train.py @@ -197,6 +197,7 @@ def main(args): empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) start_time = time.time() timer("one-batch").start() + gpc.config.batch_count = batch_count # load batch data batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)