mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
94 lines
3.2 KiB
94 lines
3.2 KiB
import os |
|
|
|
import pytest |
|
import torch |
|
|
|
from colossalai.accelerator import get_accelerator |
|
from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum |
|
|
|
NUM_EXPERTS = 4 |
|
BATCH_SIZE = 4 |
|
SEQ_LEN = 4 |
|
|
|
MOE_TENSOR_PATH = os.getenv("MOE_TENSOR_PATH") |
|
|
|
|
|
def check_equal(tensor_a, tensor_b, atol=1e-06): |
|
assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True |
|
|
|
|
|
def run_moe_cumsum(): |
|
test_mask = torch.tensor( |
|
[ |
|
[0, 1, 0, 0], |
|
[1, 0, 0, 0], |
|
[0, 1, 0, 0], |
|
[1, 0, 0, 0], |
|
], |
|
dtype=torch.int32, |
|
).to("cuda") |
|
out_no_kernel = moe_cumsum(test_mask, use_kernel=False) |
|
out_kernel = moe_cumsum(test_mask, use_kernel=True) |
|
print(out_no_kernel.dtype, out_kernel.dtype) |
|
check_equal(out_no_kernel.to(torch.int32), out_kernel) |
|
|
|
|
|
def run_moe_dispatch_combine_fwd_bwd(data_type=torch.float32, hidden_size=128, num_experts=4): |
|
tokens = torch.randn( |
|
BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True |
|
) |
|
|
|
# use kernel |
|
route_result_list_kernel = torch.load(f"{MOE_TENSOR_PATH}/True_4_{data_type}.pt") |
|
# dispatch |
|
dispatch_data_kernel = MoeDispatch.apply(tokens, *route_result_list_kernel[1:]) |
|
dispatch_data_kernel = dispatch_data_kernel.reshape(num_experts, -1, hidden_size) |
|
# combine |
|
expert_output = dispatch_data_kernel.reshape(-1, hidden_size) |
|
ans_kernel = MoeCombine.apply(expert_output, *route_result_list_kernel) |
|
|
|
# no kernel |
|
route_result_list_no_kernel = torch.load(f"{MOE_TENSOR_PATH}/False_2_{data_type}.pt") |
|
# dispatch |
|
sec_mask_f = route_result_list_no_kernel[1].type_as(tokens) |
|
dispatch_data_no_kernel = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) |
|
# combine |
|
combine_weights = route_result_list_no_kernel[0].type_as(tokens) |
|
combine_weights = combine_weights.view(combine_weights.shape[0], -1) |
|
expert_output = expert_output.view(-1, expert_output.shape[-1]) |
|
ans_no_kernel = torch.matmul(combine_weights, expert_output) |
|
|
|
# check fwd |
|
if data_type == torch.float32: |
|
check_equal(dispatch_data_kernel.reshape(dispatch_data_no_kernel.shape), dispatch_data_no_kernel) |
|
else: |
|
check_equal(dispatch_data_kernel.reshape(dispatch_data_no_kernel.shape), dispatch_data_no_kernel, 1e-2) |
|
|
|
if data_type == torch.float32: |
|
check_equal(ans_kernel, ans_no_kernel) |
|
else: |
|
check_equal(ans_kernel, ans_no_kernel, 1e-2) |
|
|
|
# check bwd |
|
out_shape = ans_kernel.shape |
|
grad = torch.randn(out_shape, device=get_accelerator().get_current_device()) |
|
|
|
ans_kernel.backward(grad, retain_graph=True) |
|
grad_kernel = tokens.grad.data.clone() |
|
tokens.grad.zero_() |
|
|
|
ans_no_kernel.backward(grad) # get gradient |
|
grad_no_kernel = tokens.grad.data.clone() |
|
tokens.grad.zero_() |
|
|
|
if data_type == torch.float32: |
|
check_equal(grad_no_kernel, grad_kernel) |
|
else: |
|
check_equal(grad_no_kernel, grad_kernel, 1e-2) |
|
|
|
|
|
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) |
|
def test_moe_kernel(data_type): |
|
torch.manual_seed(1024) |
|
run_moe_cumsum() |
|
run_moe_dispatch_combine_fwd_bwd(data_type=data_type)
|
|
|