mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] polish code
parent
1a29e8fc29
commit
cbb54d3202
|
@ -80,6 +80,7 @@ def default_init(cls, *args, **kwargs):
|
|||
|
||||
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
|
@ -219,6 +220,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
|
|||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
||||
|
@ -233,6 +235,7 @@ class RMSNorm(torch.nn.Module):
|
|||
|
||||
|
||||
class CoreAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self, config: ChatGLMConfig, layer_number):
|
||||
super(CoreAttention, self).__init__()
|
||||
|
||||
|
@ -839,6 +842,7 @@ class Embedding(torch.nn.Module):
|
|||
|
||||
|
||||
class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
|
||||
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
|
||||
super().__init__(config)
|
||||
if empty_init:
|
||||
|
@ -921,6 +925,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
print(inputs_embeds)
|
||||
|
||||
if self.pre_seq_len is not None:
|
||||
if past_key_values is None:
|
||||
|
@ -982,6 +987,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||
|
||||
|
||||
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
|
||||
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
||||
super().__init__(config)
|
||||
|
||||
|
|
Loading…
Reference in New Issue