From 8f8fe84c0344a90eaf9f8c003131a92b20446f4b Mon Sep 17 00:00:00 2001 From: YWMditto <862779238@qq.com> Date: Thu, 9 Nov 2023 16:58:27 +0800 Subject: [PATCH] fix some info --- internlm/apis/inference.py | 2 -- tools/load_internlm_model.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/internlm/apis/inference.py b/internlm/apis/inference.py index 307d24b..7a51e34 100644 --- a/internlm/apis/inference.py +++ b/internlm/apis/inference.py @@ -407,12 +407,10 @@ def _streaming_no_beam_search_generate( bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) to_atten_x = bos_pos[:, :, None] to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) else: bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) to_atten_x = bos_pos[:, :, None] to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) inference_params.attention_mask = attention_mask scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) diff --git a/tools/load_internlm_model.py b/tools/load_internlm_model.py index e943f2c..123cade 100644 --- a/tools/load_internlm_model.py +++ b/tools/load_internlm_model.py @@ -134,8 +134,7 @@ def initialize_internlm_model( """Initialize internlm model. Args: - model_type (str): The types of models supported by train_internlm, such as "LLAMA" or "INTERNLM". Note that - when loading these models, ``model_type`` can only be "LLAMA". + model_type (str): The types of models supported by train_internlm, such as "INTERNLM". ckpt_dir (str): Directory where model checkpoints are stored. Its format needs to be like this: (a) local path, such as: "local:{your local path}"; (b) boto3 path, such as: "boto3:s3://{bucket name}.{ip}/{your ceph path}".