mirror of https://github.com/InternLM/InternLM
update tools/load_internlm_model
parent
7a462c7d3f
commit
3dab742b75
|
@ -472,6 +472,8 @@ def _streaming_no_beam_search_generate(
|
|||
if eos_token_id is not None and add_eos_when_return:
|
||||
token_ids = torch.cat([token_ids, token_ids.new_full((token_ids.size(0), 1), eos_token_id[0])], dim=1)
|
||||
|
||||
yield token_ids
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _no_beam_search_generate(
|
||||
|
|
|
@ -139,14 +139,13 @@ def initialize_internlm_model(
|
|||
ckpt_dir (str): Directory where model checkpoints are stored. Its format needs to be like this:
|
||||
(a) local path, such as: "local:/mnt/petrelfs/share_data/llm_llama/codellama_raw/codellama-7b";
|
||||
(b) boto3 path, such as: "boto3:s3://checkpoints_ssd_02.10.135.7.249/0831/origin_llama/7B".
|
||||
model_config (Optional[Union[Dict, str]], optional): Configuration of models. Defaults to None.
|
||||
del_model_prefix (bool, optional): Whether to remove the "model." string in the key in state_dict.
|
||||
Defaults to False.
|
||||
param_dtype (torch.dtype, optional): The dtype of the model at inference time. This value can be a string.
|
||||
Use "torch.tf32" when you want to use tf32 to do the inference. Defaults to torch.bfloat16.
|
||||
training (bool, optional): model.train() or model.eval(). Defaults to False.
|
||||
seed (int, optional): Defaults to 1024.
|
||||
model_config (Optional[Union[Dict, str]], optional): Configuration of models.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
|
@ -211,6 +210,7 @@ def initialize_internlm_model(
|
|||
def get_model_device(model):
|
||||
for param in model.parameters():
|
||||
device = param.device
|
||||
break
|
||||
return device
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue