mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
179 lines
6.2 KiB
179 lines
6.2 KiB
import copy
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer import all_to_all_comm
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
class SequenceParallelAttention(torch.nn.Module):
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
def __init__(
heads_num: torch.Tensor,
hidden_dim: torch.Tensor,
enable_sequence_parallellism: bool = False,
sequence_process_group: dist.ProcessGroup = None,
scatter_idx: int = 2,
gather_idx: int = 1,
) -> None:
super(SequenceParallelAttention, self).__init__()
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.heads_num = heads_num
self.hidden_dim = hidden_dim
assert hidden_dim % heads_num == 0
self.head_dim = hidden_dim // heads_num
self.enable_sequence_parallellism = enable_sequence_parallellism
self.q = nn.Linear(hidden_dim, hidden_dim)
self.k = nn.Linear(hidden_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, hidden_dim)
self.out = nn.Linear(hidden_dim, hidden_dim)
def attn(self, q, k, v):
batch_size, seq_len = q.shape[0], q.shape[1]
scale = self.head_dim**0.5
qk = torch.matmul(q, k.transpose(-2, -1)) / scale
weights = F.softmax(qk, dim=-1)
attention_score = torch.matmul(weights, v)
return attention_score
def forward(self, x) -> Tensor:
bsz, q_len, _ = x.size()
seq_len = q_len * dist.get_world_size(self.spg) if self.enable_sequence_parallellism else q_len
num_heads = (
self.heads_num // dist.get_world_size(self.spg) if self.enable_sequence_parallellism else self.heads_num
# in shape : e.g., [s/p:h:]
query_states = self.q(x)
key_states = self.k(x)
value_states = self.v(x)
if self.enable_sequence_parallellism:
query_states = all_to_all_comm(query_states, self.spg, self.scatter_idx, self.gather_idx)
key_states = all_to_all_comm(key_states, self.spg, self.scatter_idx, self.gather_idx)
value_states = all_to_all_comm(value_states, self.spg, self.scatter_idx, self.gather_idx)
query_states = query_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
# out shape : e.g., [s:h/p:]
attn_score = self.attn(query_states, key_states, value_states)
attn_score = attn_score.transpose(1, 2).contiguous()
attn_score = attn_score.reshape(bsz, seq_len, num_heads * self.head_dim)
if self.enable_sequence_parallellism:
attn_score = all_to_all_comm(attn_score, self.spg, self.gather_idx, self.scatter_idx)
# output e.g., [s/p::h]
output = self.out(attn_score)
return output
def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
seq_len = seq_len
hidden_dim = hidden_dim
head_num = head_num
batch_size = batch_size
world_size = dist.get_world_size()
x = torch.randn(batch_size, seq_len, hidden_dim).cuda()
x_unshard = x.clone()
x_input = torch.chunk(x.clone(), world_size, dim=1)[dist.get_rank()]
# Multi-head Attention
mha = SequenceParallelAttention(head_num, hidden_dim).cuda()
# Multi-head Attention forward
mha_out = mha(x_unshard)
# Sequence parallel Attention
sp_attn = SequenceParallelAttention(head_num, hidden_dim, True).cuda()
# Sequence parallel Attention forward
dist_attn_out = sp_attn(x_input)
# gather the output of sequence parallel attention
out_list = [torch.empty_like(dist_attn_out) for _ in range(world_size)]
dist.all_gather(out_list, dist_attn_out)
seq_out = torch.cat(out_list, dim=1)
# forward result check
assert_close(seq_out, mha_out)
# Multi-head Attention backward
q_grad = mha.q.weight.grad
k_grad = mha.k.weight.grad
v_grad = mha.v.weight.grad
o_grad = mha.out.weight.grad
x_grad = x_unshard.grad
# Sequence parallel Attention backward
q_grad_seq = sp_attn.q.weight.grad
k_grad_seq = sp_attn.k.weight.grad
v_grad_seq = sp_attn.v.weight.grad
o_grad_seq = sp_attn.out.weight.grad
x_grad_seq = x_input.grad
# all_reduce the grad of sequence parallel attention weight
# gather the grad of sequence parallel attention input
x_grad_seq_list = [torch.empty_like(x_grad_seq) for _ in range(world_size)]
dist.all_gather(x_grad_seq_list, x_grad_seq)
x_grad_seq_gather = torch.cat(x_grad_seq_list, dim=1)
# backward result check
assert_close(q_grad_seq, q_grad)
assert_close(k_grad_seq, k_grad)
assert_close(v_grad_seq, v_grad, atol=1e-4, rtol=1e-4)
assert_close(o_grad_seq, o_grad)
assert_close(x_grad_seq_gather, x_grad)
@parameterize("seq_len", [128])
@parameterize("hidden_dim", [64])
@parameterize("head_num", [4])
@parameterize("batch_size", [1])
def run_seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size)
def check_all2all_attn(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
def test_all_to_all_attention():
spawn(check_all2all_attn, nprocs=4)
if __name__ == "__main__":