mirror of https://github.com/InternLM/InternLM
fix some info
parent
18bd6429f5
commit
8f8fe84c03
|
@ -407,12 +407,10 @@ def _streaming_no_beam_search_generate(
|
|||
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
|
||||
to_atten_x = bos_pos[:, :, None]
|
||||
to_atten_y = bos_pos[:, None, :]
|
||||
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
||||
else:
|
||||
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
|
||||
to_atten_x = bos_pos[:, :, None]
|
||||
to_atten_y = bos_pos[:, None, :]
|
||||
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
||||
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
|
||||
inference_params.attention_mask = attention_mask
|
||||
scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params})
|
||||
|
|
|
@ -134,8 +134,7 @@ def initialize_internlm_model(
|
|||
"""Initialize internlm model.
|
||||
|
||||
Args:
|
||||
model_type (str): The types of models supported by train_internlm, such as "LLAMA" or "INTERNLM". Note that
|
||||
when loading these models, ``model_type`` can only be "LLAMA".
|
||||
model_type (str): The types of models supported by train_internlm, such as "INTERNLM".
|
||||
ckpt_dir (str): Directory where model checkpoints are stored. Its format needs to be like this:
|
||||
(a) local path, such as: "local:{your local path}";
|
||||
(b) boto3 path, such as: "boto3:s3://{bucket name}.{ip}/{your ceph path}".
|
||||
|
|
Loading…
Reference in New Issue