[shardformer] polish code

pull/4445/head
klhhhhh 2023-07-13 19:51:25 +08:00 committed by Hongxin Liu
parent 1a29e8fc29
commit cbb54d3202
1 changed files with 6 additions and 0 deletions

View File

@ -80,6 +80,7 @@ def default_init(cls, *args, **kwargs):
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
@ -219,6 +220,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
class RMSNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
@ -233,6 +235,7 @@ class RMSNorm(torch.nn.Module):
class CoreAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number):
super(CoreAttention, self).__init__()
@ -839,6 +842,7 @@ class Embedding(torch.nn.Module):
class ChatGLMModel(ChatGLMPreTrainedModel):
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
super().__init__(config)
if empty_init:
@ -921,6 +925,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
print(inputs_embeds)
if self.pre_seq_len is not None:
if past_key_values is None:
@ -982,6 +987,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
super().__init__(config)