mirror of https://github.com/hpcaitech/ColossalAI
59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
import os
|
|
from copy import deepcopy
|
|
from functools import partial
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.testing import assert_close
|
|
|
|
from colossalai.elixir.search import simple_search
|
|
from colossalai.elixir.utils import init_distributed
|
|
from colossalai.elixir.wrapper import ElixirModule
|
|
|
|
|
|
def exam_fused_layernorm(nproc, group):
|
|
torch_model = nn.LayerNorm(2048)
|
|
fused_model = deepcopy(torch_model)
|
|
|
|
torch_model = torch_model.cuda()
|
|
sr = simple_search(fused_model, nproc, 1, 1.0, verbose=True)
|
|
fused_model = ElixirModule(fused_model, sr, group, use_fused_kernels=True)
|
|
|
|
data = torch.randn(2, 2048, device='cuda')
|
|
|
|
torch_loss = torch_model(data).sum()
|
|
torch_loss.backward()
|
|
|
|
fused_loss = fused_model(data).sum()
|
|
fused_model.backward(fused_loss)
|
|
|
|
assert_close(torch_loss, fused_loss)
|
|
|
|
grad_state = fused_model.state_dict(from_param=True)
|
|
for name, param in torch_model.named_parameters():
|
|
assert_close(param.grad.cpu(), grad_state[name])
|
|
|
|
|
|
def run_dist(rank, world_size):
|
|
os.environ['RANK'] = str(rank)
|
|
os.environ['LOCAL_RANK'] = str(rank)
|
|
os.environ['WORLD_SIZE'] = str(world_size)
|
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
|
os.environ['MASTER_PORT'] = str(29512)
|
|
init_distributed()
|
|
exam_fused_layernorm(nproc=world_size, group=dist.GroupMember.WORLD)
|
|
|
|
|
|
@pytest.mark.dist
|
|
@pytest.mark.parametrize('world_size', [1])
|
|
@pytest.mark.skip(reason='need to install apex')
|
|
def test_fused_layernorm(world_size):
|
|
run_func = partial(run_dist, world_size=world_size)
|
|
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_fused_layernorm(world_size=1)
|