ColossalAI/tests/test_elixir/test_kernels/test_attn.py

36 lines
946 B
Python

from copy import deepcopy
import pytest
from torch.testing import assert_close
from tests.test_elixir.utils import TEST_MODELS, to_cuda
def exam_one_model(model_fn, data_fn):
from colossalai.elixir.kernels.attn_wrapper import wrap_attention
torch_model = model_fn().cuda()
test_model = deepcopy(torch_model)
test_model = wrap_attention(test_model)
data = to_cuda(data_fn())
torch_out = torch_model(**data)
torch_out.backward()
test_out = test_model(**data)
test_out.backward()
assert_close(torch_out, test_out)
for (name, p_torch), p_test in zip(torch_model.named_parameters(), test_model.parameters()):
assert_close(p_torch.grad, p_test.grad)
@pytest.mark.skip(reason="Need to install xformers")
def test_gpt_atten_kernel():
exam_one_model(*TEST_MODELS.get('gpt2_micro'))
exam_one_model(*TEST_MODELS.get('opt_micro'))
if __name__ == '__main__':
test_gpt_atten_kernel()