[NFC] polish colossalai/nn/layer/parallel_sequence/_operation.py code style (#1266)

pull/1298/head
Geng Zhang 2022-07-12 18:14:21 +08:00 committed by Frank Lee
parent c95e18cdb9
commit 0e06f62160
1 changed files with 25 additions and 49 deletions

View File

@ -19,24 +19,17 @@ class RingQK(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx,
sub_q,
sub_k,
batch_size,
num_attention_heads,
sub_seq_length):
def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length):
# save tensor for backward
ctx.save_for_backward(sub_q, sub_k)
ctx.sub_seq_length = sub_seq_length
# create local segment of attention score
attention_score = torch.empty(
batch_size * num_attention_heads,
sub_seq_length,
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
dtype=sub_q.dtype,
device=get_current_device()
)
attention_score = torch.empty(batch_size * num_attention_heads,
sub_seq_length,
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
dtype=sub_q.dtype,
device=get_current_device())
# compute local QK^T
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
@ -44,7 +37,7 @@ class RingQK(torch.autograd.Function):
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
start_idx = local_rank * sub_seq_length
end_idx = (local_rank + 1) * sub_seq_length
attention_score[:, :, start_idx: end_idx] = part_a
attention_score[:, :, start_idx:end_idx] = part_a
# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
@ -63,19 +56,18 @@ class RingQK(torch.autograd.Function):
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
# calculate gradient of sub_k
grad_k = torch.matmul(
grad_output.transpose(2, 1),
sub_q
)
grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q)
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length]
grad_k = grad_k[:, local_rank * ctx.sub_seq_length:(local_rank + 1) * ctx.sub_seq_length]
grad_k /= local_world_size
# calculate gradient for sub_q
grad_q = torch.zeros_like(sub_q,
dtype=sub_q.dtype,
device=get_current_device(), )
grad_q = torch.zeros_like(
sub_q,
dtype=sub_q.dtype,
device=get_current_device(),
)
# compute with local sub_k
start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
@ -85,7 +77,7 @@ class RingQK(torch.autograd.Function):
for i in range(local_world_size - 1):
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
grad_q += torch.matmul(grad_output[:, :, start_idx: end_idx], sub_k)
grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k)
grad_q /= local_world_size
@ -99,23 +91,16 @@ class RingAV(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx,
attention_score,
sub_v,
batch_size,
num_attention_heads,
attention_head_size,
sub_seq_length):
def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attention_head_size, sub_seq_length):
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length)
sub_attention_result = torch.zeros(
batch_size * num_attention_heads,
sub_seq_length,
attention_head_size,
device=get_current_device(),
dtype=attention_score.dtype)
sub_attention_result = torch.zeros(batch_size * num_attention_heads,
sub_seq_length,
attention_head_size,
device=get_current_device(),
dtype=attention_score.dtype)
# save tensors for backward
ctx.save_for_backward(attention_score, sub_v)
@ -144,23 +129,16 @@ class RingAV(torch.autograd.Function):
attention_scores, sub_v = ctx.saved_tensors
# calculate gradient of v
grad_v = torch.matmul(
attention_scores.transpose(2, 1),
grad_output
)
grad_v = torch.matmul(attention_scores.transpose(2, 1), grad_output)
dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_v = grad_v[:, local_start_idx:local_end_idx]
grad_v /= local_world_size
# calculate gradient for attention score
grad_attention_score = torch.zeros_like(attention_scores,
dtype=grad_output.dtype,
device=get_current_device())
grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device())
# compute with local sub_k
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(
grad_output,
sub_v.transpose(2, 1))
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))
# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
@ -168,8 +146,6 @@ class RingAV(torch.autograd.Function):
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
# compute grad_q
grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(
grad_output,
sub_v.transpose(2, 1))
grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))
return grad_attention_score, grad_v, None, None, None, None