diff --git a/internlm/model/modeling_llama.py b/internlm/model/modeling_llama.py index 16af3c2..5b394b9 100644 --- a/internlm/model/modeling_llama.py +++ b/internlm/model/modeling_llama.py @@ -46,7 +46,7 @@ try: except ImportError: pass -MODEL_TYPE = "LLAMA" +MODEL_TYPE = "LLAMA2" logger = get_logger(__file__) RMSNorm = try_import_RMSNorm() diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index bf0b9e9..234944c 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -364,7 +364,7 @@ def load_llama_pretrained_weights(folder, model): current_states = {} for idx, i in enumerate(range(model.first_layer, model.last_layer)): - if gpc.config.model_type == "LLAMA": + if gpc.config.model_type == "LLAMA2": # LLAMA's w2 and w3 are in reverse order w2 = states.pop(f"layers.{i}.feed_forward.w2.weight") w3 = states.pop(f"layers.{i}.feed_forward.w3.weight") @@ -419,7 +419,7 @@ def load_hf_llama_pretrained_weights(folder, model): current_states = {} for idx, i in enumerate(range(model.first_layer, model.last_layer)): - if gpc.config.model_type == "LLAMA": + if gpc.config.model_type == "LLAMA2": if deep_split: layer_ids = i // 2 else: diff --git a/internlm/utils/registry.py b/internlm/utils/registry.py index 7cbfcc5..3ac1445 100644 --- a/internlm/utils/registry.py +++ b/internlm/utils/registry.py @@ -29,7 +29,7 @@ class Registry: AssertionError: Raises an AssertionError if the module has already been registered before. """ - assert module_name not in self._registry, f"{module_name} not found in {self.name}" + assert module_name not in self._registry, f"{module_name} already registered in {self.name}" def decorator_wrapper(original_func): self._registry[module_name] = original_func