From bbb5651582fbf1e07fa394c55804429e1b719da7 Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Wed, 13 Dec 2023 17:24:45 +0800 Subject: [PATCH] fix(model): change model_type `LLAMA` to `LLAMA2` (#539) * support hf llama * support hf llama * support hf llama * support hf llama * importerror * importerror * modeling * modeling * fix bug --- internlm/model/modeling_llama.py | 2 +- internlm/utils/model_checkpoint.py | 4 ++-- internlm/utils/registry.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internlm/model/modeling_llama.py b/internlm/model/modeling_llama.py index 16af3c2..5b394b9 100644 --- a/internlm/model/modeling_llama.py +++ b/internlm/model/modeling_llama.py @@ -46,7 +46,7 @@ try: except ImportError: pass -MODEL_TYPE = "LLAMA" +MODEL_TYPE = "LLAMA2" logger = get_logger(__file__) RMSNorm = try_import_RMSNorm() diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index bf0b9e9..234944c 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -364,7 +364,7 @@ def load_llama_pretrained_weights(folder, model): current_states = {} for idx, i in enumerate(range(model.first_layer, model.last_layer)): - if gpc.config.model_type == "LLAMA": + if gpc.config.model_type == "LLAMA2": # LLAMA's w2 and w3 are in reverse order w2 = states.pop(f"layers.{i}.feed_forward.w2.weight") w3 = states.pop(f"layers.{i}.feed_forward.w3.weight") @@ -419,7 +419,7 @@ def load_hf_llama_pretrained_weights(folder, model): current_states = {} for idx, i in enumerate(range(model.first_layer, model.last_layer)): - if gpc.config.model_type == "LLAMA": + if gpc.config.model_type == "LLAMA2": if deep_split: layer_ids = i // 2 else: diff --git a/internlm/utils/registry.py b/internlm/utils/registry.py index 7cbfcc5..3ac1445 100644 --- a/internlm/utils/registry.py +++ b/internlm/utils/registry.py @@ -29,7 +29,7 @@ class Registry: AssertionError: Raises an AssertionError if the module has already been registered before. """ - assert module_name not in self._registry, f"{module_name} not found in {self.name}" + assert module_name not in self._registry, f"{module_name} already registered in {self.name}" def decorator_wrapper(original_func): self._registry[module_name] = original_func