mirror of https://github.com/InternLM/InternLM
fix(model): change model_type `LLAMA` to `LLAMA2` (#539)
* support hf llama * support hf llama * support hf llama * support hf llama * importerror * importerror * modeling * modeling * fix bugpull/542/head
parent
5ecb6aa712
commit
bbb5651582
|
@ -46,7 +46,7 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
MODEL_TYPE = "LLAMA"
|
||||
MODEL_TYPE = "LLAMA2"
|
||||
|
||||
logger = get_logger(__file__)
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
|
|
|
@ -364,7 +364,7 @@ def load_llama_pretrained_weights(folder, model):
|
|||
|
||||
current_states = {}
|
||||
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
|
||||
if gpc.config.model_type == "LLAMA":
|
||||
if gpc.config.model_type == "LLAMA2":
|
||||
# LLAMA's w2 and w3 are in reverse order
|
||||
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
|
||||
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
|
||||
|
@ -419,7 +419,7 @@ def load_hf_llama_pretrained_weights(folder, model):
|
|||
|
||||
current_states = {}
|
||||
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
|
||||
if gpc.config.model_type == "LLAMA":
|
||||
if gpc.config.model_type == "LLAMA2":
|
||||
if deep_split:
|
||||
layer_ids = i // 2
|
||||
else:
|
||||
|
|
|
@ -29,7 +29,7 @@ class Registry:
|
|||
AssertionError: Raises an AssertionError if the module has already been registered before.
|
||||
"""
|
||||
|
||||
assert module_name not in self._registry, f"{module_name} not found in {self.name}"
|
||||
assert module_name not in self._registry, f"{module_name} already registered in {self.name}"
|
||||
|
||||
def decorator_wrapper(original_func):
|
||||
self._registry[module_name] = original_func
|
||||
|
|
Loading…
Reference in New Issue