You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_moe/test_kernel.py

95 lines
3.2 KiB

import os
import pytest
import torch
from colossalai.accelerator import get_accelerator
from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum
NUM_EXPERTS = 4
BATCH_SIZE = 4
SEQ_LEN = 4
MOE_TENSOR_PATH = os.getenv("MOE_TENSOR_PATH")
def check_equal(tensor_a, tensor_b, atol=1e-06):
assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True
def run_moe_cumsum():
test_mask = torch.tensor(
[
[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[1, 0, 0, 0],
],
dtype=torch.int32,
).to("cuda")
out_no_kernel = moe_cumsum(test_mask, use_kernel=False)
out_kernel = moe_cumsum(test_mask, use_kernel=True)
print(out_no_kernel.dtype, out_kernel.dtype)
check_equal(out_no_kernel.to(torch.int32), out_kernel)
def run_moe_dispatch_combine_fwd_bwd(data_type=torch.float32, hidden_size=128, num_experts=4):
tokens = torch.randn(
BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True
)
# use kernel
route_result_list_kernel = torch.load(f"{MOE_TENSOR_PATH}/True_4_{data_type}.pt")
# dispatch
dispatch_data_kernel = MoeDispatch.apply(tokens, *route_result_list_kernel[1:])
dispatch_data_kernel = dispatch_data_kernel.reshape(num_experts, -1, hidden_size)
# combine
expert_output = dispatch_data_kernel.reshape(-1, hidden_size)
ans_kernel = MoeCombine.apply(expert_output, *route_result_list_kernel)
# no kernel
route_result_list_no_kernel = torch.load(f"{MOE_TENSOR_PATH}/False_2_{data_type}.pt")
# dispatch
sec_mask_f = route_result_list_no_kernel[1].type_as(tokens)
dispatch_data_no_kernel = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# combine
combine_weights = route_result_list_no_kernel[0].type_as(tokens)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ans_no_kernel = torch.matmul(combine_weights, expert_output)
# check fwd
if data_type == torch.float32:
check_equal(dispatch_data_kernel.reshape(dispatch_data_no_kernel.shape), dispatch_data_no_kernel)
else:
check_equal(dispatch_data_kernel.reshape(dispatch_data_no_kernel.shape), dispatch_data_no_kernel, 1e-2)
if data_type == torch.float32:
check_equal(ans_kernel, ans_no_kernel)
else:
check_equal(ans_kernel, ans_no_kernel, 1e-2)
# check bwd
out_shape = ans_kernel.shape
grad = torch.randn(out_shape, device=get_accelerator().get_current_device())
ans_kernel.backward(grad, retain_graph=True)
grad_kernel = tokens.grad.data.clone()
tokens.grad.zero_()
ans_no_kernel.backward(grad) # get gradient
grad_no_kernel = tokens.grad.data.clone()
tokens.grad.zero_()
if data_type == torch.float32:
check_equal(grad_no_kernel, grad_kernel)
else:
check_equal(grad_no_kernel, grad_kernel, 1e-2)
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_kernel(data_type):
torch.manual_seed(1024)
run_moe_cumsum()
run_moe_dispatch_combine_fwd_bwd(data_type=data_type)