mirror of https://github.com/InternLM/InternLM
fix(model): change model_type `LLAMA` to `LLAMA2` (#539)
* support hf llama * support hf llama * support hf llama * support hf llama * importerror * importerror * modeling * modeling * fix bugpull/542/head
parent
5ecb6aa712
commit
bbb5651582
|
@ -46,7 +46,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
MODEL_TYPE = "LLAMA"
|
MODEL_TYPE = "LLAMA2"
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
RMSNorm = try_import_RMSNorm()
|
RMSNorm = try_import_RMSNorm()
|
||||||
|
|
|
@ -364,7 +364,7 @@ def load_llama_pretrained_weights(folder, model):
|
||||||
|
|
||||||
current_states = {}
|
current_states = {}
|
||||||
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
|
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
|
||||||
if gpc.config.model_type == "LLAMA":
|
if gpc.config.model_type == "LLAMA2":
|
||||||
# LLAMA's w2 and w3 are in reverse order
|
# LLAMA's w2 and w3 are in reverse order
|
||||||
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
|
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
|
||||||
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
|
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
|
||||||
|
@ -419,7 +419,7 @@ def load_hf_llama_pretrained_weights(folder, model):
|
||||||
|
|
||||||
current_states = {}
|
current_states = {}
|
||||||
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
|
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
|
||||||
if gpc.config.model_type == "LLAMA":
|
if gpc.config.model_type == "LLAMA2":
|
||||||
if deep_split:
|
if deep_split:
|
||||||
layer_ids = i // 2
|
layer_ids = i // 2
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -29,7 +29,7 @@ class Registry:
|
||||||
AssertionError: Raises an AssertionError if the module has already been registered before.
|
AssertionError: Raises an AssertionError if the module has already been registered before.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert module_name not in self._registry, f"{module_name} not found in {self.name}"
|
assert module_name not in self._registry, f"{module_name} already registered in {self.name}"
|
||||||
|
|
||||||
def decorator_wrapper(original_func):
|
def decorator_wrapper(original_func):
|
||||||
self._registry[module_name] = original_func
|
self._registry[module_name] = original_func
|
||||||
|
|
Loading…
Reference in New Issue