You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_layers/test_sequence/test_sequence.py

139 lines
4.6 KiB

import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
def check_ring_qk(rank, world_size):
# params
batch_size = 4
num_heads = 4
seq_length = 32
attention_head_size = 32
sub_seq_length = seq_length // world_size
# create master tensors
q = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
k = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
# create distributed tensors
sub_q = q.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
sub_k = k.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
# set autograd attributes
q.requires_grad = True
k.requires_grad = True
q.retain_grad()
k.retain_grad()
sub_q.requires_grad = True
sub_k.requires_grad = True
sub_q.retain_grad()
sub_k.retain_grad()
# compute master attention scores
a = torch.matmul(q, k.transpose(2, 1))
# compute distributed attention scores
ring_qk = colossalai.nn.layer.parallel_sequence.RingQK.apply
sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length)
# check master and distributed attention scores
sub_master_a = a[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2)
# run master backward
a.retain_grad()
a.mean().backward()
# run distributed backward
partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
torch.autograd.backward(sub_a, partial_master_a_grad)
# check master and distributed grads
partial_master_q_grad = q.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \
'attention score cannot match'
def check_ring_av(rank, world_size):
# params
batch_size = 4
num_heads = 4
seq_length = 16
attention_head_size = 32
sub_seq_length = seq_length // world_size
# create master tensors
a = torch.rand(batch_size * num_heads, seq_length, seq_length).cuda()
v = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
# create distributed tensors
sub_a = a.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
sub_v = v.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
# set autograd attributes
a.requires_grad = True
v.requires_grad = True
a.retain_grad()
v.retain_grad()
sub_a.requires_grad = True
sub_v.requires_grad = True
sub_a.retain_grad()
sub_v.retain_grad()
# compute master attention scores
out = torch.matmul(a, v)
# compute distributed attention scores
ring_av = colossalai.nn.layer.parallel_sequence.RingAV.apply
sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length)
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
# check master and distributed output
sub_master_out = out[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2)
# # run master backward
out.retain_grad()
out.mean().backward()
# # run distributed backward
partial_master_out_grad = out.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
torch.autograd.backward(sub_out, partial_master_out_grad)
# # check master and distributed grads
partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \
'attention output cannot match'
def run_test(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port)
# check_ring_qk(rank, world_size)
check_ring_av(rank, world_size)
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_sequence():
spawn(run_test, 4)
if __name__ == '__main__':
test_sequence()