mirror of https://github.com/hpcaitech/ColossalAI
170 lines
6.4 KiB
Python
170 lines
6.4 KiB
Python
|
#!/usr/bin/env python
|
||
|
# -*- encoding: utf-8 -*-
|
||
|
|
||
|
import torch
|
||
|
from torch import distributed as dist
|
||
|
|
||
|
from colossalai.communication import ring_forward
|
||
|
from colossalai.context.parallel_mode import ParallelMode
|
||
|
from colossalai.core import global_context as gpc
|
||
|
from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
|
||
|
from colossalai.utils import get_current_device
|
||
|
|
||
|
|
||
|
class RingQK(torch.autograd.Function):
|
||
|
"""
|
||
|
Calculate QK in a ring-exchange style
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx,
|
||
|
sub_q,
|
||
|
sub_k,
|
||
|
batch_size,
|
||
|
num_attention_heads,
|
||
|
sub_seq_length):
|
||
|
# save tensor for backward
|
||
|
ctx.save_for_backward(sub_q, sub_k)
|
||
|
ctx.sub_seq_length = sub_seq_length
|
||
|
|
||
|
# create local segment of attention score
|
||
|
attention_score = torch.empty(
|
||
|
batch_size * num_attention_heads,
|
||
|
sub_seq_length,
|
||
|
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
|
||
|
dtype=sub_q.dtype,
|
||
|
device=get_current_device()
|
||
|
)
|
||
|
|
||
|
# compute local QK^T
|
||
|
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
|
||
|
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||
|
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||
|
start_idx = local_rank * sub_seq_length
|
||
|
end_idx = (local_rank + 1) * sub_seq_length
|
||
|
attention_score[:, :, start_idx: end_idx] = part_a
|
||
|
|
||
|
# compute QK^T in ring-all-reduce style
|
||
|
for i in range(local_world_size - 1):
|
||
|
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
|
||
|
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)
|
||
|
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
|
||
|
attention_score[:, :, start_idx:end_idx] = part_a
|
||
|
|
||
|
return attention_score
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
sub_q, sub_k, = ctx.saved_tensors
|
||
|
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||
|
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||
|
|
||
|
# calculate gradient of sub_k
|
||
|
grad_k = torch.matmul(
|
||
|
grad_output.transpose(2, 1),
|
||
|
sub_q
|
||
|
)
|
||
|
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||
|
grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length]
|
||
|
grad_k /= local_world_size
|
||
|
|
||
|
# calculate gradient for sub_q
|
||
|
grad_q = torch.zeros_like(sub_q,
|
||
|
dtype=sub_q.dtype,
|
||
|
device=get_current_device(), )
|
||
|
|
||
|
# compute with local sub_k
|
||
|
start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
|
||
|
grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k)
|
||
|
|
||
|
# compute QK^T in ring-all-reduce style
|
||
|
for i in range(local_world_size - 1):
|
||
|
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
|
||
|
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
|
||
|
grad_q += torch.matmul(grad_output[:, :, start_idx: end_idx], sub_k)
|
||
|
|
||
|
grad_q /= local_world_size
|
||
|
|
||
|
return grad_q, grad_k, None, None, None
|
||
|
|
||
|
|
||
|
class RingAV(torch.autograd.Function):
|
||
|
"""
|
||
|
Calculate AV in a ring-exchange style
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx,
|
||
|
attention_score,
|
||
|
sub_v,
|
||
|
batch_size,
|
||
|
num_attention_heads,
|
||
|
attention_head_size,
|
||
|
sub_seq_length):
|
||
|
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||
|
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||
|
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length)
|
||
|
|
||
|
sub_attention_result = torch.zeros(
|
||
|
batch_size * num_attention_heads,
|
||
|
sub_seq_length,
|
||
|
attention_head_size,
|
||
|
device=get_current_device(),
|
||
|
dtype=attention_score.dtype)
|
||
|
|
||
|
# save tensors for backward
|
||
|
ctx.save_for_backward(attention_score, sub_v)
|
||
|
ctx.sub_seq_length = sub_seq_length
|
||
|
|
||
|
# compute local AV
|
||
|
part_av = torch.matmul(attention_score[:, :, local_start_idx:local_end_idx], sub_v)
|
||
|
sub_attention_result += part_av
|
||
|
|
||
|
# compute AV in ring - all - reduce style
|
||
|
for i in range(local_world_size - 1):
|
||
|
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
|
||
|
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)
|
||
|
|
||
|
# compute QK^T
|
||
|
part_av = torch.matmul(attention_score[:, :, start_idx:end_idx], sub_v)
|
||
|
sub_attention_result += part_av
|
||
|
return sub_attention_result
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||
|
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||
|
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
|
||
|
attention_scores, sub_v = ctx.saved_tensors
|
||
|
|
||
|
# calculate gradient of v
|
||
|
grad_v = torch.matmul(
|
||
|
attention_scores.transpose(2, 1),
|
||
|
grad_output
|
||
|
)
|
||
|
dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||
|
grad_v = grad_v[:, local_start_idx:local_end_idx]
|
||
|
grad_v /= local_world_size
|
||
|
|
||
|
# calculate gradient for attention score
|
||
|
grad_attention_score = torch.zeros_like(attention_scores,
|
||
|
dtype=grad_output.dtype,
|
||
|
device=get_current_device())
|
||
|
|
||
|
# compute with local sub_k
|
||
|
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(
|
||
|
grad_output,
|
||
|
sub_v.transpose(2, 1))
|
||
|
|
||
|
# compute QK^T in ring-all-reduce style
|
||
|
for i in range(local_world_size - 1):
|
||
|
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
|
||
|
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
|
||
|
|
||
|
# compute grad_q
|
||
|
grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(
|
||
|
grad_output,
|
||
|
sub_v.transpose(2, 1))
|
||
|
|
||
|
return grad_attention_score, grad_v, None, None, None, None
|