mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] polish code
parent
1a29e8fc29
commit
cbb54d3202
|
@ -80,6 +80,7 @@ def default_init(cls, *args, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||||
scores.zero_()
|
scores.zero_()
|
||||||
|
@ -219,6 +220,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
class RMSNorm(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
||||||
|
@ -233,6 +235,7 @@ class RMSNorm(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class CoreAttention(torch.nn.Module):
|
class CoreAttention(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: ChatGLMConfig, layer_number):
|
def __init__(self, config: ChatGLMConfig, layer_number):
|
||||||
super(CoreAttention, self).__init__()
|
super(CoreAttention, self).__init__()
|
||||||
|
|
||||||
|
@ -839,6 +842,7 @@ class Embedding(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMModel(ChatGLMPreTrainedModel):
|
class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
|
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
if empty_init:
|
if empty_init:
|
||||||
|
@ -921,6 +925,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embedding(input_ids)
|
inputs_embeds = self.embedding(input_ids)
|
||||||
|
print(inputs_embeds)
|
||||||
|
|
||||||
if self.pre_seq_len is not None:
|
if self.pre_seq_len is not None:
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
|
@ -982,6 +987,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue