From 3dab742b7578a8660cc30e31446687c629fea170 Mon Sep 17 00:00:00 2001 From: YWMditto <862779238@qq.com> Date: Thu, 9 Nov 2023 16:39:23 +0800 Subject: [PATCH] update tools/load_internlm_model --- internlm/apis/inference.py | 2 ++ tools/load_internlm_model.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/internlm/apis/inference.py b/internlm/apis/inference.py index 5c70c31..307d24b 100644 --- a/internlm/apis/inference.py +++ b/internlm/apis/inference.py @@ -472,6 +472,8 @@ def _streaming_no_beam_search_generate( if eos_token_id is not None and add_eos_when_return: token_ids = torch.cat([token_ids, token_ids.new_full((token_ids.size(0), 1), eos_token_id[0])], dim=1) + yield token_ids + @torch.no_grad() def _no_beam_search_generate( diff --git a/tools/load_internlm_model.py b/tools/load_internlm_model.py index 5cf9cda..7e21b71 100644 --- a/tools/load_internlm_model.py +++ b/tools/load_internlm_model.py @@ -139,14 +139,13 @@ def initialize_internlm_model( ckpt_dir (str): Directory where model checkpoints are stored. Its format needs to be like this: (a) local path, such as: "local:/mnt/petrelfs/share_data/llm_llama/codellama_raw/codellama-7b"; (b) boto3 path, such as: "boto3:s3://checkpoints_ssd_02.10.135.7.249/0831/origin_llama/7B". + model_config (Optional[Union[Dict, str]], optional): Configuration of models. Defaults to None. del_model_prefix (bool, optional): Whether to remove the "model." string in the key in state_dict. Defaults to False. param_dtype (torch.dtype, optional): The dtype of the model at inference time. This value can be a string. Use "torch.tf32" when you want to use tf32 to do the inference. Defaults to torch.bfloat16. training (bool, optional): model.train() or model.eval(). Defaults to False. seed (int, optional): Defaults to 1024. - model_config (Optional[Union[Dict, str]], optional): Configuration of models. - Defaults to None. """ if gpc.is_rank_for_log(): @@ -211,6 +210,7 @@ def initialize_internlm_model( def get_model_device(model): for param in model.parameters(): device = param.device + break return device