mirror of https://github.com/InternLM/InternLM
update tools/load_internlm_model
parent
7a462c7d3f
commit
3dab742b75
|
@ -472,6 +472,8 @@ def _streaming_no_beam_search_generate(
|
||||||
if eos_token_id is not None and add_eos_when_return:
|
if eos_token_id is not None and add_eos_when_return:
|
||||||
token_ids = torch.cat([token_ids, token_ids.new_full((token_ids.size(0), 1), eos_token_id[0])], dim=1)
|
token_ids = torch.cat([token_ids, token_ids.new_full((token_ids.size(0), 1), eos_token_id[0])], dim=1)
|
||||||
|
|
||||||
|
yield token_ids
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _no_beam_search_generate(
|
def _no_beam_search_generate(
|
||||||
|
|
|
@ -139,14 +139,13 @@ def initialize_internlm_model(
|
||||||
ckpt_dir (str): Directory where model checkpoints are stored. Its format needs to be like this:
|
ckpt_dir (str): Directory where model checkpoints are stored. Its format needs to be like this:
|
||||||
(a) local path, such as: "local:/mnt/petrelfs/share_data/llm_llama/codellama_raw/codellama-7b";
|
(a) local path, such as: "local:/mnt/petrelfs/share_data/llm_llama/codellama_raw/codellama-7b";
|
||||||
(b) boto3 path, such as: "boto3:s3://checkpoints_ssd_02.10.135.7.249/0831/origin_llama/7B".
|
(b) boto3 path, such as: "boto3:s3://checkpoints_ssd_02.10.135.7.249/0831/origin_llama/7B".
|
||||||
|
model_config (Optional[Union[Dict, str]], optional): Configuration of models. Defaults to None.
|
||||||
del_model_prefix (bool, optional): Whether to remove the "model." string in the key in state_dict.
|
del_model_prefix (bool, optional): Whether to remove the "model." string in the key in state_dict.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
param_dtype (torch.dtype, optional): The dtype of the model at inference time. This value can be a string.
|
param_dtype (torch.dtype, optional): The dtype of the model at inference time. This value can be a string.
|
||||||
Use "torch.tf32" when you want to use tf32 to do the inference. Defaults to torch.bfloat16.
|
Use "torch.tf32" when you want to use tf32 to do the inference. Defaults to torch.bfloat16.
|
||||||
training (bool, optional): model.train() or model.eval(). Defaults to False.
|
training (bool, optional): model.train() or model.eval(). Defaults to False.
|
||||||
seed (int, optional): Defaults to 1024.
|
seed (int, optional): Defaults to 1024.
|
||||||
model_config (Optional[Union[Dict, str]], optional): Configuration of models.
|
|
||||||
Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
|
@ -211,6 +210,7 @@ def initialize_internlm_model(
|
||||||
def get_model_device(model):
|
def get_model_device(model):
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
device = param.device
|
device = param.device
|
||||||
|
break
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue