ColossalAI/tests/test_elixir/test_kernels/test_ln.py

59 lines
1.7 KiB
Python

import os
from copy import deepcopy
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from colossalai.elixir.search import simple_search
from colossalai.elixir.utils import init_distributed
from colossalai.elixir.wrapper import ElixirModule
def exam_fused_layernorm(nproc, group):
torch_model = nn.LayerNorm(2048)
fused_model = deepcopy(torch_model)
torch_model = torch_model.cuda()
sr = simple_search(fused_model, nproc, 1, 1.0, verbose=True)
fused_model = ElixirModule(fused_model, sr, group, use_fused_kernels=True)
data = torch.randn(2, 2048, device='cuda')
torch_loss = torch_model(data).sum()
torch_loss.backward()
fused_loss = fused_model(data).sum()
fused_model.backward(fused_loss)
assert_close(torch_loss, fused_loss)
grad_state = fused_model.state_dict(from_param=True)
for name, param in torch_model.named_parameters():
assert_close(param.grad.cpu(), grad_state[name])
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_fused_layernorm(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1])
@pytest.mark.skip(reason='need to install apex')
def test_fused_layernorm(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_fused_layernorm(world_size=1)