From 0e06f62160bb52018a888f3837067042f163747d Mon Sep 17 00:00:00 2001 From: Geng Zhang <34452939+zxgx@users.noreply.github.com> Date: Tue, 12 Jul 2022 18:14:21 +0800 Subject: [PATCH] [NFC] polish colossalai/nn/layer/parallel_sequence/_operation.py code style (#1266) --- .../nn/layer/parallel_sequence/_operation.py | 74 +++++++------------ 1 file changed, 25 insertions(+), 49 deletions(-) diff --git a/colossalai/nn/layer/parallel_sequence/_operation.py b/colossalai/nn/layer/parallel_sequence/_operation.py index 119302a09..fc8049422 100644 --- a/colossalai/nn/layer/parallel_sequence/_operation.py +++ b/colossalai/nn/layer/parallel_sequence/_operation.py @@ -19,24 +19,17 @@ class RingQK(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, - sub_q, - sub_k, - batch_size, - num_attention_heads, - sub_seq_length): + def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length): # save tensor for backward ctx.save_for_backward(sub_q, sub_k) ctx.sub_seq_length = sub_seq_length # create local segment of attention score - attention_score = torch.empty( - batch_size * num_attention_heads, - sub_seq_length, - sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), - dtype=sub_q.dtype, - device=get_current_device() - ) + attention_score = torch.empty(batch_size * num_attention_heads, + sub_seq_length, + sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), + dtype=sub_q.dtype, + device=get_current_device()) # compute local QK^T part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) @@ -44,7 +37,7 @@ class RingQK(torch.autograd.Function): local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) start_idx = local_rank * sub_seq_length end_idx = (local_rank + 1) * sub_seq_length - attention_score[:, :, start_idx: end_idx] = part_a + attention_score[:, :, start_idx:end_idx] = part_a # compute QK^T in ring-all-reduce style for i in range(local_world_size - 1): @@ -63,19 +56,18 @@ class RingQK(torch.autograd.Function): local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) # calculate gradient of sub_k - grad_k = torch.matmul( - grad_output.transpose(2, 1), - sub_q - ) + grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q) dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE)) - grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length] + grad_k = grad_k[:, local_rank * ctx.sub_seq_length:(local_rank + 1) * ctx.sub_seq_length] grad_k /= local_world_size # calculate gradient for sub_q - grad_q = torch.zeros_like(sub_q, - dtype=sub_q.dtype, - device=get_current_device(), ) + grad_q = torch.zeros_like( + sub_q, + dtype=sub_q.dtype, + device=get_current_device(), + ) # compute with local sub_k start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length) @@ -85,7 +77,7 @@ class RingQK(torch.autograd.Function): for i in range(local_world_size - 1): sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE) start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length) - grad_q += torch.matmul(grad_output[:, :, start_idx: end_idx], sub_k) + grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k) grad_q /= local_world_size @@ -99,23 +91,16 @@ class RingAV(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, - attention_score, - sub_v, - batch_size, - num_attention_heads, - attention_head_size, - sub_seq_length): + def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attention_head_size, sub_seq_length): local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length) - sub_attention_result = torch.zeros( - batch_size * num_attention_heads, - sub_seq_length, - attention_head_size, - device=get_current_device(), - dtype=attention_score.dtype) + sub_attention_result = torch.zeros(batch_size * num_attention_heads, + sub_seq_length, + attention_head_size, + device=get_current_device(), + dtype=attention_score.dtype) # save tensors for backward ctx.save_for_backward(attention_score, sub_v) @@ -144,23 +129,16 @@ class RingAV(torch.autograd.Function): attention_scores, sub_v = ctx.saved_tensors # calculate gradient of v - grad_v = torch.matmul( - attention_scores.transpose(2, 1), - grad_output - ) + grad_v = torch.matmul(attention_scores.transpose(2, 1), grad_output) dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE)) grad_v = grad_v[:, local_start_idx:local_end_idx] grad_v /= local_world_size # calculate gradient for attention score - grad_attention_score = torch.zeros_like(attention_scores, - dtype=grad_output.dtype, - device=get_current_device()) + grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device()) # compute with local sub_k - grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul( - grad_output, - sub_v.transpose(2, 1)) + grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) # compute QK^T in ring-all-reduce style for i in range(local_world_size - 1): @@ -168,8 +146,6 @@ class RingAV(torch.autograd.Function): start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length) # compute grad_q - grad_attention_score[:, :, start_idx:end_idx] += torch.matmul( - grad_output, - sub_v.transpose(2, 1)) + grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) return grad_attention_score, grad_v, None, None, None, None