update tools/load_internlm_model

pull/478/head
YWMditto 2023-11-09 16:39:23 +08:00
parent 7a462c7d3f
commit 3dab742b75
2 changed files with 4 additions and 2 deletions

View File

@ -472,6 +472,8 @@ def _streaming_no_beam_search_generate(
if eos_token_id is not None and add_eos_when_return: if eos_token_id is not None and add_eos_when_return:
token_ids = torch.cat([token_ids, token_ids.new_full((token_ids.size(0), 1), eos_token_id[0])], dim=1) token_ids = torch.cat([token_ids, token_ids.new_full((token_ids.size(0), 1), eos_token_id[0])], dim=1)
yield token_ids
@torch.no_grad() @torch.no_grad()
def _no_beam_search_generate( def _no_beam_search_generate(

View File

@ -139,14 +139,13 @@ def initialize_internlm_model(
ckpt_dir (str): Directory where model checkpoints are stored. Its format needs to be like this: ckpt_dir (str): Directory where model checkpoints are stored. Its format needs to be like this:
(a) local path, such as: "local:/mnt/petrelfs/share_data/llm_llama/codellama_raw/codellama-7b"; (a) local path, such as: "local:/mnt/petrelfs/share_data/llm_llama/codellama_raw/codellama-7b";
(b) boto3 path, such as: "boto3:s3://checkpoints_ssd_02.10.135.7.249/0831/origin_llama/7B". (b) boto3 path, such as: "boto3:s3://checkpoints_ssd_02.10.135.7.249/0831/origin_llama/7B".
model_config (Optional[Union[Dict, str]], optional): Configuration of models. Defaults to None.
del_model_prefix (bool, optional): Whether to remove the "model." string in the key in state_dict. del_model_prefix (bool, optional): Whether to remove the "model." string in the key in state_dict.
Defaults to False. Defaults to False.
param_dtype (torch.dtype, optional): The dtype of the model at inference time. This value can be a string. param_dtype (torch.dtype, optional): The dtype of the model at inference time. This value can be a string.
Use "torch.tf32" when you want to use tf32 to do the inference. Defaults to torch.bfloat16. Use "torch.tf32" when you want to use tf32 to do the inference. Defaults to torch.bfloat16.
training (bool, optional): model.train() or model.eval(). Defaults to False. training (bool, optional): model.train() or model.eval(). Defaults to False.
seed (int, optional): Defaults to 1024. seed (int, optional): Defaults to 1024.
model_config (Optional[Union[Dict, str]], optional): Configuration of models.
Defaults to None.
""" """
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
@ -211,6 +210,7 @@ def initialize_internlm_model(
def get_model_device(model): def get_model_device(model):
for param in model.parameters(): for param in model.parameters():
device = param.device device = param.device
break
return device return device