diff --git a/colossalai/nn/layer/parallel_sequence/layers.py b/colossalai/nn/layer/parallel_sequence/layers.py index 55c400c18..d9486217b 100644 --- a/colossalai/nn/layer/parallel_sequence/layers.py +++ b/colossalai/nn/layer/parallel_sequence/layers.py @@ -44,8 +44,7 @@ class TransformerSelfAttentionRing(nn.Module): attn_mask_type=AttnMaskType.padding, masked_softmax_fusion=True, fp16=False, - bf16=False - ): + bf16=False): super().__init__() self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax self.apply_query_key_layer_scaling = apply_query_key_layer_scaling @@ -80,21 +79,14 @@ class TransformerSelfAttentionRing(nn.Module): self.coeff = layer_number self.norm_factor *= self.coeff - self.scale_mask_softmax = FusedScaleMaskSoftmax( - fp16, bf16, - self.attn_mask_type, - masked_softmax_fusion, - self.attention_mask_func, - self.convert_fp16_to_fp32_in_softmax, - self.coeff) + self.scale_mask_softmax = FusedScaleMaskSoftmax(fp16, bf16, self.attn_mask_type, masked_softmax_fusion, + self.attention_mask_func, self.convert_fp16_to_fp32_in_softmax, + self.coeff) self.attention_dropout = nn.Dropout(attention_dropout) # Output. - self.dense = _Linear(hidden_size, - hidden_size, - bias=True, - skip_bias_add=True) + self.dense = _Linear(hidden_size, hidden_size, bias=True, skip_bias_add=True) def forward(self, hidden_states, attention_mask): # hidden_states: [sub_seq_len, batch_size, hidden_size] @@ -120,30 +112,24 @@ class TransformerSelfAttentionRing(nn.Module): assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \ 'cannot be divided into query, key and value' partition_size = last_dim_value // 3 - (query_layer, key_layer, value_layer) = torch.split( - mixed_x_layer, partition_size, dim=last_dim) + (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, partition_size, dim=last_dim) # attention scores: [batch_size, num_heads, sub_seq_len, seq_len] - output_size = (query_layer.size(1), - query_layer.size(2), - query_layer.size(0), + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0) * self.world_size) # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] - query_layer = query_layer.view(output_size[2], - output_size[0] * output_size[1], -1) + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] - key_layer = key_layer.view(key_layer.size(0), - output_size[0] * output_size[1], -1) + key_layer = key_layer.view(key_layer.size(0), output_size[0] * output_size[1], -1) # attention_scores: [batch_size * num_heads, sub_seq_len, seq_len] attention_scores = RingQK.apply( - query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size] - key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size], + query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size] + key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size], batch_size, self.num_attention_heads, - sub_seq_length - ) + sub_seq_length) attention_scores /= self.norm_factor @@ -158,29 +144,19 @@ class TransformerSelfAttentionRing(nn.Module): attention_probs = self.attention_dropout(attention_probs) # context layer shape: [batch_size, num_heads, sub_seq_len, head_size] - output_size = (value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3)) + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sub_seq_len, batch_size * num_heads, head_size] - value_layer = value_layer.contiguous().view(value_layer.size(0), - output_size[0] * output_size[1], -1) + value_layer = value_layer.contiguous().view(value_layer.size(0), output_size[0] * output_size[1], -1) # # change view [b * num_heads, sub_seq_len, seq_len] - attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1), - attention_probs.size(2), - attention_probs.size(3)) + attention_probs = attention_probs.view( + attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3)) # matmul: [batch_size * num_heads, sub_seq_len, head_size] - context_layer = RingAV.apply( - attention_probs, - value_layer.transpose(0, 1).contiguous(), - batch_size, - self.num_attention_heads, - self.hidden_size_per_attention_head, - sub_seq_length - ) + context_layer = RingAV.apply(attention_probs, + value_layer.transpose(0, 1).contiguous(), batch_size, self.num_attention_heads, + self.hidden_size_per_attention_head, sub_seq_length) # change view [batch_size, num_heads, sub_seq_len, head_size] context_layer = context_layer.view(*output_size) @@ -189,8 +165,8 @@ class TransformerSelfAttentionRing(nn.Module): context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size] - new_context_layer_shape = context_layer.size()[:-2] + ( - self.hidden_size_per_attention_head * self.num_attention_heads,) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_attention_head * + self.num_attention_heads,) context_layer = context_layer.view(*new_context_layer_shape) output, bias = self.dense(context_layer) @@ -224,11 +200,7 @@ class _Linear(nn.Module): adding bias but instead return it. """ - def __init__(self, - input_size, - output_size, - bias=True, - skip_bias_add=False): + def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): super(_Linear, self).__init__() # Keep input parameters @@ -236,9 +208,10 @@ class _Linear(nn.Module): self.output_size = output_size self.skip_bias_add = skip_bias_add - self.weight = Parameter(torch.empty(self.output_size, - self.input_size, - )) + self.weight = Parameter(torch.empty( + self.output_size, + self.input_size, + )) nn.init.xavier_normal_(self.weight) if bias: