From cbb54d3202c6935edf11e481fc43929c410fdf1a Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Thu, 13 Jul 2023 19:51:25 +0800 Subject: [PATCH] [shardformer] polish code --- .../model_zoo/transformers/chatglm2_6b/modeling_chatglm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py index 488f24c5f..f704715e1 100644 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py @@ -80,6 +80,7 @@ def default_init(cls, *args, **kwargs): class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() @@ -219,6 +220,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): super().__init__() self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) @@ -233,6 +235,7 @@ class RMSNorm(torch.nn.Module): class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): super(CoreAttention, self).__init__() @@ -839,6 +842,7 @@ class Embedding(torch.nn.Module): class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): super().__init__(config) if empty_init: @@ -921,6 +925,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) + print(inputs_embeds) if self.pre_seq_len is not None: if past_key_values is None: @@ -982,6 +987,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): super().__init__(config)