mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/nn/layer/parallel_sequence/layers.py code style (#1280)
Co-authored-by: JThh <jiatong.han@u.nus.edu>pull/1298/head
parent
b414eaa5db
commit
38e3ccd1e9
|
@ -44,8 +44,7 @@ class TransformerSelfAttentionRing(nn.Module):
|
||||||
attn_mask_type=AttnMaskType.padding,
|
attn_mask_type=AttnMaskType.padding,
|
||||||
masked_softmax_fusion=True,
|
masked_softmax_fusion=True,
|
||||||
fp16=False,
|
fp16=False,
|
||||||
bf16=False
|
bf16=False):
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax
|
self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax
|
||||||
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
||||||
|
@ -80,21 +79,14 @@ class TransformerSelfAttentionRing(nn.Module):
|
||||||
self.coeff = layer_number
|
self.coeff = layer_number
|
||||||
self.norm_factor *= self.coeff
|
self.norm_factor *= self.coeff
|
||||||
|
|
||||||
self.scale_mask_softmax = FusedScaleMaskSoftmax(
|
self.scale_mask_softmax = FusedScaleMaskSoftmax(fp16, bf16, self.attn_mask_type, masked_softmax_fusion,
|
||||||
fp16, bf16,
|
self.attention_mask_func, self.convert_fp16_to_fp32_in_softmax,
|
||||||
self.attn_mask_type,
|
self.coeff)
|
||||||
masked_softmax_fusion,
|
|
||||||
self.attention_mask_func,
|
|
||||||
self.convert_fp16_to_fp32_in_softmax,
|
|
||||||
self.coeff)
|
|
||||||
|
|
||||||
self.attention_dropout = nn.Dropout(attention_dropout)
|
self.attention_dropout = nn.Dropout(attention_dropout)
|
||||||
|
|
||||||
# Output.
|
# Output.
|
||||||
self.dense = _Linear(hidden_size,
|
self.dense = _Linear(hidden_size, hidden_size, bias=True, skip_bias_add=True)
|
||||||
hidden_size,
|
|
||||||
bias=True,
|
|
||||||
skip_bias_add=True)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask):
|
def forward(self, hidden_states, attention_mask):
|
||||||
# hidden_states: [sub_seq_len, batch_size, hidden_size]
|
# hidden_states: [sub_seq_len, batch_size, hidden_size]
|
||||||
|
@ -120,30 +112,24 @@ class TransformerSelfAttentionRing(nn.Module):
|
||||||
assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \
|
assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \
|
||||||
'cannot be divided into query, key and value'
|
'cannot be divided into query, key and value'
|
||||||
partition_size = last_dim_value // 3
|
partition_size = last_dim_value // 3
|
||||||
(query_layer, key_layer, value_layer) = torch.split(
|
(query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, partition_size, dim=last_dim)
|
||||||
mixed_x_layer, partition_size, dim=last_dim)
|
|
||||||
|
|
||||||
# attention scores: [batch_size, num_heads, sub_seq_len, seq_len]
|
# attention scores: [batch_size, num_heads, sub_seq_len, seq_len]
|
||||||
output_size = (query_layer.size(1),
|
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0),
|
||||||
query_layer.size(2),
|
|
||||||
query_layer.size(0),
|
|
||||||
key_layer.size(0) * self.world_size)
|
key_layer.size(0) * self.world_size)
|
||||||
|
|
||||||
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
|
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
|
||||||
query_layer = query_layer.view(output_size[2],
|
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
||||||
output_size[0] * output_size[1], -1)
|
|
||||||
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
|
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
|
||||||
key_layer = key_layer.view(key_layer.size(0),
|
key_layer = key_layer.view(key_layer.size(0), output_size[0] * output_size[1], -1)
|
||||||
output_size[0] * output_size[1], -1)
|
|
||||||
|
|
||||||
# attention_scores: [batch_size * num_heads, sub_seq_len, seq_len]
|
# attention_scores: [batch_size * num_heads, sub_seq_len, seq_len]
|
||||||
attention_scores = RingQK.apply(
|
attention_scores = RingQK.apply(
|
||||||
query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size]
|
query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size]
|
||||||
key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size],
|
key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size],
|
||||||
batch_size,
|
batch_size,
|
||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
sub_seq_length
|
sub_seq_length)
|
||||||
)
|
|
||||||
|
|
||||||
attention_scores /= self.norm_factor
|
attention_scores /= self.norm_factor
|
||||||
|
|
||||||
|
@ -158,29 +144,19 @@ class TransformerSelfAttentionRing(nn.Module):
|
||||||
attention_probs = self.attention_dropout(attention_probs)
|
attention_probs = self.attention_dropout(attention_probs)
|
||||||
|
|
||||||
# context layer shape: [batch_size, num_heads, sub_seq_len, head_size]
|
# context layer shape: [batch_size, num_heads, sub_seq_len, head_size]
|
||||||
output_size = (value_layer.size(1),
|
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
||||||
value_layer.size(2),
|
|
||||||
query_layer.size(0),
|
|
||||||
value_layer.size(3))
|
|
||||||
|
|
||||||
# change view [sub_seq_len, batch_size * num_heads, head_size]
|
# change view [sub_seq_len, batch_size * num_heads, head_size]
|
||||||
value_layer = value_layer.contiguous().view(value_layer.size(0),
|
value_layer = value_layer.contiguous().view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
||||||
output_size[0] * output_size[1], -1)
|
|
||||||
|
|
||||||
# # change view [b * num_heads, sub_seq_len, seq_len]
|
# # change view [b * num_heads, sub_seq_len, seq_len]
|
||||||
attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1),
|
attention_probs = attention_probs.view(
|
||||||
attention_probs.size(2),
|
attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3))
|
||||||
attention_probs.size(3))
|
|
||||||
|
|
||||||
# matmul: [batch_size * num_heads, sub_seq_len, head_size]
|
# matmul: [batch_size * num_heads, sub_seq_len, head_size]
|
||||||
context_layer = RingAV.apply(
|
context_layer = RingAV.apply(attention_probs,
|
||||||
attention_probs,
|
value_layer.transpose(0, 1).contiguous(), batch_size, self.num_attention_heads,
|
||||||
value_layer.transpose(0, 1).contiguous(),
|
self.hidden_size_per_attention_head, sub_seq_length)
|
||||||
batch_size,
|
|
||||||
self.num_attention_heads,
|
|
||||||
self.hidden_size_per_attention_head,
|
|
||||||
sub_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# change view [batch_size, num_heads, sub_seq_len, head_size]
|
# change view [batch_size, num_heads, sub_seq_len, head_size]
|
||||||
context_layer = context_layer.view(*output_size)
|
context_layer = context_layer.view(*output_size)
|
||||||
|
@ -189,8 +165,8 @@ class TransformerSelfAttentionRing(nn.Module):
|
||||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||||
|
|
||||||
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size]
|
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size]
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (
|
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_attention_head *
|
||||||
self.hidden_size_per_attention_head * self.num_attention_heads,)
|
self.num_attention_heads,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
output, bias = self.dense(context_layer)
|
output, bias = self.dense(context_layer)
|
||||||
|
@ -224,11 +200,7 @@ class _Linear(nn.Module):
|
||||||
adding bias but instead return it.
|
adding bias but instead return it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, input_size, output_size, bias=True, skip_bias_add=False):
|
||||||
input_size,
|
|
||||||
output_size,
|
|
||||||
bias=True,
|
|
||||||
skip_bias_add=False):
|
|
||||||
super(_Linear, self).__init__()
|
super(_Linear, self).__init__()
|
||||||
|
|
||||||
# Keep input parameters
|
# Keep input parameters
|
||||||
|
@ -236,9 +208,10 @@ class _Linear(nn.Module):
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
|
|
||||||
self.weight = Parameter(torch.empty(self.output_size,
|
self.weight = Parameter(torch.empty(
|
||||||
self.input_size,
|
self.output_size,
|
||||||
))
|
self.input_size,
|
||||||
|
))
|
||||||
nn.init.xavier_normal_(self.weight)
|
nn.init.xavier_normal_(self.weight)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
|
|
Loading…
Reference in New Issue