mirror of https://github.com/InternLM/InternLM
fix some info
parent
18bd6429f5
commit
8f8fe84c03
|
@ -407,12 +407,10 @@ def _streaming_no_beam_search_generate(
|
||||||
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
|
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
|
||||||
to_atten_x = bos_pos[:, :, None]
|
to_atten_x = bos_pos[:, :, None]
|
||||||
to_atten_y = bos_pos[:, None, :]
|
to_atten_y = bos_pos[:, None, :]
|
||||||
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
|
||||||
else:
|
else:
|
||||||
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
|
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
|
||||||
to_atten_x = bos_pos[:, :, None]
|
to_atten_x = bos_pos[:, :, None]
|
||||||
to_atten_y = bos_pos[:, None, :]
|
to_atten_y = bos_pos[:, None, :]
|
||||||
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
|
||||||
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
|
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
|
||||||
inference_params.attention_mask = attention_mask
|
inference_params.attention_mask = attention_mask
|
||||||
scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params})
|
scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params})
|
||||||
|
|
|
@ -134,8 +134,7 @@ def initialize_internlm_model(
|
||||||
"""Initialize internlm model.
|
"""Initialize internlm model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_type (str): The types of models supported by train_internlm, such as "LLAMA" or "INTERNLM". Note that
|
model_type (str): The types of models supported by train_internlm, such as "INTERNLM".
|
||||||
when loading these models, ``model_type`` can only be "LLAMA".
|
|
||||||
ckpt_dir (str): Directory where model checkpoints are stored. Its format needs to be like this:
|
ckpt_dir (str): Directory where model checkpoints are stored. Its format needs to be like this:
|
||||||
(a) local path, such as: "local:{your local path}";
|
(a) local path, such as: "local:{your local path}";
|
||||||
(b) boto3 path, such as: "boto3:s3://{bucket name}.{ip}/{your ceph path}".
|
(b) boto3 path, such as: "boto3:s3://{bucket name}.{ip}/{your ceph path}".
|
||||||
|
|
Loading…
Reference in New Issue