fix(model): change model_type `LLAMA` to `LLAMA2` (#539)

* support hf llama

* support hf llama

* support hf llama

* support hf llama

* importerror

* importerror

* modeling

* modeling

* fix bug
pull/542/head
jiaxingli 2023-12-13 17:24:45 +08:00 committed by GitHub
parent 5ecb6aa712
commit bbb5651582
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 4 deletions

View File

@ -46,7 +46,7 @@ try:
except ImportError:
pass
MODEL_TYPE = "LLAMA"
MODEL_TYPE = "LLAMA2"
logger = get_logger(__file__)
RMSNorm = try_import_RMSNorm()

View File

@ -364,7 +364,7 @@ def load_llama_pretrained_weights(folder, model):
current_states = {}
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
if gpc.config.model_type == "LLAMA":
if gpc.config.model_type == "LLAMA2":
# LLAMA's w2 and w3 are in reverse order
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
@ -419,7 +419,7 @@ def load_hf_llama_pretrained_weights(folder, model):
current_states = {}
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
if gpc.config.model_type == "LLAMA":
if gpc.config.model_type == "LLAMA2":
if deep_split:
layer_ids = i // 2
else:

View File

@ -29,7 +29,7 @@ class Registry:
AssertionError: Raises an AssertionError if the module has already been registered before.
"""
assert module_name not in self._registry, f"{module_name} not found in {self.name}"
assert module_name not in self._registry, f"{module_name} already registered in {self.name}"
def decorator_wrapper(original_func):
self._registry[module_name] = original_func