You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_shardformer/test_layer/test_sequence_parallel.py

179 lines
6.2 KiB

import copy
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer import all_to_all_comm
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
class SequenceParallelAttention(torch.nn.Module):
"""Initialization.
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
"""
def __init__(
self,
heads_num: torch.Tensor,
hidden_dim: torch.Tensor,
enable_sequence_parallellism: bool = False,
sequence_process_group: dist.ProcessGroup = None,
scatter_idx: int = 2,
gather_idx: int = 1,
) -> None:
super(SequenceParallelAttention, self).__init__()
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.heads_num = heads_num
self.hidden_dim = hidden_dim
assert hidden_dim % heads_num == 0
self.head_dim = hidden_dim // heads_num
self.enable_sequence_parallellism = enable_sequence_parallellism
self.q = nn.Linear(hidden_dim, hidden_dim)
self.k = nn.Linear(hidden_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, hidden_dim)
self.out = nn.Linear(hidden_dim, hidden_dim)
def attn(self, q, k, v):
batch_size, seq_len = q.shape[0], q.shape[1]
scale = self.head_dim**0.5
qk = torch.matmul(q, k.transpose(-2, -1)) / scale
weights = F.softmax(qk, dim=-1)
attention_score = torch.matmul(weights, v)
return attention_score
def forward(self, x) -> Tensor:
bsz, q_len, _ = x.size()
seq_len = q_len * dist.get_world_size(self.spg) if self.enable_sequence_parallellism else q_len
num_heads = (
self.heads_num // dist.get_world_size(self.spg) if self.enable_sequence_parallellism else self.heads_num
)
# in shape : e.g., [s/p:h:]
query_states = self.q(x)
key_states = self.k(x)
value_states = self.v(x)
if self.enable_sequence_parallellism:
query_states = all_to_all_comm(query_states, self.spg, self.scatter_idx, self.gather_idx)
key_states = all_to_all_comm(key_states, self.spg, self.scatter_idx, self.gather_idx)
value_states = all_to_all_comm(value_states, self.spg, self.scatter_idx, self.gather_idx)
query_states = query_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
# out shape : e.g., [s:h/p:]
attn_score = self.attn(query_states, key_states, value_states)
attn_score = attn_score.transpose(1, 2).contiguous()
attn_score = attn_score.reshape(bsz, seq_len, num_heads * self.head_dim)
if self.enable_sequence_parallellism:
attn_score = all_to_all_comm(attn_score, self.spg, self.gather_idx, self.scatter_idx)
# output e.g., [s/p::h]
output = self.out(attn_score)
return output
def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
seq_len = seq_len
hidden_dim = hidden_dim
head_num = head_num
batch_size = batch_size
world_size = dist.get_world_size()
x = torch.randn(batch_size, seq_len, hidden_dim).cuda()
x_unshard = x.clone()
x_unshard.requires_grad_(True)
x_input = torch.chunk(x.clone(), world_size, dim=1)[dist.get_rank()]
x_input.requires_grad_(True)
# Multi-head Attention
mha = SequenceParallelAttention(head_num, hidden_dim).cuda()
# Multi-head Attention forward
mha_out = mha(x_unshard)
# Sequence parallel Attention
sp_attn = SequenceParallelAttention(head_num, hidden_dim, True).cuda()
sp_attn.load_state_dict(copy.deepcopy(mha.state_dict()))
# Sequence parallel Attention forward
dist_attn_out = sp_attn(x_input)
# gather the output of sequence parallel attention
out_list = [torch.empty_like(dist_attn_out) for _ in range(world_size)]
dist.all_gather(out_list, dist_attn_out)
seq_out = torch.cat(out_list, dim=1)
# forward result check
assert_close(seq_out, mha_out)
# Multi-head Attention backward
mha_out.sum().backward()
q_grad = mha.q.weight.grad
k_grad = mha.k.weight.grad
v_grad = mha.v.weight.grad
o_grad = mha.out.weight.grad
x_grad = x_unshard.grad
# Sequence parallel Attention backward
dist_attn_out.sum().backward()
q_grad_seq = sp_attn.q.weight.grad
k_grad_seq = sp_attn.k.weight.grad
v_grad_seq = sp_attn.v.weight.grad
o_grad_seq = sp_attn.out.weight.grad
x_grad_seq = x_input.grad
# all_reduce the grad of sequence parallel attention weight
dist.all_reduce(q_grad_seq)
dist.all_reduce(k_grad_seq)
dist.all_reduce(v_grad_seq)
dist.all_reduce(o_grad_seq)
# gather the grad of sequence parallel attention input
x_grad_seq_list = [torch.empty_like(x_grad_seq) for _ in range(world_size)]
dist.all_gather(x_grad_seq_list, x_grad_seq)
x_grad_seq_gather = torch.cat(x_grad_seq_list, dim=1)
# backward result check
assert_close(q_grad_seq, q_grad)
assert_close(k_grad_seq, k_grad)
assert_close(v_grad_seq, v_grad, atol=1e-4, rtol=1e-4)
assert_close(o_grad_seq, o_grad)
assert_close(x_grad_seq_gather, x_grad)
@parameterize("seq_len", [128])
@parameterize("hidden_dim", [64])
@parameterize("head_num", [4])
@parameterize("batch_size", [1])
def run_seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size)
def check_all2all_attn(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_seq_parallel_attn()
@rerun_if_address_is_in_use()
def test_all_to_all_attention():
spawn(check_all2all_attn, nprocs=4)
if __name__ == "__main__":
test_all_to_all_attention()