mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
179 lines
6.2 KiB
179 lines
6.2 KiB
import copy
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch.testing import assert_close
|
|
|
|
import colossalai
|
|
from colossalai.shardformer.layer import all_to_all_comm
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|
|
|
|
|
class SequenceParallelAttention(torch.nn.Module):
|
|
"""Initialization.
|
|
|
|
Arguments:
|
|
local_attention (Module): local attention with q,k,v
|
|
sequence_process_group (ProcessGroup): sequence parallel process group
|
|
scatter_idx (int): scatter_idx for all2all comm
|
|
gather_idx (int): gather_idx for all2all comm
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
heads_num: torch.Tensor,
|
|
hidden_dim: torch.Tensor,
|
|
enable_sequence_parallellism: bool = False,
|
|
sequence_process_group: dist.ProcessGroup = None,
|
|
scatter_idx: int = 2,
|
|
gather_idx: int = 1,
|
|
) -> None:
|
|
super(SequenceParallelAttention, self).__init__()
|
|
self.spg = sequence_process_group
|
|
self.scatter_idx = scatter_idx
|
|
self.gather_idx = gather_idx
|
|
self.heads_num = heads_num
|
|
self.hidden_dim = hidden_dim
|
|
assert hidden_dim % heads_num == 0
|
|
self.head_dim = hidden_dim // heads_num
|
|
self.enable_sequence_parallellism = enable_sequence_parallellism
|
|
|
|
self.q = nn.Linear(hidden_dim, hidden_dim)
|
|
self.k = nn.Linear(hidden_dim, hidden_dim)
|
|
self.v = nn.Linear(hidden_dim, hidden_dim)
|
|
self.out = nn.Linear(hidden_dim, hidden_dim)
|
|
|
|
def attn(self, q, k, v):
|
|
batch_size, seq_len = q.shape[0], q.shape[1]
|
|
|
|
scale = self.head_dim**0.5
|
|
qk = torch.matmul(q, k.transpose(-2, -1)) / scale
|
|
weights = F.softmax(qk, dim=-1)
|
|
|
|
attention_score = torch.matmul(weights, v)
|
|
|
|
return attention_score
|
|
|
|
def forward(self, x) -> Tensor:
|
|
bsz, q_len, _ = x.size()
|
|
|
|
seq_len = q_len * dist.get_world_size(self.spg) if self.enable_sequence_parallellism else q_len
|
|
num_heads = (
|
|
self.heads_num // dist.get_world_size(self.spg) if self.enable_sequence_parallellism else self.heads_num
|
|
)
|
|
|
|
# in shape : e.g., [s/p:h:]
|
|
query_states = self.q(x)
|
|
key_states = self.k(x)
|
|
value_states = self.v(x)
|
|
|
|
if self.enable_sequence_parallellism:
|
|
query_states = all_to_all_comm(query_states, self.spg, self.scatter_idx, self.gather_idx)
|
|
key_states = all_to_all_comm(key_states, self.spg, self.scatter_idx, self.gather_idx)
|
|
value_states = all_to_all_comm(value_states, self.spg, self.scatter_idx, self.gather_idx)
|
|
|
|
query_states = query_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
|
|
# out shape : e.g., [s:h/p:]
|
|
attn_score = self.attn(query_states, key_states, value_states)
|
|
attn_score = attn_score.transpose(1, 2).contiguous()
|
|
attn_score = attn_score.reshape(bsz, seq_len, num_heads * self.head_dim)
|
|
if self.enable_sequence_parallellism:
|
|
attn_score = all_to_all_comm(attn_score, self.spg, self.gather_idx, self.scatter_idx)
|
|
|
|
# output e.g., [s/p::h]
|
|
output = self.out(attn_score)
|
|
|
|
return output
|
|
|
|
|
|
def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
|
|
seq_len = seq_len
|
|
hidden_dim = hidden_dim
|
|
head_num = head_num
|
|
batch_size = batch_size
|
|
world_size = dist.get_world_size()
|
|
|
|
x = torch.randn(batch_size, seq_len, hidden_dim).cuda()
|
|
x_unshard = x.clone()
|
|
x_unshard.requires_grad_(True)
|
|
x_input = torch.chunk(x.clone(), world_size, dim=1)[dist.get_rank()]
|
|
x_input.requires_grad_(True)
|
|
|
|
# Multi-head Attention
|
|
mha = SequenceParallelAttention(head_num, hidden_dim).cuda()
|
|
# Multi-head Attention forward
|
|
mha_out = mha(x_unshard)
|
|
|
|
# Sequence parallel Attention
|
|
sp_attn = SequenceParallelAttention(head_num, hidden_dim, True).cuda()
|
|
sp_attn.load_state_dict(copy.deepcopy(mha.state_dict()))
|
|
# Sequence parallel Attention forward
|
|
dist_attn_out = sp_attn(x_input)
|
|
|
|
# gather the output of sequence parallel attention
|
|
out_list = [torch.empty_like(dist_attn_out) for _ in range(world_size)]
|
|
dist.all_gather(out_list, dist_attn_out)
|
|
seq_out = torch.cat(out_list, dim=1)
|
|
|
|
# forward result check
|
|
assert_close(seq_out, mha_out)
|
|
|
|
# Multi-head Attention backward
|
|
mha_out.sum().backward()
|
|
q_grad = mha.q.weight.grad
|
|
k_grad = mha.k.weight.grad
|
|
v_grad = mha.v.weight.grad
|
|
o_grad = mha.out.weight.grad
|
|
x_grad = x_unshard.grad
|
|
|
|
# Sequence parallel Attention backward
|
|
dist_attn_out.sum().backward()
|
|
q_grad_seq = sp_attn.q.weight.grad
|
|
k_grad_seq = sp_attn.k.weight.grad
|
|
v_grad_seq = sp_attn.v.weight.grad
|
|
o_grad_seq = sp_attn.out.weight.grad
|
|
x_grad_seq = x_input.grad
|
|
# all_reduce the grad of sequence parallel attention weight
|
|
dist.all_reduce(q_grad_seq)
|
|
dist.all_reduce(k_grad_seq)
|
|
dist.all_reduce(v_grad_seq)
|
|
dist.all_reduce(o_grad_seq)
|
|
# gather the grad of sequence parallel attention input
|
|
x_grad_seq_list = [torch.empty_like(x_grad_seq) for _ in range(world_size)]
|
|
dist.all_gather(x_grad_seq_list, x_grad_seq)
|
|
x_grad_seq_gather = torch.cat(x_grad_seq_list, dim=1)
|
|
|
|
# backward result check
|
|
assert_close(q_grad_seq, q_grad)
|
|
assert_close(k_grad_seq, k_grad)
|
|
assert_close(v_grad_seq, v_grad, atol=1e-4, rtol=1e-4)
|
|
assert_close(o_grad_seq, o_grad)
|
|
assert_close(x_grad_seq_gather, x_grad)
|
|
|
|
|
|
@parameterize("seq_len", [128])
|
|
@parameterize("hidden_dim", [64])
|
|
@parameterize("head_num", [4])
|
|
@parameterize("batch_size", [1])
|
|
def run_seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
|
|
seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size)
|
|
|
|
|
|
def check_all2all_attn(rank, world_size, port):
|
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
run_seq_parallel_attn()
|
|
|
|
|
|
@rerun_if_address_is_in_use()
|
|
def test_all_to_all_attention():
|
|
spawn(check_all2all_attn, nprocs=4)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_all_to_all_attention()
|