update tools/load_internlm_model

pull/478/head
YWMditto 2023-11-09 16:39:23 +08:00
parent 7a462c7d3f
commit 3dab742b75
2 changed files with 4 additions and 2 deletions

View File

@ -472,6 +472,8 @@ def _streaming_no_beam_search_generate(
if eos_token_id is not None and add_eos_when_return:
token_ids = torch.cat([token_ids, token_ids.new_full((token_ids.size(0), 1), eos_token_id[0])], dim=1)
yield token_ids
@torch.no_grad()
def _no_beam_search_generate(

View File

@ -139,14 +139,13 @@ def initialize_internlm_model(
ckpt_dir (str): Directory where model checkpoints are stored. Its format needs to be like this:
(a) local path, such as: "local:/mnt/petrelfs/share_data/llm_llama/codellama_raw/codellama-7b";
(b) boto3 path, such as: "boto3:s3://checkpoints_ssd_02.10.135.7.249/0831/origin_llama/7B".
model_config (Optional[Union[Dict, str]], optional): Configuration of models. Defaults to None.
del_model_prefix (bool, optional): Whether to remove the "model." string in the key in state_dict.
Defaults to False.
param_dtype (torch.dtype, optional): The dtype of the model at inference time. This value can be a string.
Use "torch.tf32" when you want to use tf32 to do the inference. Defaults to torch.bfloat16.
training (bool, optional): model.train() or model.eval(). Defaults to False.
seed (int, optional): Defaults to 1024.
model_config (Optional[Union[Dict, str]], optional): Configuration of models.
Defaults to None.
"""
if gpc.is_rank_for_log():
@ -211,6 +210,7 @@ def initialize_internlm_model(
def get_model_device(model):
for param in model.parameters():
device = param.device
break
return device