From 6c0ff4820fa98abf5834d8c34b2b4ba6ebfa18e0 Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Mon, 11 Dec 2023 16:25:24 +0800 Subject: [PATCH 1/5] feat(model): support llama model with checkpoint loading (#532) * support hf llama * support hf llama * support hf llama * support hf llama * importerror * importerror * modeling * modeling --- configs/7B_sft.py | 2 +- internlm/model/__init__.py | 2 + internlm/model/modeling_llama.py | 1142 +++++++++++++++++ internlm/utils/model_checkpoint.py | 206 ++- .../internlm_model/modeling_internlm.py | 7 +- 5 files changed, 1352 insertions(+), 7 deletions(-) create mode 100644 internlm/model/modeling_llama.py diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 1cbb5e7..7d945b4 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -28,7 +28,7 @@ ckpt = dict( # 'load_ckpt_info' setting guide: # 1. the 'path' indicate ckpt path, # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported. + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internlm", "llama", "hf_llama". load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) diff --git a/internlm/model/__init__.py b/internlm/model/__init__.py index cb7dd8e..a4efc03 100644 --- a/internlm/model/__init__.py +++ b/internlm/model/__init__.py @@ -5,6 +5,7 @@ from .embedding import Embedding1D, RotaryEmbedding from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear from .metrics import AccPerplex from .modeling_internlm import build_model_with_cfg +from .modeling_llama import build_model_with_cfg as build_model_with_llama_cfg from .modeling_moe import build_model_with_moe_cfg from .moe import MoE from .multi_head_attention import MHA @@ -22,4 +23,5 @@ __all__ = [ "gather_forward_split_backward", "build_model_with_cfg", "build_model_with_moe_cfg", + "build_model_with_llama_cfg", ] diff --git a/internlm/model/modeling_llama.py b/internlm/model/modeling_llama.py new file mode 100644 index 0000000..16af3c2 --- /dev/null +++ b/internlm/model/modeling_llama.py @@ -0,0 +1,1142 @@ +# Copyright (c) InternLM. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.initialize.initialize_tensor import ( + normal_, + scaled_init_method_normal, + scaled_init_method_uniform, + uniform_, +) +from internlm.model.embedding import Embedding1D, RotaryEmbedding +from internlm.model.linear import ( + ColumnParallelLinearTorch, + FeedForward, + RewardModelLinear, + RowParallelLinearTorch, + ScaleColumnParallelLinear, +) +from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm +from internlm.solver.pipeline_utils import partition_uniform +from internlm.utils.checkpoint import activation_checkpoint +from internlm.utils.common import filter_kwargs +from internlm.utils.logger import get_logger +from internlm.utils.registry import MODEL_INITIALIZER + +try: + from flash_attn import flash_attn_varlen_kvpacked_func + from flash_attn.flash_attn_interface import FlashAttnVarlenKVPackedFunc + from flash_attn.modules.embedding import ParallelGPT2Embeddings + from flash_attn.modules.mha import ( + CrossAttention, + FlashCrossAttention, + FlashSelfAttention, + SelfAttention, + _update_kv_cache, + ) + from flash_attn.modules.mlp import ParallelFusedMLP + from flash_attn.ops.layer_norm import dropout_add_layer_norm +except ImportError: + pass + +MODEL_TYPE = "LLAMA" + +logger = get_logger(__file__) +RMSNorm = try_import_RMSNorm() + + +class MHA(nn.Module): + """ + Multi-head self-attention and cross-attention. + + Args: + embed_dim (int): The dimention of hidden state. + num_heads (int): The number of attention heads. + num_kv_heads (int): The number of kv attention heads. + process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. + bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and + output projection. True by default. + dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. + softmax_scale (float): The temperature to use for the softmax attention. + causal (boolean): Whether to apply causal attention mask. False by default. + layer_idx (int): The index of current layer. None by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. + rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements + XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. + use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used. + True by default. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + rot_embed_HF_impl: rotary embedding hf implementation. False by default. + + + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + process_group: Optional[torch.distributed.ProcessGroup], + bias: bool = True, + dropout: float = 0.0, + softmax_scale: float = None, + causal: bool = False, + layer_idx: int = None, + rope_base: int = 10000, + rotary_emb_dim: int = 0, + rotary_emb_scale_base: int = 0, + use_flash_attn: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + rot_embed_HF_impl: Optional[bool] = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads" + + self.head_dim = self.embed_dim // num_heads + self.num_kv_heads = num_kv_heads + self.kv_dim = self.head_dim * num_kv_heads + self.causal = causal + self.layer_idx = layer_idx + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.dtype = dtype + + self.rot_embed_HF_impl = rot_embed_HF_impl + sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) + + if self.rotary_emb_dim > 0: + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device + ) + + # notice here should change bias=True + self.wq = ColumnParallelLinearTorch( + embed_dim, + embed_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + self.wk = ColumnParallelLinearTorch( + embed_dim, + self.kv_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + self.wv = ColumnParallelLinearTorch( + embed_dim, + self.kv_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + + inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_cross_attn = inner_cross_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + + self.inner_cross_attn_causal = causal + self.inner_cross_attn_softmax_scale = softmax_scale + self.inner_cross_attn_dropout = dropout + + # output projection always have the bias (for now) + self.wo = RowParallelLinearTorch( + embed_dim, + embed_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + # need to assign tp attribute so that internlm know it is tensor parallel module + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + for name in ["wo", "wq", "wk", "wv"]: + for param in getattr(self, name).parameters(): + setattr(param, IS_TENSOR_PARALLEL, True) + + def forward(self, x, seqlen=None, inference_params=None, **kwargs): + if kwargs.get("indexes", None) is not None: + return self._packed_forward(x=x, inference_params=inference_params, **kwargs) + else: + return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) + + def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613 + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + """ + bsz, _, _ = x.shape + q, k, v = self.wq(x), self.wk(x), self.wv(x) + if seqlen is None: + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + else: + q = rearrange(q, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) + k = rearrange(k, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) + v = rearrange(v, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) + + if not self.rot_embed_HF_impl: + q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) + k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) + if inference_params is None: + if self.rotary_emb_dim > 0: + q = self.rotary_emb._single_eval_forward(q) + k = self.rotary_emb._single_eval_forward(k) + kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) + if self.dtype is torch.float32 and self.use_flash_attn: + if q.dtype not in [torch.float16, torch.bfloat16]: + q = q.to(torch.bfloat16) + if kv.dtype not in [torch.float16, torch.bfloat16]: + kv = kv.to(torch.bfloat16) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + context = self.inner_cross_attn(q, kv).to(self.dtype) + else: + context = self.inner_cross_attn(q, kv) + + else: + assert self.rotary_emb_dim > 0 + if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: + empties = inference_params.attention_mask[..., -1].sum(dim=-1) + moved_q = q.clone() + moved_k = k.clone() + if inference_params.sequence_len_offset == 0: + for i in range(len(empties)): + if empties[i] != 0: + moved_q[i][: -empties[i]] = q[i][empties[i] :] + moved_k[i][: -empties[i]] = k[i][empties[i] :] + moved_q = self.rotary_emb._single_eval_forward( + moved_q, seqlen_offset=inference_params.sequence_len_offset + ) + moved_k = self.rotary_emb._single_eval_forward( + moved_k, seqlen_offset=inference_params.sequence_len_offset + ) + for i in range(len(empties)): + if empties[i] != 0: + q[i][empties[i] :] = moved_q[i][: -empties[i]] + k[i][empties[i] :] = moved_k[i][: -empties[i]] + else: + q[i] = moved_q[i] + k[i] = moved_k[i] + else: + q = q.squeeze(1) + k = k.squeeze(1) + q = self.rotary_emb._single_forward( + q, + inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) + - empties, + ).unsqueeze(1) + k = self.rotary_emb._single_forward( + k, + inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) + - empties, + ).unsqueeze(1) + else: + raise NotImplementedError( + "You should make sure you are aware that you are changing the method of generating." + "According to your generation function instead of inference/seq_generator_module.py, " + "You may implement here for normal running." + ) + + kv = torch.stack([k, v], dim=2) + + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + if hasattr(inference_params, "window_size") and inference_params.window_size is not None: + if inference_params.window_size <= inference_params.sequence_len_offset: + assert kv.size(1) == 1, "update kv lenth more than 1" + inference_params.key_value_memory_dict[self.layer_idx][ + :, inference_params.keep_first : inference_params.window_size - 1, ... + ] = inference_params.key_value_memory_dict[self.layer_idx][ + :, -(inference_params.window_size - 1 - inference_params.keep_first) :, ... + ].clone() + inference_params.real_sequence_len_offset = inference_params.sequence_len_offset + inference_params.sequence_len_offset = inference_params.window_size - 1 + + kv = _update_kv_cache(kv, inference_params, self.layer_idx) + + inference_params.sequence_len_offset = inference_params.real_sequence_len_offset + else: + kv = _update_kv_cache(kv, inference_params, self.layer_idx) + else: + kv = _update_kv_cache(kv, inference_params, self.layer_idx) + + # When using FP16, there is a high probability of NAN in the KV. + # Since NAN cannot be removed by multiplying with and 0, it needs + # to be removed manually here. + kv = torch.where(torch.isnan(kv), 0, kv) + + if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: + assert self.use_flash_attn is True + if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) + attn_mask = inference_params.attention_mask[:, None, ...] + attn_mask = torch.logical_or( + torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask + ) + attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) + cu_seqlens = torch.concat( + [ + torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device), + attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32), + ], + dim=0, + ) + cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32) + max_seqlen_q = attn_mask4flsh.shape[-1] + max_seqlen_k = attn_mask4flsh.shape[-1] + total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) + total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view( + -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] + ) + if self.dtype is torch.float32: + if total_q.dtype not in [torch.float16, torch.bfloat16]: + total_q = total_q.to(torch.bfloat16) + if total_kv.dtype not in [torch.float16, torch.bfloat16]: + total_kv = total_kv.to(torch.bfloat16) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + output = FlashAttnVarlenKVPackedFunc.apply( + total_q, + total_kv, + cu_seqlens, + cu_seqlens, + max_seqlen_q, + max_seqlen_k, + 0.0, + None, + True, + False, + ).to(self.dtype) + else: + output = FlashAttnVarlenKVPackedFunc.apply( + total_q, + total_kv, + cu_seqlens, + cu_seqlens, + max_seqlen_q, + max_seqlen_k, + 0.0, + None, + True, + False, + ) + + context = torch.zeros_like(q) + context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) + + else: + attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1) + if hasattr(inference_params, "window_size") and inference_params.window_size is not None: + if inference_params.window_size <= inference_params.sequence_len_offset: + attn_mask = torch.concat( + [ + attn_mask[..., : inference_params.keep_first], + attn_mask[..., -(inference_params.window_size - inference_params.keep_first) :], + ], + dim=-1, + ) + + k, v = torch.chunk(kv, 2, dim=2) + k = k.squeeze(2) + v = v.squeeze(2) + sp = k.shape + expansion = q.size(2) // k.size(2) + scores = torch.einsum( + "blhd,bnhd->bhln", + q, + k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) / math.sqrt(q.size(-1)) + scores = scores.masked_fill(attn_mask, -65000.0) + scores = F.softmax(scores, dim=-1) # bsz x h x L x L + context = torch.einsum( + "bhmn,bnhd->bmhd", + scores, + v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) + else: + if self.dtype is torch.float32 and self.use_flash_attn: + if q.dtype not in [torch.float16, torch.bfloat16]: + q = q.to(torch.bfloat16) + if kv.dtype not in [torch.float16, torch.bfloat16]: + kv = kv.to(torch.bfloat16) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + context = self.inner_cross_attn(q, kv, causal=True).to(self.dtype) + else: + context = self.inner_cross_attn(q, kv, causal=True) + if seqlen is None: + context = rearrange(context, "b s h d -> b s (h d)") + else: + context = rearrange(context, "b s h d -> (b s) (h d)") + out = self.wo(context) + return out + + def _packed_forward(self, x, inference_params=None, **kwargs): + """ + we delete seqlen=None for lint check, cause this arg is not used. + + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + """ + assert self.use_flash_attn is True + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "t (h d) -> t h d", d=self.head_dim) + k = rearrange(k, "t (h d) -> t h d", d=self.head_dim) + v = rearrange(v, "t (h d) -> t h d", d=self.head_dim) + + # qkv shift + # the rotary embedding in flash attention module in performed by separating the front and back parts, while + # most of others are done by odd-even methods. + if not self.rot_embed_HF_impl: + q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) + k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) + + indexes = kwargs.pop("indexes") + q = self.rotary_emb._single_forward(q, indexes=indexes) + k = self.rotary_emb._single_forward(k, indexes=indexes) + + if inference_params is None: + kv = torch.concat([k.unsqueeze(1), v.unsqueeze(1)], dim=1) + if self.dtype is torch.float32: + if q.dtype not in [torch.float16, torch.bfloat16]: + q = q.to(torch.bfloat16) + if kv.dtype not in [torch.float16, torch.bfloat16]: + kv = kv.to(torch.bfloat16) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + context = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=kwargs["cu_seqlens"], + cu_seqlens_k=kwargs["cu_seqlens"], + max_seqlen_q=kwargs["max_seqlen"], + max_seqlen_k=kwargs["max_seqlen"], + dropout_p=self.inner_cross_attn_dropout, + softmax_scale=self.inner_cross_attn_softmax_scale, + causal=self.inner_cross_attn_causal, + ).to(self.dtype) + else: + context = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=kwargs["cu_seqlens"], + cu_seqlens_k=kwargs["cu_seqlens"], + max_seqlen_q=kwargs["max_seqlen"], + max_seqlen_k=kwargs["max_seqlen"], + dropout_p=self.inner_cross_attn_dropout, + softmax_scale=self.inner_cross_attn_softmax_scale, + causal=self.inner_cross_attn_causal, + ) + else: + raise RuntimeError("Not support this right now") + + context = rearrange(context, "b h d -> b (h d)") # recover shape + out = self.wo(context) + return out + + +class PackedFlashLlamaLayer1D(nn.Module): + """ + 1D Packed Flash Llama Layer. + + Args: + hidden_size (int): The hidden size of model. 768 by default. + num_attention_heads (int): The number of attention heads. 12 by default. + num_kv_attention_heads (int): The number of kv attention heads. 12 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + dtype (torch.dtype): Type of data. torch.float by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + layer_idx (int): The index of current layer. 0 by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + device (Optional[Union[str, torch.device]]): The device will be used. + apply_post_layer_norm (bool): Whether use post layer norm. False by default. + fused_dropout_add_ln (bool): Whether use fused dropout add ln. True by default. + no_bias (bool): Whether remove bias. False by default. + norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. + adapt_hf (bool): Whether adapt hf. False by default. + dropout_selective_checkpoint (bool): Whether use dropout selective checkpoint. True by default. + use_scaled_init (bool): Whether use scaled init. True by default. + use_swiglu (bool): Whether use swiglu. True by default. + use_flash_attn (bool): Whether use flash-attn. True by default. + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + """ + + def __init__( + self, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + mlp_ratio: int = 4, + attn_drop_rate: float = 0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + layer_norm_epsilon: float = 1e-6, + checkpoint: bool = False, + layer_idx: int = 0, + residual_in_fp32: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm: bool = False, + fused_dropout_add_ln: bool = True, + no_bias: bool = False, + norm_type: str = "rmsnorm", + adapt_hf: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + use_flash_attn: bool = True, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, + ): + super().__init__() + self.checkpoint = checkpoint + # dropout selective checkpoint can only be enabled when checkpoint is disabled. + self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False + self.layer_idx = layer_idx + self.use_flash_attn = use_flash_attn + self.prenorm = not apply_post_layer_norm + assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" + self.fused_dropout_add_ln = fused_dropout_add_ln + self.attn_wqkv_init_std = attn_wqkv_init_std + self.attn_other_init_std = attn_other_init_std + self.ffn_uplayer_init_std = ffn_uplayer_init_std + self.ffn_other_init_std = ffn_other_init_std + + head_dim = hidden_size // num_attention_heads + self.attention = MHA( + embed_dim=hidden_size, + num_heads=num_attention_heads, + num_kv_heads=num_kv_attention_heads, + process_group=gpc.get_group(ParallelMode.TENSOR), + dropout=attn_drop_rate, + softmax_scale=1 / math.sqrt(head_dim), + causal=True, + layer_idx=layer_idx, + rotary_emb_dim=head_dim, + rotary_emb_scale_base=0, + use_flash_attn=use_flash_attn, + device=device, + dtype=dtype, + rot_embed_HF_impl=adapt_hf, + bias=not no_bias, + rope_base=rope_base, + ) + + self.dropout1 = nn.Dropout(drop_rate) + if norm_type == "rmsnorm": + self.attention_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) + self.ffn_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) + else: + self.attention_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.ffn_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + if self.fused_dropout_add_ln: + assert dropout_add_layer_norm is not None, "dropout_add_ln is not installed" + assert isinstance(self.attention_norm, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout) + + sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) + if use_swiglu: + self.feed_forward = FeedForward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, + device=device, + dtype=dtype, + ) + else: + self.feed_forward = ParallelFusedMLP( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + activation="gelu_approx", + process_group=gpc.get_group(ParallelMode.TENSOR), + bias1=False, + bias2=False, + sequence_parallel=sequence_parallel, + checkpoint_lvl=0, + heuristic="auto", + device=device, + dtype=dtype, + ) + + for _, param in self.feed_forward.named_parameters(): + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) + for param in self.attention_norm.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) + for param in self.ffn_norm.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) + + self.dropout2 = nn.Dropout(drop_rate) + self.use_swiglu = use_swiglu + self.use_scaled_init = use_scaled_init + self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm + self.return_residual = False + + if init_type == "normal": + self.init_func = normal_ + self.scaled_init_func = scaled_init_method_normal + else: + self.init_func = uniform_ + self.scaled_init_func = scaled_init_method_uniform + + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + for name, param in self.attention.named_parameters(): + if param.ndim == 1: + param.data.zero_() + elif "wq" in name or "wk" in name or "wv" in name: + self.init_func(std=self.attn_wqkv_init_std)(param.data) + elif self.use_scaled_init: # wo + self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.attn_other_init_std)(param.data) + + for name, param in self.feed_forward.named_parameters(): + if self.use_swiglu: + if self.use_scaled_init and "w2" in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func( + std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std + )(param.data) + else: + if self.use_scaled_init and "fc1" not in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( + param.data + ) + + def forward( + self, hidden_states, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None + ): + if self.checkpoint and self.training: + return activation_checkpoint( + self._forward, False, hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen + ) + else: + return self._forward(hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen) + + def _forward( + self, hidden_states=None, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Attn/MLP(LN(residual)) + cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 + indexes: the length of index is same as hidden states, which stand for the current position + """ + if self.prenorm: + + def _dropout_and_norm_attn(_residual, _hidden_states): + _dropped = self.dropout1(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) + else: + residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + mixer_kwargs = { + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "indexes": indexes, + "inference_params": inference_params, + } + hidden_states = self.attention(hidden_states, **mixer_kwargs) + + if not isinstance(self.feed_forward, nn.Identity): + if not self.fused_dropout_add_ln: + + def _dropout_and_norm_ffn(_residual, _hidden_states): + _dropped = self.dropout2(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.ffn_norm(_residual.to(torch.float32)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint( + _dropout_and_norm_ffn, False, residual, hidden_states + ) + else: + residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + hidden_states = self.feed_forward(hidden_states) + + return hidden_states + residual + else: + assert residual is None + mixer_kwargs = { + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "indexes": indexes, + "inference_params": inference_params, + } + mixer_out = self.attention(hidden_states, **mixer_kwargs) + if self.return_residual: # mixer out is actually a pair here + mixer_out, hidden_states = mixer_out + hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( + dtype=self.attention_norm.weight.dtype + ) + if not isinstance(self.feed_forward, nn.Identity): + mlp_out = self.feed_forward(hidden_states) + if self.return_residual: # mlp out is actually a pair here + mlp_out, hidden_states = mlp_out + hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to( + dtype=self.ffn_norm.weight.dtype + ) + return hidden_states + + +class PackedFlashLlama1D(nn.Module): + """ + 1D Packed Flash Llama. + + Args: + num_layers (int): The number of layer. 12 by default. + hidden_size (int): The size of hidden state. 768 by default. + num_attention_heads (int): The number of attention head. 12 by default. + num_kv_attention_heads (int): The number of kv attention head. 12 by default. + vocab_size (int): The size of vocabulary. 50304 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. + drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + dtype (torch.dtype): The type of data. torch.float by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number + of layers. 1.0 by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + first (bool): Whether input embedding layer or not. False by default. + last (bool): Whether output embedding layer or not. False by default. + embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. + True by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + device (Optional[Union[str, torch.device]]): The device will be used. None by default. + apply_post_layer_norm (bool): Whether use post layer norm. False by default. + no_bias (bool): Whether remove bias. False by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. + adapt_hf (bool): Whether adapt hf. False by default. + is_reward (bool): Whether use is_reward. False by default. + dropout_selective_checkpoint (bool): Whether dropout selective checkpoint. True by default. + use_scaled_init (bool): Whether use scaled init. True by default. + use_swiglu (bool): Whether use swiglu. True by default. + use_flash_attn (bool): Whether to use flash-attn. True by default. + embedding_init_std (float): std used to init embedding weight. 0.02 by default, + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + """ + + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + vocab_size: int = 50304, + mlp_ratio: int = 4, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + checkpoint_fraction: float = 1.0, + layer_norm_epsilon: float = 1e-5, + first: bool = False, + last: bool = False, + embed_split_hidden: bool = False, + embed_grad_scale: float = 0.1, + parallel_output: bool = True, + start_layer_idx: int = 0, + device: Optional[torch.device] = None, + apply_post_layer_norm=False, + no_bias=False, + residual_in_fp32: bool = False, + norm_type: str = "rmsnorm", + adapt_hf: bool = False, + is_reward: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + use_flash_attn: bool = True, + embedding_init_std: float = 0.02, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + out_head_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, + ): + super().__init__() + + self.use_flash_attn = use_flash_attn + if checkpoint_fraction <= 0: + checkpoint = False + if not checkpoint: + checkpoint_fraction = 0 + checkpoint_layer_num = num_layers * checkpoint_fraction + sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) + if is_reward: + head_cls = RewardModelLinear + else: + head_cls = ScaleColumnParallelLinear + if first: + if embed_split_hidden: + self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + else: + + self.tok_embeddings = ParallelGPT2Embeddings( + embed_dim=hidden_size, + vocab_size=vocab_size, + max_position_embeddings=-1, + process_group=gpc.get_group(ParallelMode.TENSOR), + padding_idx=None, + sequence_parallel=sequence_parallel, + device=device, + dtype=dtype, + ) + for _, param in self.tok_embeddings.named_parameters(): + if init_type == "normal": + normal_(std=embedding_init_std)(param) + else: + uniform_(std=embedding_init_std)(param) + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) + self.embed_grad_scale = embed_grad_scale + self.layers = nn.ModuleList( + [ + PackedFlashLlamaLayer1D( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_kv_attention_heads=num_kv_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + dtype=dtype, + layer_norm_epsilon=layer_norm_epsilon, + checkpoint=lid < checkpoint_layer_num, + layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation + residual_in_fp32=residual_in_fp32, + device=device, + apply_post_layer_norm=apply_post_layer_norm, + fused_dropout_add_ln=False, + no_bias=no_bias, + norm_type=norm_type, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_swiglu=use_swiglu, + use_flash_attn=use_flash_attn, + adapt_hf=adapt_hf, + attn_wqkv_init_std=attn_wqkv_init_std, + attn_other_init_std=attn_other_init_std, + ffn_uplayer_init_std=ffn_uplayer_init_std, + ffn_other_init_std=ffn_other_init_std, + init_type=init_type, + rope_base=rope_base, + ) + for lid in range(num_layers) + ] + ) + + if last: + if not apply_post_layer_norm: + if norm_type == "rmsnorm": + self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) + else: + self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + for param in self.norm.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) + + self.output = head_cls( + in_features=hidden_size, + out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, + device=device, + dtype=dtype, + weight_scale=embed_grad_scale, + ) + + for _, param in self.output.named_parameters(): + if init_type == "normal": + normal_(std=out_head_init_std)(param) + else: + uniform_(std=out_head_init_std)(param) + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) + + self.parallel_output = parallel_output + + def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + # attention_mask: compute attention on the places where the value is 1 + if hasattr(self, "tok_embeddings"): + hidden_states = self.tok_embeddings(input_ids) + if self.embed_grad_scale != 1: + hidden_states = ( + self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() + ) + if isinstance(cu_seqlens, list): + assert len(cu_seqlens) == 1 + cu_seqlens = cu_seqlens[0].to(hidden_states.device) + + if cu_seqlens is not None: + cu_seqlens = cu_seqlens.squeeze(0) + hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state, + # the batch dimension with a size of 1 should be directly squeezed off. + + if indexes is not None: + assert len(indexes) == 1 + # The indexes are used to indicate the actual position IDs of each token in the packed input. + indexes = indexes[0] + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None + + for _, block in enumerate(self.layers): + hidden_states = block( + hidden_states, + residual=None, + cu_seqlens=cu_seqlens, + indexes=indexes, + inference_params=inference_params, + max_seqlen=max_seqlen, + ) + + if hasattr(self, "norm"): + hidden_states = self.norm(hidden_states.float()) + + extra_hidden_states_list = None + if hasattr(self, "output"): + hidden_states = self.output(hidden_states) + + if not self.parallel_output: + hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) + if extra_hidden_states_list is not None: + extra_hidden_states_list = [ + gather_forward_split_backward(extra_hidden_states, ParallelMode.TENSOR, dim=-1) + for extra_hidden_states in extra_hidden_states_list + ] + + if extra_hidden_states_list is not None: + return (hidden_states, extra_hidden_states_list) + + return hidden_states + + +def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): + """ + build generic model 1d + + Args: + num_layers (int): The number of layer. + num_chunks (int): The number of partitions in pipeline parallel. + device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default. + + """ + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) + parts = all_parts[pipeline_rank] + if gpc.is_rank_for_log(): + logger.info(f"The layer sharding is {all_parts}.") + + models = [] + kwargs["checkpoint_fraction"] = 1.0 + start_idx, end_idx = 0, 0 + for start, end in parts: + start_idx, end_idx = start, end + kwargs["num_layers"] = end - start + kwargs["first"] = start == 0 + # If there is no content in the final layer, assign the last layer. + kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 + kwargs["device"] = device + kwargs["start_layer_idx"] = start + chunk = PackedFlashLlama1D(**filter_kwargs(PackedFlashLlama1D.__init__, kwargs)).to(device) + + models.append(chunk) + torch.distributed.barrier() + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + setattr(model, "first_layer", start_idx) + setattr(model, "last_layer", end_idx) + return model + + +@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) +def build_model_with_cfg( + num_chunks=1, + checkpoint=False, + dtype=torch.float, + embed_split_hidden=False, + num_layers=48, + hidden_size=2048, + vocab_size=50304, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=32, + num_kv_attention_heads=None, + mlp_ratio=4.0, + residual_in_fp32=False, + norm_type="rmsnorm", + adapt_hf=False, + drop_rate=0, + attn_drop_rate=0, + apply_post_layer_norm=False, # pylint: disable=W0613 + no_bias=False, + deepnorm=False, + layer_norm_epsilon=1e-5, + is_reward=False, + dropout_selective_checkpoint=True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + use_flash_attn: bool = True, + embedding_init_std: float = 0.02, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + out_head_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, +): + """ + Builde model with config + + Args: + num_chunks (int): The number of partitions in pipeline parallel. 1 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. + dtype (torch.dtype): The type of data. torch.float by default. + embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. + False by default. + num_layers (int): The number of layer. 48 by default. + hidden_size (int): The size of hidden state. 2048 by default. + vocab_size (int): The size of vocabulary. 50304 by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + num_attention_heads (int): The number of attention head. 32 by default. + num_kv_attention_heads (int): The number of kv attention head. None by default. + mlp_ratio (int): The ratio of MLP layers. 4.0 by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily + because this parameter requires inconsistent data types to be passed between pipelines, + which requires significant modifications to internlm. + norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. + adapt_hf (bool): Whether adapt hf. False by default. + drop_rate (float): The dropout rate of input hidden state. 0 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. + no_bias (bool): Whether remove bias. False by default. + deepnorm (bool): Whether us deepnorm. False by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + is_reward (bool): Whether to use reward model. False by default. + dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. + use_scaled_init (bool): Whether to use scaled init. True by default. + use_swiglu (bool): Whether to use swiglu. True by default. + use_flash_attn (bool): Whether to use flash-attn. True by default. + embedding_init_std (float): std used to init embedding weight. 0.02 by default, + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + """ + if deepnorm: + raise AssertionError("deepnorm will not be supported in future versions." "Use early versions if necessary.") + + cfg = dict( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_kv_attention_heads=num_kv_attention_heads if num_kv_attention_heads else num_attention_heads, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + vocab_size=vocab_size, + embed_grad_scale=embed_grad_scale, + parallel_output=parallel_output, + mlp_ratio=mlp_ratio, + apply_post_layer_norm=apply_post_layer_norm, + no_bias=no_bias, + residual_in_fp32=residual_in_fp32, + norm_type=norm_type, + adapt_hf=adapt_hf, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + layer_norm_epsilon=layer_norm_epsilon, + is_reward=is_reward, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_swiglu=use_swiglu, + use_flash_attn=use_flash_attn, + embedding_init_std=embedding_init_std, + attn_wqkv_init_std=attn_wqkv_init_std, + attn_other_init_std=attn_other_init_std, + ffn_uplayer_init_std=ffn_uplayer_init_std, + ffn_other_init_std=ffn_other_init_std, + out_head_init_std=out_head_init_std, + init_type=init_type, + rope_base=rope_base, + ) + + return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 87a303c..c4b0c3c 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -50,12 +50,16 @@ class CheckpointSaveType(Enum): class CheckpointLoadType(Enum): INTERNLM = "internlm" + HF_LLAMA = "hf_llama" + LLAMA = "llama" # The load method implemented by internlm by default does not use string representation types, # but uses enumeration types defined in advance. LOAD_TYPE_DICT = { "internlm": CheckpointLoadType.INTERNLM, + "hf_llama": CheckpointLoadType.HF_LLAMA, + "llama": CheckpointLoadType.LLAMA, } @@ -74,7 +78,7 @@ class CheckpointLoadMethod: LOAD_TYPE_FUNC = {} @staticmethod - def convet_load_type(load_type: str) -> Union[CheckpointLoadType, str]: + def convert_load_type(load_type: str) -> Union[CheckpointLoadType, str]: if load_type.lower() in LOAD_TYPE_DICT: # The ckpt load method implemented by internlm by default. return LOAD_TYPE_DICT[load_type.lower()] @@ -90,7 +94,11 @@ class CheckpointLoadMethod: CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func}) - if load_type == CheckpointLoadType.INTERNLM: + if load_type in ( + CheckpointLoadType.INTERNLM, + CheckpointLoadType.HF_LLAMA, + CheckpointLoadType.LLAMA, + ): CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func) else: if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG and gpc.is_rank_for_log(): @@ -188,13 +196,33 @@ def load_shard_state_dict(shard_model, shard_state, **kwargs): return (missing_k, unexpected_keys) -def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): +def process_load_info(load_info): load_content_str = "" load_ckpt_folder = load_info["path"] load_content: CheckpointLoadMask = load_info["content"] if gpc.is_rank_for_log(): logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") + return load_content_str, load_ckpt_folder, load_content + + +def try_load_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState): # pylint: disable=W0613 + load_content_str, load_ckpt_folder, load_content = process_load_info(load_info) + if load_content.need_load(CheckpointLoadContent.MODEL): + load_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model) + load_content_str += f"{CheckpointLoadContent.MODEL}, " + + +def try_load_hf_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState): # pylint: disable=W0613 + load_content_str, load_ckpt_folder, load_content = process_load_info(load_info) + if load_content.need_load(CheckpointLoadContent.MODEL): + load_hf_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model) + load_content_str += f"{CheckpointLoadContent.MODEL}, " + + +def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): + load_content_str, load_ckpt_folder, load_content = process_load_info(load_info) + if load_content.need_load(CheckpointLoadContent.MODEL): load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model) load_content_str += f"{CheckpointLoadContent.MODEL}, " @@ -314,6 +342,170 @@ def save_model_checkpoint(folder, model): torch.distributed.barrier() +def load_llama_pretrained_weights(folder, model): + model = model.model + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".pth") or fn.endswith(".pt")] + model_fns.sort() + + old_tp = len(model_fns) + cur_tp = gpc.get_world_size(ParallelMode.TENSOR) + # If the two tp are inconsistent, you need to consider the merge before splitting + if old_tp != cur_tp: + raise RuntimeError( + f"Your current tp is `{cur_tp}`, but the tp in folder:`{folder}` is `{old_tp}`, use `` to convert first" + ) + + states = llm_load(model_fns[gpc.get_local_rank(ParallelMode.TENSOR)], map_location="cpu") + + current_states = {} + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + if gpc.config.model_type == "LLAMA": + # LLAMA's w2 and w3 are in reverse order + w2 = states.pop(f"layers.{i}.feed_forward.w2.weight") + w3 = states.pop(f"layers.{i}.feed_forward.w3.weight") + states[f"layers.{i}.feed_forward.w2.weight"] = w3 + states[f"layers.{i}.feed_forward.w3.weight"] = w2 + if "rope.freqs" in states: + states[f"layers.{i}.attention.rotary_emb.inv_freq"] = states["rope.freqs"] + for name in list(states.keys()): + if f".{i}." in name: + current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name) + + model_state_keys = set(list(model.state_dict().keys())) + + if "tok_embeddings.weight" in model_state_keys: + current_states["tok_embeddings.weight"] = states["tok_embeddings.weight"] + assert model.first_layer == 0, f"Expect model.NaiveAMPModel to be 0, but got {model.first_layer}" + if "output.weight" in model_state_keys: + current_states["norm.weight"] = states["norm.weight"] + current_states["output.weight"] = states["output.weight"] + missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + + del states + del current_states + torch.cuda.empty_cache() + + +def load_hf_llama_pretrained_weights(folder, model): + model = model.model + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".bin") and fn.startswith("pytorch_model")] + model_fns.sort() + + states = {} + + for model_fn in model_fns: + states.update(llm_load(model_fn, map_location="cpu")) + + deep_split = getattr(model, "deep_split", False) + if deep_split: + print("using deep split when loading pretrained weights!") + + current_states = {} + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + if gpc.config.model_type == "LLAMA": + if deep_split: + layer_ids = i // 2 + else: + layer_ids = i + + if not deep_split or (i + 2) % 2 == 0: + states[f"layers.{i}.attention.wq.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + states[f"layers.{i}.attention.wk.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + states[f"layers.{i}.attention.wv.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + states[f"layers.{i}.attention.wo.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=1, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + states[f"layers.{i}.attention_norm.weight"] = states.pop( + f"model.layers.{layer_ids}.input_layernorm.weight" + ) + + if not deep_split or (i + 2) % 2 == 1: + states[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + states[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=0, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + states[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( + states.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), + gpc.get_world_size(ParallelMode.TENSOR), + dim=1, + )[gpc.get_local_rank(ParallelMode.TENSOR)] + + states[f"layers.{i}.ffn_norm.weight"] = states.pop( + f"model.layers.{layer_ids}.post_attention_layernorm.weight" + ) + + if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in states: + states.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq") + for name in list(states.keys()): + if name.startswith(f"layers.{i}"): + current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name) + + model_state_keys = set(list(model.state_dict().keys())) + + if "tok_embeddings.weight" in model_state_keys or "tok_embeddings.word_embeddings.weight" in model_state_keys: + if gpc.config.model.get("embed_split_hidden", True): + current_states["tok_embeddings.weight"] = torch.chunk( + states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1 + )[gpc.get_local_rank(ParallelMode.TENSOR)] + else: + current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk( + states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1 + )[gpc.get_local_rank(ParallelMode.TENSOR)] + assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}" + if "output.weight" in model_state_keys: + current_states["norm.weight"] = states["model.norm.weight"] + current_states["output.weight"] = torch.chunk( + states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0 + )[gpc.get_local_rank(ParallelMode.TENSOR)] + + missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + torch.cuda.empty_cache() + + def load_model_checkpoint(folder, model): """ There should be weights with names similar to the following under the folder. @@ -682,7 +874,11 @@ class CheckpointManager: self.model_config_file = model_config_file # Register defalut internlm ckpt load type. - self.defalut_load_type_func = {CheckpointLoadType.INTERNLM: try_load_internlm_ckpt} + self.defalut_load_type_func = { + CheckpointLoadType.INTERNLM: try_load_internlm_ckpt, + CheckpointLoadType.HF_LLAMA: try_load_hf_LLAMA_ckpt, + CheckpointLoadType.LLAMA: try_load_LLAMA_ckpt, + } for ckpt_load_type in CheckpointLoadType: CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type]) @@ -718,7 +914,7 @@ class CheckpointManager: # replace load_ckpt self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"]) - self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"]) + self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convert_load_type(self.load_ckpt_info["ckpt_type"]) torch.distributed.barrier() # test storage setting is ok. diff --git a/tools/transformers/internlm_model/modeling_internlm.py b/tools/transformers/internlm_model/modeling_internlm.py index e2d52ed..9ea7f17 100644 --- a/tools/transformers/internlm_model/modeling_internlm.py +++ b/tools/transformers/internlm_model/modeling_internlm.py @@ -28,7 +28,6 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN -from transformers.generation.streamers import BaseStreamer from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -42,6 +41,11 @@ from transformers.utils import ( replace_return_docstrings, ) +try: + from transformers.generation.streamers import BaseStreamer +except: # noqa # pylint: disable=bare-except + BaseStreamer = None + from .configuration_internlm import InternLMConfig logger = logging.get_logger(__name__) @@ -113,6 +117,7 @@ class InternLMRotaryEmbedding(torch.nn.Module): base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000. device (Any, optional): Running device. Defaults to None. """ + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) From cc5b15349da0b01f5d1307c3284ce6a0d9ca17e7 Mon Sep 17 00:00:00 2001 From: Pryest <54388244+Pryest@users.noreply.github.com> Date: Mon, 11 Dec 2023 19:36:31 +0800 Subject: [PATCH 2/5] fix(metric): add metric dtype control (#533) * fix(metric): add metric dtype control * fix demo config to avoid implicity * fix default behavior --- configs/7B_MoE4_sft.py | 30 +++++++++++++++++------------- configs/7B_sft.py | 28 ++++++++++++++++------------ internlm/model/metrics.py | 11 ++++++++--- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index 92a93d0..cc94cdc 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -145,18 +145,18 @@ model = dict( moe_use_residual=False, moe_gate_k=2, ) -""" -zero1 parallel: - 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. - 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. - For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. -pipeline parallel (dict): - 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. -tensor parallel: tensor parallel size, usually the number of GPUs per node. -""" + +# zero1 parallel: +# 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, +# so parameters will be divided within the range of dp. +# 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. +# 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. +# For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. +# pipeline parallel (dict): +# 1. size: int, the size of pipeline parallel. +# 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. +# tensor parallel: tensor parallel size, usually the number of GPUs per node. + parallel = dict( zero1=dict(size=-1, fsdp=False), tensor=1, @@ -176,4 +176,8 @@ monitor = dict( ), ) -model_type = "INTERNLM_MoE" \ No newline at end of file +model_type = "INTERNLM_MoE" + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 7d945b4..c0a9bc8 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -146,18 +146,18 @@ model = dict( use_flash_attn=True, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. ) -""" -zero1 parallel: - 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. - 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. - For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. -pipeline parallel (dict): - 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. -tensor parallel: tensor parallel size, usually the number of GPUs per node. -""" + +# zero1 parallel: +# 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, +# so parameters will be divided within the range of dp. +# 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. +# 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. +# For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. +# pipeline parallel (dict): +# 1. size: int, the size of pipeline parallel. +# 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. +# tensor parallel: tensor parallel size, usually the number of GPUs per node. + parallel = dict( zero1=dict(size=8, fsdp=False), tensor=1, @@ -177,3 +177,7 @@ monitor = dict( alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", ), ) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 1f54d06..704d2d6 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -26,7 +26,11 @@ class AccPerplex: self.device = device self.right = torch.Tensor([0]).to(device=device) self.total = torch.Tensor([0]).to(device=device) - self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float) + self.metric_dtype = torch.float if gpc.config.get("metric_dtype", None) == "fp32" else None + if self.metric_dtype is not None: + self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype) + else: + self.total_log_probs = torch.Tensor([0]).to(device=device) self.tp_pg = tp_pg self.dp_pg = dp_pg self.tp_local_rank = torch.distributed.get_rank(self.tp_pg) @@ -128,8 +132,9 @@ class AccPerplex: # All reduce is needed to get the chunks from other GPUs. torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg) - predicted_logits = predicted_logits.to(dtype=torch.float) - shift_logits = shift_logits.to(dtype=torch.float) + if self.metric_dtype is not None: + predicted_logits = predicted_logits.to(dtype=self.metric_dtype) + shift_logits = shift_logits.to(dtype=self.metric_dtype) pred_exp_logits = torch.exp(predicted_logits) # Sum of exponential of logits along vocab dimension across all GPUs. From d904730be7abad1d9bd1028c6e0f67f9a8ac0d4c Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Tue, 12 Dec 2023 13:27:24 +0800 Subject: [PATCH 3/5] feat(ckpt): support auto resume in Volc and Ali (#529) * multipart upload * upload * storage * storage * storage * storage * change ak sk name * change ak sk name * change ak sk name * change ak sk name * storage * storage * auto resume * auto resume * auto resume * bug --- internlm/utils/model_checkpoint.py | 18 ++++++++++++------ internlm/utils/storage_manager.py | 3 +-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index c4b0c3c..bf0b9e9 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -1016,7 +1016,8 @@ now step_count is {train_state.step_count}", torch.distributed.barrier() def query_latest_snapshot_step_boto3(self): - """query_latest_snapshot_step_boto3 + """Query the latest snapshot step from the storage backend. + Currently, we only support the following storage backends: boto3, oss2 and volc. Returns: Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return. """ @@ -1074,6 +1075,7 @@ now step_count is {train_state.step_count}", return load_path, max(snap_step, max_normal_step) def query_latest_snapshot_step_local(self): + """Query the latest snapshot step from the local file system.""" max_step, max_step_path = 0, None save_ckpt_folder = self.save_ckpt_folder.split(":")[1] for root, _, files in os.walk(save_ckpt_folder, followlinks=True): @@ -1090,18 +1092,22 @@ now step_count is {train_state.step_count}", return max_step_path, max_step def query_lastest_ckpt(self): + """Query the latest ckpt via the storage backend.""" latest_ckpt, step = None, -1 # Training was automatically restarted by the process, forcing the latest snapshot to be read. if self.save_ckpt_folder: backend, _ = try_get_storage_backend(self.save_ckpt_folder) - if backend == "boto3": + if backend in ["boto3", "oss2", "volc"]: latest_ckpt, step = self.query_latest_snapshot_step_boto3() - if latest_ckpt and not latest_ckpt.startswith("boto3:"): - latest_ckpt = ":".join(["boto3", latest_ckpt]) elif backend == "local": latest_ckpt, step = self.query_latest_snapshot_step_local() - if latest_ckpt and not latest_ckpt.startswith("local:"): - latest_ckpt = ":".join(["local", latest_ckpt]) + else: + raise NotImplementedError( + f"Unsupported backend: {backend}, " "Currently only support `boto3`, `oss2`, `volc` and `local`" + ) + + if latest_ckpt and not latest_ckpt.startswith(backend + ":"): + latest_ckpt = ":".join([backend, latest_ckpt]) if gpc.is_rank_for_log(): logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...") diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 151af04..53a4e37 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -739,10 +739,9 @@ class AliClient(StorageClient): if AliClient.is_fp_exists(handler, fp): folder_name_list = [] for obj in handler.handler.ObjectIteratorV2(handler.client, prefix=fp): - folder_name_list.append(obj.key.split("/")[-1]) + folder_name_list.append(obj.key.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) return list(set(folder_name_list)) - else: if is_rank_for_log(): logger.warning(f"'{fp}' not found!") From 432bd5ee9ffb8ff2dcb744cb50ef8530d6509f80 Mon Sep 17 00:00:00 2001 From: ytxiong <45058324+yingtongxiong@users.noreply.github.com> Date: Tue, 12 Dec 2023 16:22:39 +0800 Subject: [PATCH 4/5] fix the bug so that the sequence parallel norm is all-reduced when overlap is False (#534) --- internlm/solver/optimizer/hybrid_zero_optim.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 01b40ab..eb7aae3 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -219,10 +219,7 @@ class HybridZeroOptimizer(BaseOptimizer): # flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled. self.skip_grad_reduce = False - # reduction hook is only used if overlapping communication - # if it is stage 1 without overlapping, no hook will be attached - if self._overlap_sync_grad: - self._attach_reduction_hook() + self._attach_reduction_hook() @property def zero_local_rank(self): @@ -321,12 +318,15 @@ class HybridZeroOptimizer(BaseOptimizer): # if sequence_parallel is True, # the grad of norm should be all-reduce across the tp process group - if gpc.config.parallel.sequence_parallel is True: - if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: - accum_grad_obj_sp = get_grad_accumulate_object(param) - accum_grad_obj_sp.register_hook(reduce_grad_hook_sp) + if ( + gpc.config.parallel.sequence_parallel is True + and hasattr(param, IS_SEQUENCE_PARALLEL) + and getattr(param, IS_SEQUENCE_PARALLEL) is True + ): + accum_grad_obj.register_hook(reduce_grad_hook_sp) - accum_grad_obj.register_hook(reduce_grad_hook) + if self._overlap_sync_grad: + accum_grad_obj.register_hook(reduce_grad_hook) _define_and_attach(param, reduce_rank) From 5ecb6aa7124a2e06ce718a8831ceeb3e62f071c9 Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Wed, 13 Dec 2023 14:48:32 +0800 Subject: [PATCH 5/5] fix(pp): fix no-packed dataset load micro batch error (#538) * fix(pp): fix no-packed dataset load micro batch error * fix based on comment --- internlm/core/engine.py | 5 ++ internlm/core/scheduler/base_scheduler.py | 16 +++-- .../core/scheduler/no_pipeline_scheduler.py | 4 +- internlm/core/scheduler/pipeline_scheduler.py | 28 +++++++-- internlm/utils/common.py | 11 ++++ tests/test_data/test_batch_sampler.py | 60 ++++++++++++++++++- 6 files changed, 109 insertions(+), 15 deletions(-) diff --git a/internlm/core/engine.py b/internlm/core/engine.py index a372b9e..eb33e35 100644 --- a/internlm/core/engine.py +++ b/internlm/core/engine.py @@ -185,6 +185,11 @@ class Engine: if to_gpu: batch_data = move_to_device(batch_data) + + # For packed-dataset, batch_data is (micro_num, micro_bsz*seq_len), + # therefore 'batch_size' is equal to 'micro_num' + # For nopacked-dataset, batch_data is (micro_num*micro_bsz, seq_len), + # therefore 'batch_size' is equal to 'micro_num*micro_bsz' batch_size = get_batch_size(batch_data) return batch_data, batch_size diff --git a/internlm/core/scheduler/base_scheduler.py b/internlm/core/scheduler/base_scheduler.py index 14c3457..6e19425 100644 --- a/internlm/core/scheduler/base_scheduler.py +++ b/internlm/core/scheduler/base_scheduler.py @@ -4,7 +4,7 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, Optional +from typing import Any, Callable, Dict, Iterable, Optional import torch @@ -36,10 +36,18 @@ class BaseScheduler(ABC): """ pass - def _load_micro_batch(self, data, label, offset): + def _load_micro_batch(self, data: Dict, label: torch.Tensor, offset: int, bsz_stride: int): + """ + For pp, it will cut one fully batch into micro batch in pipeline concept. + For nopp, it will cut one fully batch into small batch based on gradient accumulate size. + + A special case is that pp uses a 'non-packed-dateset' (such as evaluation dataset), + so the data of batch is unpacked and 'bsz_stride' is equal to 'micro_bsz'. + In all other cases 'bsz_stride' should be equal to 1. + """ assert isinstance(data, dict) and isinstance(label, torch.Tensor) - micro_batch_data = {k: v[offset : offset + 1] for k, v in data.items()} - micro_batch_label = label[offset : offset + 1] + micro_batch_data = {k: v[offset : offset + bsz_stride] for k, v in data.items()} + micro_batch_label = label[offset : offset + bsz_stride] return micro_batch_data, micro_batch_label diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 24d94ef..79a6f62 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -72,7 +72,7 @@ class NonPipelineScheduler(BaseScheduler): label (Any): The label to be loaded. """ - _data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset) + _data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset, bsz_stride=1) self._grad_accum_offset += 1 if self.data_process_func: @@ -167,7 +167,7 @@ class NonPipelineScheduler(BaseScheduler): forward_only or return_loss ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." - batch_data, actual_batch_size = engine.load_batch(data_iter) + batch_data, actual_batch_size = engine.load_batch(data_iter) # actual_batch_size is micro_num self._grad_accum_size = actual_batch_size # Rampup or variable bsz size. diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 550584e..5b864ff 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -14,7 +14,11 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.engine import Engine from internlm.core.naive_amp import NaiveAMPModel -from internlm.utils.common import get_current_device, move_to_device +from internlm.utils.common import ( + check_data_is_packed, + get_current_device, + move_to_device, +) from internlm.utils.logger import get_logger from internlm.utils.timeout import llm_timeout @@ -186,17 +190,28 @@ class PipelineScheduler(BaseScheduler): raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") def load_batch(self, engine, data_iter): - # Pipeline schedule just puts data in memory + # Pipeline schedule just puts data in memory, batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False) - self.num_microbatches = actual_batch_size # Rampup or variable bsz size. + # Even if 'use_flash_attn' is False, the data seen when the 'load_batch' is called is still packed, + # because internlm's current train dataset is packed, even using dummy data. + # The unpack operation is performed in load_micro_batch(). + if check_data_is_packed(batch_data): + micro_num = actual_batch_size + else: + micro_num = actual_batch_size // gpc.config.data["micro_bsz"] + self.microbatch_offset = 0 self.batch_size = actual_batch_size self.batch_data, self.batch_label = batch_data + self.bsz_stride = self.batch_size // micro_num + # 'num_microbatches' is no longer an initialization parameter, + # but is determined on the fly by the Scheduler. + self.num_microbatches = micro_num # Rampup or variable bsz size. def load_micro_batch(self): micro_batch_data, micro_batch_label = self._load_micro_batch( - data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset + data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, bsz_stride=self.bsz_stride ) if self.data_process_func: micro_batch_data["input_ids"] = self.data_process_func( @@ -208,7 +223,7 @@ class PipelineScheduler(BaseScheduler): micro_batch_data.pop("indexes") micro_batch_data["label"] = micro_batch_label - self.microbatch_offset += 1 + self.microbatch_offset += self.bsz_stride return move_to_device(micro_batch_data) @@ -787,9 +802,10 @@ class InterleavedPipelineScheduler(PipelineScheduler): data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset[model_chunk_id], + bsz_stride=self.bsz_stride, ) micro_batch_data["label"] = micro_batch_label - self.microbatch_offset[model_chunk_id] += 1 + self.microbatch_offset[model_chunk_id] += self.bsz_stride return move_to_device(micro_batch_data) def _forward_step(self, engine, chunk_id): diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 6c9cc68..a20b61d 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -110,6 +110,17 @@ def get_batch_size(data): return data[list(data.keys())[0]].size(0) +def check_data_is_packed(data): + if isinstance(data, torch.Tensor): + return False + elif isinstance(data, (list, tuple)): + if isinstance(data[0], dict): + return "indexes" in data[0] + return False + elif isinstance(data, dict): + return "indexes" in data[0] + + def filter_kwargs(func, kwargs): sig = inspect.signature(func) return {k: v for k, v in kwargs.items() if k in sig.parameters} diff --git a/tests/test_data/test_batch_sampler.py b/tests/test_data/test_batch_sampler.py index 2ad10c0..eb835b2 100644 --- a/tests/test_data/test_batch_sampler.py +++ b/tests/test_data/test_batch_sampler.py @@ -10,7 +10,15 @@ from internlm.core.context import global_context as gpc # from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import Config from internlm.core.trainer import TrainState -from internlm.train import get_train_data_loader, load_new_batch +from internlm.train import ( + get_train_data_loader, + get_validation_data_loader, + load_new_batch, +) +from internlm.utils.evaluation import ( + switch_evaluation_no_pipeline_scheduler, + switch_evaluation_pipeline_scheduler, +) # from internlm.core.context.parallel_context import global_context as gpc from tests.test_core.utils import build_environment, init_model_and_optim @@ -20,7 +28,7 @@ use_flash_attens = [True, False] answers = [[1] * 8, [1, 1, 1, 1, 2, 2, 2, 2], [4] * 8, [2, 2, 4, 4, 6, 6, 8, 8]] test_case_group = [ # format: micro_nums, rampup_batch_size, should sccuess, answer, pp size, sql len - # (1, "1 1 1", True, answers[0], 1, 8), + (1, "1 1 1", True, answers[0], 1, 8), (4, "1 1 4", True, answers[1], 1, 8), (4, None, True, answers[2], 1, 8), (8, "2 2 2", True, answers[3], 1, 8), @@ -28,6 +36,11 @@ test_case_group = [ ] +class DummyTrainer: + def __init__(self, scheduler) -> None: + self.schedule = scheduler + + def do_warmup(args): rank, worldsize, init_config, should_sccuess, answer = args build_environment(rank, worldsize, init_config) @@ -44,9 +57,11 @@ def do_warmup(args): ) scheduler.pre_processing(engine) engine.train() + trainer = DummyTrainer(scheduler) try: train_dl, _ = get_train_data_loader(num_worker=0) + val_dls = get_validation_data_loader(num_worker=0) except Exception as e: assert should_sccuess is False, f"{e}" else: @@ -105,6 +120,38 @@ def do_warmup(args): tokens_num == answer[i] * gpc.config.data.seq_len * micro_bsz ), f"{tokens_num} == {answer[i] * gpc.config.data.seq_len * micro_bsz}" + # test no-packed datasets. + for _, val_dl in val_dls.items(): + for _, batch in enumerate(val_dl): + if gpc.is_using_pp(): + total_val_bsz = len(batch[1]) + batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16) + assert total_val_bsz % micro_bsz == 0 + num_microbatches = total_val_bsz // micro_bsz + tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1]]) # toy model hidden size is 8. + with switch_evaluation_pipeline_scheduler( + trainer=trainer, + num_microbatches=num_microbatches, + tensor_shape=tensor_shape, + metric_hook_list=[], + ): + scheduler.forward_backward_step( + engine, batch, forward_only=True, return_loss=False, return_output_label=False + ) + else: + total_val_bsz = len(batch[1]) + batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16) + assert total_val_bsz % micro_bsz == 0 + grad_accum_size = total_val_bsz // micro_bsz + with switch_evaluation_no_pipeline_scheduler( + trainer=trainer, + grad_accum_size=grad_accum_size, + metric_hook_list=[], + ): + scheduler.forward_backward_step( + engine, batch, forward_only=True, return_loss=False, return_output_label=False + ) + @pytest.mark.parametrize("use_flash_atten_case", use_flash_attens) @pytest.mark.parametrize("group_case", test_case_group) @@ -121,7 +168,14 @@ def test_warmup(use_flash_atten_case, group_case, micro_bsz_case): sequence_parallel=False, tensor=1, ), - data=dict(train_folder=None, pack_sample_into_one=False, min_length=0, total_steps=8), + data=dict( + train_folder=None, + valid_folder=None, + valid_micro_num=4, + pack_sample_into_one=False, + min_length=0, + total_steps=8, + ), model=dict( dtype=torch.bfloat16, ),