mirror of https://github.com/InternLM/InternLM
				
				
				
			feat(model): support llama model with checkpoint loading (#532)
* support hf llama * support hf llama * support hf llama * support hf llama * importerror * importerror * modeling * modelingpull/538/head
							parent
							
								
									81ffb3d824
								
							
						
					
					
						commit
						6c0ff4820f
					
				| 
						 | 
				
			
			@ -28,7 +28,7 @@ ckpt = dict(
 | 
			
		|||
    # 'load_ckpt_info' setting guide:
 | 
			
		||||
    # 1. the 'path' indicate ckpt path,
 | 
			
		||||
    # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
 | 
			
		||||
    # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
 | 
			
		||||
    # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internlm", "llama", "hf_llama".
 | 
			
		||||
    load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
 | 
			
		||||
    # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
 | 
			
		||||
    # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,6 +5,7 @@ from .embedding import Embedding1D, RotaryEmbedding
 | 
			
		|||
from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
 | 
			
		||||
from .metrics import AccPerplex
 | 
			
		||||
from .modeling_internlm import build_model_with_cfg
 | 
			
		||||
from .modeling_llama import build_model_with_cfg as build_model_with_llama_cfg
 | 
			
		||||
from .modeling_moe import build_model_with_moe_cfg
 | 
			
		||||
from .moe import MoE
 | 
			
		||||
from .multi_head_attention import MHA
 | 
			
		||||
| 
						 | 
				
			
			@ -22,4 +23,5 @@ __all__ = [
 | 
			
		|||
    "gather_forward_split_backward",
 | 
			
		||||
    "build_model_with_cfg",
 | 
			
		||||
    "build_model_with_moe_cfg",
 | 
			
		||||
    "build_model_with_llama_cfg",
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
				
			
			@ -50,12 +50,16 @@ class CheckpointSaveType(Enum):
 | 
			
		|||
 | 
			
		||||
class CheckpointLoadType(Enum):
 | 
			
		||||
    INTERNLM = "internlm"
 | 
			
		||||
    HF_LLAMA = "hf_llama"
 | 
			
		||||
    LLAMA = "llama"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# The load method implemented by internlm by default does not use string representation types,
 | 
			
		||||
# but uses enumeration types defined in advance.
 | 
			
		||||
LOAD_TYPE_DICT = {
 | 
			
		||||
    "internlm": CheckpointLoadType.INTERNLM,
 | 
			
		||||
    "hf_llama": CheckpointLoadType.HF_LLAMA,
 | 
			
		||||
    "llama": CheckpointLoadType.LLAMA,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -74,7 +78,7 @@ class CheckpointLoadMethod:
 | 
			
		|||
    LOAD_TYPE_FUNC = {}
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def convet_load_type(load_type: str) -> Union[CheckpointLoadType, str]:
 | 
			
		||||
    def convert_load_type(load_type: str) -> Union[CheckpointLoadType, str]:
 | 
			
		||||
        if load_type.lower() in LOAD_TYPE_DICT:
 | 
			
		||||
            # The ckpt load method implemented by internlm by default.
 | 
			
		||||
            return LOAD_TYPE_DICT[load_type.lower()]
 | 
			
		||||
| 
						 | 
				
			
			@ -90,7 +94,11 @@ class CheckpointLoadMethod:
 | 
			
		|||
 | 
			
		||||
        CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func})
 | 
			
		||||
 | 
			
		||||
        if load_type == CheckpointLoadType.INTERNLM:
 | 
			
		||||
        if load_type in (
 | 
			
		||||
            CheckpointLoadType.INTERNLM,
 | 
			
		||||
            CheckpointLoadType.HF_LLAMA,
 | 
			
		||||
            CheckpointLoadType.LLAMA,
 | 
			
		||||
        ):
 | 
			
		||||
            CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func)
 | 
			
		||||
        else:
 | 
			
		||||
            if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG and gpc.is_rank_for_log():
 | 
			
		||||
| 
						 | 
				
			
			@ -188,13 +196,33 @@ def load_shard_state_dict(shard_model, shard_state, **kwargs):
 | 
			
		|||
    return (missing_k, unexpected_keys)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
 | 
			
		||||
def process_load_info(load_info):
 | 
			
		||||
    load_content_str = ""
 | 
			
		||||
    load_ckpt_folder = load_info["path"]
 | 
			
		||||
    load_content: CheckpointLoadMask = load_info["content"]
 | 
			
		||||
    if gpc.is_rank_for_log():
 | 
			
		||||
        logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
 | 
			
		||||
 | 
			
		||||
    return load_content_str, load_ckpt_folder, load_content
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_load_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState):  # pylint: disable=W0613
 | 
			
		||||
    load_content_str, load_ckpt_folder, load_content = process_load_info(load_info)
 | 
			
		||||
    if load_content.need_load(CheckpointLoadContent.MODEL):
 | 
			
		||||
        load_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model)
 | 
			
		||||
        load_content_str += f"{CheckpointLoadContent.MODEL}, "
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_load_hf_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState):  # pylint: disable=W0613
 | 
			
		||||
    load_content_str, load_ckpt_folder, load_content = process_load_info(load_info)
 | 
			
		||||
    if load_content.need_load(CheckpointLoadContent.MODEL):
 | 
			
		||||
        load_hf_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model)
 | 
			
		||||
        load_content_str += f"{CheckpointLoadContent.MODEL}, "
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
 | 
			
		||||
    load_content_str, load_ckpt_folder, load_content = process_load_info(load_info)
 | 
			
		||||
 | 
			
		||||
    if load_content.need_load(CheckpointLoadContent.MODEL):
 | 
			
		||||
        load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model)
 | 
			
		||||
        load_content_str += f"{CheckpointLoadContent.MODEL}, "
 | 
			
		||||
| 
						 | 
				
			
			@ -314,6 +342,170 @@ def save_model_checkpoint(folder, model):
 | 
			
		|||
    torch.distributed.barrier()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_llama_pretrained_weights(folder, model):
 | 
			
		||||
    model = model.model
 | 
			
		||||
    assert folder is not None, "Please specify the folder of the pretrained model"
 | 
			
		||||
    if gpc.is_rank_for_log():
 | 
			
		||||
        logger.info(f"Loading pretrained model from {folder}")
 | 
			
		||||
 | 
			
		||||
    fns = get_fns(folder)
 | 
			
		||||
    model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".pth") or fn.endswith(".pt")]
 | 
			
		||||
    model_fns.sort()
 | 
			
		||||
 | 
			
		||||
    old_tp = len(model_fns)
 | 
			
		||||
    cur_tp = gpc.get_world_size(ParallelMode.TENSOR)
 | 
			
		||||
    # If the two tp are inconsistent, you need to consider the merge before splitting
 | 
			
		||||
    if old_tp != cur_tp:
 | 
			
		||||
        raise RuntimeError(
 | 
			
		||||
            f"Your current tp is `{cur_tp}`, but the tp in folder:`{folder}` is `{old_tp}`, use `` to convert first"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    states = llm_load(model_fns[gpc.get_local_rank(ParallelMode.TENSOR)], map_location="cpu")
 | 
			
		||||
 | 
			
		||||
    current_states = {}
 | 
			
		||||
    for idx, i in enumerate(range(model.first_layer, model.last_layer)):
 | 
			
		||||
        if gpc.config.model_type == "LLAMA":
 | 
			
		||||
            # LLAMA's w2 and w3 are in reverse order
 | 
			
		||||
            w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
 | 
			
		||||
            w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
 | 
			
		||||
            states[f"layers.{i}.feed_forward.w2.weight"] = w3
 | 
			
		||||
            states[f"layers.{i}.feed_forward.w3.weight"] = w2
 | 
			
		||||
        if "rope.freqs" in states:
 | 
			
		||||
            states[f"layers.{i}.attention.rotary_emb.inv_freq"] = states["rope.freqs"]
 | 
			
		||||
        for name in list(states.keys()):
 | 
			
		||||
            if f".{i}." in name:
 | 
			
		||||
                current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)
 | 
			
		||||
 | 
			
		||||
    model_state_keys = set(list(model.state_dict().keys()))
 | 
			
		||||
 | 
			
		||||
    if "tok_embeddings.weight" in model_state_keys:
 | 
			
		||||
        current_states["tok_embeddings.weight"] = states["tok_embeddings.weight"]
 | 
			
		||||
        assert model.first_layer == 0, f"Expect model.NaiveAMPModel to be 0, but got {model.first_layer}"
 | 
			
		||||
    if "output.weight" in model_state_keys:
 | 
			
		||||
        current_states["norm.weight"] = states["norm.weight"]
 | 
			
		||||
        current_states["output.weight"] = states["output.weight"]
 | 
			
		||||
    missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
 | 
			
		||||
 | 
			
		||||
    if gpc.get_local_rank(ParallelMode.DATA) == 0:
 | 
			
		||||
        pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
 | 
			
		||||
        logger.info(
 | 
			
		||||
            f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
 | 
			
		||||
            f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    del states
 | 
			
		||||
    del current_states
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_hf_llama_pretrained_weights(folder, model):
 | 
			
		||||
    model = model.model
 | 
			
		||||
    assert folder is not None, "Please specify the folder of the pretrained model"
 | 
			
		||||
    if gpc.is_rank_for_log():
 | 
			
		||||
        logger.info(f"Loading pretrained model from {folder}")
 | 
			
		||||
 | 
			
		||||
    fns = get_fns(folder)
 | 
			
		||||
    model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".bin") and fn.startswith("pytorch_model")]
 | 
			
		||||
    model_fns.sort()
 | 
			
		||||
 | 
			
		||||
    states = {}
 | 
			
		||||
 | 
			
		||||
    for model_fn in model_fns:
 | 
			
		||||
        states.update(llm_load(model_fn, map_location="cpu"))
 | 
			
		||||
 | 
			
		||||
    deep_split = getattr(model, "deep_split", False)
 | 
			
		||||
    if deep_split:
 | 
			
		||||
        print("using deep split when loading pretrained weights!")
 | 
			
		||||
 | 
			
		||||
    current_states = {}
 | 
			
		||||
    for idx, i in enumerate(range(model.first_layer, model.last_layer)):
 | 
			
		||||
        if gpc.config.model_type == "LLAMA":
 | 
			
		||||
            if deep_split:
 | 
			
		||||
                layer_ids = i // 2
 | 
			
		||||
            else:
 | 
			
		||||
                layer_ids = i
 | 
			
		||||
 | 
			
		||||
            if not deep_split or (i + 2) % 2 == 0:
 | 
			
		||||
                states[f"layers.{i}.attention.wq.weight"] = torch.chunk(
 | 
			
		||||
                    states.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"),
 | 
			
		||||
                    gpc.get_world_size(ParallelMode.TENSOR),
 | 
			
		||||
                    dim=0,
 | 
			
		||||
                )[gpc.get_local_rank(ParallelMode.TENSOR)]
 | 
			
		||||
                states[f"layers.{i}.attention.wk.weight"] = torch.chunk(
 | 
			
		||||
                    states.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"),
 | 
			
		||||
                    gpc.get_world_size(ParallelMode.TENSOR),
 | 
			
		||||
                    dim=0,
 | 
			
		||||
                )[gpc.get_local_rank(ParallelMode.TENSOR)]
 | 
			
		||||
                states[f"layers.{i}.attention.wv.weight"] = torch.chunk(
 | 
			
		||||
                    states.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"),
 | 
			
		||||
                    gpc.get_world_size(ParallelMode.TENSOR),
 | 
			
		||||
                    dim=0,
 | 
			
		||||
                )[gpc.get_local_rank(ParallelMode.TENSOR)]
 | 
			
		||||
                states[f"layers.{i}.attention.wo.weight"] = torch.chunk(
 | 
			
		||||
                    states.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
 | 
			
		||||
                    gpc.get_world_size(ParallelMode.TENSOR),
 | 
			
		||||
                    dim=1,
 | 
			
		||||
                )[gpc.get_local_rank(ParallelMode.TENSOR)]
 | 
			
		||||
                states[f"layers.{i}.attention_norm.weight"] = states.pop(
 | 
			
		||||
                    f"model.layers.{layer_ids}.input_layernorm.weight"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if not deep_split or (i + 2) % 2 == 1:
 | 
			
		||||
                states[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
 | 
			
		||||
                    states.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
 | 
			
		||||
                    gpc.get_world_size(ParallelMode.TENSOR),
 | 
			
		||||
                    dim=0,
 | 
			
		||||
                )[gpc.get_local_rank(ParallelMode.TENSOR)]
 | 
			
		||||
                states[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
 | 
			
		||||
                    states.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
 | 
			
		||||
                    gpc.get_world_size(ParallelMode.TENSOR),
 | 
			
		||||
                    dim=0,
 | 
			
		||||
                )[gpc.get_local_rank(ParallelMode.TENSOR)]
 | 
			
		||||
                states[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
 | 
			
		||||
                    states.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
 | 
			
		||||
                    gpc.get_world_size(ParallelMode.TENSOR),
 | 
			
		||||
                    dim=1,
 | 
			
		||||
                )[gpc.get_local_rank(ParallelMode.TENSOR)]
 | 
			
		||||
 | 
			
		||||
                states[f"layers.{i}.ffn_norm.weight"] = states.pop(
 | 
			
		||||
                    f"model.layers.{layer_ids}.post_attention_layernorm.weight"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in states:
 | 
			
		||||
                states.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq")
 | 
			
		||||
        for name in list(states.keys()):
 | 
			
		||||
            if name.startswith(f"layers.{i}"):
 | 
			
		||||
                current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)
 | 
			
		||||
 | 
			
		||||
    model_state_keys = set(list(model.state_dict().keys()))
 | 
			
		||||
 | 
			
		||||
    if "tok_embeddings.weight" in model_state_keys or "tok_embeddings.word_embeddings.weight" in model_state_keys:
 | 
			
		||||
        if gpc.config.model.get("embed_split_hidden", True):
 | 
			
		||||
            current_states["tok_embeddings.weight"] = torch.chunk(
 | 
			
		||||
                states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
 | 
			
		||||
            )[gpc.get_local_rank(ParallelMode.TENSOR)]
 | 
			
		||||
        else:
 | 
			
		||||
            current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk(
 | 
			
		||||
                states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
 | 
			
		||||
            )[gpc.get_local_rank(ParallelMode.TENSOR)]
 | 
			
		||||
        assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}"
 | 
			
		||||
    if "output.weight" in model_state_keys:
 | 
			
		||||
        current_states["norm.weight"] = states["model.norm.weight"]
 | 
			
		||||
        current_states["output.weight"] = torch.chunk(
 | 
			
		||||
            states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
 | 
			
		||||
        )[gpc.get_local_rank(ParallelMode.TENSOR)]
 | 
			
		||||
 | 
			
		||||
    missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
 | 
			
		||||
 | 
			
		||||
    if gpc.get_local_rank(ParallelMode.DATA) == 0:
 | 
			
		||||
        pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
 | 
			
		||||
        logger.info(
 | 
			
		||||
            f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
 | 
			
		||||
            f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
 | 
			
		||||
        )
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model_checkpoint(folder, model):
 | 
			
		||||
    """
 | 
			
		||||
    There should be weights with names similar to the following under the folder.
 | 
			
		||||
| 
						 | 
				
			
			@ -682,7 +874,11 @@ class CheckpointManager:
 | 
			
		|||
        self.model_config_file = model_config_file
 | 
			
		||||
 | 
			
		||||
        # Register defalut internlm ckpt load type.
 | 
			
		||||
        self.defalut_load_type_func = {CheckpointLoadType.INTERNLM: try_load_internlm_ckpt}
 | 
			
		||||
        self.defalut_load_type_func = {
 | 
			
		||||
            CheckpointLoadType.INTERNLM: try_load_internlm_ckpt,
 | 
			
		||||
            CheckpointLoadType.HF_LLAMA: try_load_hf_LLAMA_ckpt,
 | 
			
		||||
            CheckpointLoadType.LLAMA: try_load_LLAMA_ckpt,
 | 
			
		||||
        }
 | 
			
		||||
        for ckpt_load_type in CheckpointLoadType:
 | 
			
		||||
            CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -718,7 +914,7 @@ class CheckpointManager:
 | 
			
		|||
 | 
			
		||||
            # replace load_ckpt
 | 
			
		||||
            self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
 | 
			
		||||
            self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"])
 | 
			
		||||
            self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convert_load_type(self.load_ckpt_info["ckpt_type"])
 | 
			
		||||
 | 
			
		||||
        torch.distributed.barrier()
 | 
			
		||||
        # test storage setting is ok.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,7 +28,6 @@ import torch.utils.checkpoint
 | 
			
		|||
from torch import nn
 | 
			
		||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
from transformers.generation.streamers import BaseStreamer
 | 
			
		||||
from transformers.modeling_outputs import (
 | 
			
		||||
    BaseModelOutputWithPast,
 | 
			
		||||
    CausalLMOutputWithPast,
 | 
			
		||||
| 
						 | 
				
			
			@ -42,6 +41,11 @@ from transformers.utils import (
 | 
			
		|||
    replace_return_docstrings,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from transformers.generation.streamers import BaseStreamer
 | 
			
		||||
except:  # noqa # pylint: disable=bare-except
 | 
			
		||||
    BaseStreamer = None
 | 
			
		||||
 | 
			
		||||
from .configuration_internlm import InternLMConfig
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
| 
						 | 
				
			
			@ -113,6 +117,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
 | 
			
		|||
        base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
 | 
			
		||||
        device (Any, optional): Running device. Defaults to None.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue