diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index a891ee7df..eba156a05 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -149,7 +149,14 @@ class Top1Router(MoeRouter): low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device()) ).rsample - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: + def forward( + self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None, + use_loss: bool = False, + use_norm: bool = False, + ) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 721a4796a..1d8e794e7 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -2,12 +2,21 @@ import torch import torch.distributed as dist import torch.nn as nn +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_moe_epsize_param_dict +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size +from tests.test_moe.moe_utils import MoeModel + + +def delete_moe_info(model): + for _, param in model.named_parameters(): + if hasattr(param, "moe_info"): + delattr(param, "moe_info") class MoeModel(nn.Module): @@ -85,6 +94,58 @@ def assert_not_equal_in_group(tensor, process_group=None): for i in range(world_size - 1): a = tensor_list[i] b = tensor_list[i + 1] - assert not torch.allclose(a, b), \ - (f"expected tensors on rank {i} and {i + 1} not to be equal " - f"but they are, {a} vs {b}") + assert not torch.allclose(a, b), ( + f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}" + ) + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss) + else: + loss.backward() + return y + + +def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + local_model (MoeModule) + ep_model (MoeModule) + """ + for (local_name, local_param), (ep_name, ep_param) in zip( + local_model.named_parameters(), ep_model.named_parameters() + ): + assert local_name == ep_name, print(f"{local_name} != {ep_name}") + if "experts" not in local_name: + if assert_grad_flag: + assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}" + assert torch.allclose(local_param.grad, ep_param.grad) + else: + local_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + if assert_grad_flag: + assert torch.allclose(local_param, all_param) + assert torch.allclose(local_param.grad, all_grad) + else: + local_param.data.copy_(all_param.data) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index 5d9ec1cff..5a7fb5f2c 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -1,85 +1,13 @@ import pytest import torch -import torch.distributed as dist import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER -from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all -from tests.test_moe.moe_utils import MoeModel - - -def split_ddp_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - - -def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - - if isinstance(model, LowLevelZeroModel): - optimizer.backward(loss) - else: - loss.backward() - return y - - -def delete_moe_info(model): - for name, param in model.named_parameters(): - if hasattr(param, "moe_info"): - delattr(param, "moe_info") - - -def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: - """Sync the parameters of tp model from ep model - - Args: - local_model (MoeModule) - ep_model (MoeModule) - """ - for (local_name, local_param), (ep_name, ep_param) in zip( - local_model.named_parameters(), ep_model.named_parameters() - ): - assert local_name == ep_name, print(f"{local_name} != {ep_name}") - if "experts" not in local_name: - if assert_grad_flag: - assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}" - assert torch.allclose(local_param.grad, ep_param.grad) - else: - local_param.data.copy_(ep_param.data) - continue - - # gather param from ep model - param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) - all_param = torch.cat(param_list, dim=0) - if assert_grad_flag: - grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) - all_grad = torch.cat(grad_list, dim=0) - - if assert_grad_flag: - assert torch.allclose(local_param, all_param) - assert torch.allclose(local_param.grad, all_grad) - else: - local_param.data.copy_(all_param.data) +from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep def run_zero_test(local_rank, stage=1):