from functools import partial import pytest import torch import torch.nn as nn import torch.multiprocessing as mp import colossalai from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import free_port, get_current_device from colossalai.nn.layer.moe import Top2Router, MoeLayer from colossalai.global_variables import moe_env BATCH_SIZE = 32 NUM_EXPERTS = 4 CONFIG = dict(parallel=dict(moe=dict(size=4))) def check_equal(A, B, atol=1e-06): assert torch.allclose(A, B, rtol=0, atol=atol) is True def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # torch.set_printoptions(precision=30) torch.backends.cuda.matmul.allow_tf32 = False local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) torch.manual_seed(rs + local_rank) moe_env.reset_loss() tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) # print(f"tokens:\n{tokens}") router = Top2Router(1) layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity()) if data_type == torch.float16: layer = layer.half() layer.cuda_mode = False old_out = layer(tokens) # print(f"old output:\n{old_out}") ech = old_out.shape grad = torch.randn(ech, device=get_current_device()) old_out.backward(grad) o_tk_grad = tokens.grad.data.clone() o_gt_grad = layer.gate.weight.grad.data.clone() tokens.grad.zero_() layer.gate.weight.grad.zero_() layer.cuda_mode = True new_out = layer(tokens) # print(torch.max(torch.abs(old_out - new_out))) if data_type == torch.float32: check_equal(old_out, new_out) else: check_equal(old_out, new_out, 1e-2) # print(f"forward functions passed") # print(f"new output:\n{new_out}") new_out.backward(grad) n_tk_grad = tokens.grad.data.clone() n_gt_grad = layer.gate.weight.grad.data.clone() # print(torch.max(torch.abs(o_tk_grad - n_tk_grad))) if data_type == torch.float32: check_equal(o_tk_grad, n_tk_grad) else: check_equal(o_tk_grad, o_tk_grad, 1e-2) # print(f"tokens gradient passed") # print(torch.max(torch.abs(o_gt_grad - n_gt_grad))) if data_type == torch.float32: check_equal(o_gt_grad, n_gt_grad, 5e-05) else: check_equal(o_gt_grad, n_gt_grad, 2e-01) # print(f"linear weight gradient passed") @pytest.mark.skip(reason="Should be activated for detailed tests") @pytest.mark.parametrize("rs", [2, 42, 60]) @pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048]) @pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) def test_moe_top2(rs, hidden_size, data_type): world_size = 4 run_func = partial(run_routing, world_size=world_size, port=free_port(), rs=rs, hidden_size=hidden_size, data_type=data_type) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': test_moe_top2(2, 256, torch.float16)