mirror of https://github.com/hpcaitech/ColossalAI
36 lines
946 B
Python
36 lines
946 B
Python
from copy import deepcopy
|
|
|
|
import pytest
|
|
from torch.testing import assert_close
|
|
|
|
from tests.test_elixir.utils import TEST_MODELS, to_cuda
|
|
|
|
|
|
def exam_one_model(model_fn, data_fn):
|
|
from colossalai.elixir.kernels.attn_wrapper import wrap_attention
|
|
|
|
torch_model = model_fn().cuda()
|
|
test_model = deepcopy(torch_model)
|
|
test_model = wrap_attention(test_model)
|
|
|
|
data = to_cuda(data_fn())
|
|
torch_out = torch_model(**data)
|
|
torch_out.backward()
|
|
|
|
test_out = test_model(**data)
|
|
test_out.backward()
|
|
|
|
assert_close(torch_out, test_out)
|
|
for (name, p_torch), p_test in zip(torch_model.named_parameters(), test_model.parameters()):
|
|
assert_close(p_torch.grad, p_test.grad)
|
|
|
|
|
|
@pytest.mark.skip(reason="Need to install xformers")
|
|
def test_gpt_atten_kernel():
|
|
exam_one_model(*TEST_MODELS.get('gpt2_micro'))
|
|
exam_one_model(*TEST_MODELS.get('opt_micro'))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_gpt_atten_kernel()
|