diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 0b0d50e28..56b731d13 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -30,6 +30,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): optimizer: Optimizer, model: Module, use_pipeline: bool, + force_overlap_comm: bool, # force overlap comm dp_process_group: ProcessGroup, # dp pg for comm moe_dp_group: ProcessGroup, # moe dp pg for comm param_info: OrderedDict, @@ -48,7 +49,16 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload forced_dtype: Optional[torch.dtype] = None, - ): + ): + + WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result" + if not force_overlap_comm and (overlap_communication or partition_grad): + raise RuntimeError(WARN_STR + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True") + + if force_overlap_comm: + overlap_communication = True + warnings.warn(WARN_STR + " Please make sure of this.") + self.param_info = param_info self.stage_manager = model.stage_manager self.shared_params = model.shared_params @@ -88,7 +98,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): TODO: add docstring """ - def __init__(self, ep_size: int, moe_tp_size: int = 1, *args, **kwargs) -> None: + def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 @@ -120,6 +130,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): # TODO do it in a better way self.shard_config.ep_group = self.ep_group + self.force_overlap_comm = force_overlap_comm + def get_checkpoint_io(self) -> MoECheckpointIO: return MoECheckpointIO( self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage @@ -168,11 +180,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info ) else: - assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + if not(self.dp_size > 1 or self.moe_dp_size > 1): + warnings.warn( + "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " + "If you do not intend to use cpu_offload, please consider set zero_stage=0." + ) optimizer = MoeHybridParallelZeroOptimizer( optimizer, model, use_pipeline=self.enable_pipeline_parallelism, + force_overlap_comm=self.force_overlap_comm, param_info=param_info, dp_process_group=self.dp_group, moe_dp_group=self.moe_dp_group, diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 0d0a606c0..78c34046a 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -110,12 +110,8 @@ class BucketStore(BaseStore): flat_grad = [] for grad_list in self._grad_in_bucket.values(): - if len(grad_list) > 0: - flat_grad.append(_flatten_dense_tensors(grad_list)) - if len(flat_grad) > 0: - flat_grad = _flatten_dense_tensors(flat_grad) - else: - flat_grad = torch.tensor([], device=self.comm_stream.device, dtype=dtype) + flat_grad.append(_flatten_dense_tensors(grad_list)) + flat_grad = _flatten_dense_tensors(flat_grad) return flat_grad def get_param_id_of_grad(self, grad: Tensor) -> int: diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index a13fa120a..b84be034a 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -19,7 +19,6 @@ class GradientStore(BaseStore): """ self._grads_of_params = dict() # stage 2 - self._partition_grads = partition_grad self._working_index = 0 if partition_grad else self._local_rank # for zero2, it's `param_id: [grad_local_rank]` self.grad_to_param_mapping = dict() diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 318a50abe..8bad6ebec 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -648,7 +648,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for group_id in range(self.num_param_groups): param_group = self._working_param_groups[group_id] for param in param_group: - if param.requires_grad and param.grad is not None: + if param.requires_grad: + if param.grad is None: + # for moe params, all experts should have gradient + # TODO better way of doing this + param.grad = torch.zeros_like(param) self._add_to_bucket(param, group_id) self._run_reduction() diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index ba6a0e8a9..b7332a937 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -137,7 +137,7 @@ def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> local_param.data.copy_(all_param.data) -def loose_close(a, b, dtype: torch.dtype = torch.float32): +def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""): rtol = None atol = None if dtype is torch.float16: @@ -150,4 +150,4 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32): a = a.detach().to(dtype) b = b.detach().to(dtype).to(a.device) - assert_close(a, b, rtol=rtol, atol=atol) + assert torch.allclose(a, b, rtol=rtol, atol=atol), f"{name} not close {a.mean()} {b.mean()}" diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 9bc11033a..24fc0a0eb 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -1,238 +1,134 @@ -import os -import warnings -from typing import Dict +from copy import deepcopy import pytest import torch import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralModel import colossalai -from colossalai.accelerator import get_accelerator -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import sync_moe_model_param +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import loose_close -# from colossalai.shardformer.layer import SparseMLP -from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor -from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from tests.test_moe.moe_utils import MoeGradientHandler +NUM_BATCH=4 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS=2 +TOP_K = 2 -def sync_tp_from_local(tp_model, local_model, assert_grad_flag: bool = False) -> None: - """Sync the parameters of tp model from local model - - Args: - tp_model (MoeModule) - local_model (MoeModule) - """ - for (tp_name, tp_param), (local_name, local_param) in zip( - tp_model.named_parameters(), local_model.named_parameters() - ): - assert tp_name == local_name - if not is_moe_tensor(tp_param): - if assert_grad_flag: - assert torch.allclose(tp_param, local_param) - assert torch.allclose(tp_param.grad, local_param.grad) - else: - tp_param.data.copy_(local_param.data) - continue - - tp_rank = get_ep_rank(tp_param) - tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape, local_param.shape)) if d1 != d2][0] - tp_slice = [slice(None)] * tp_dim + [ - slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1)) - ] - - if assert_grad_flag: - assert torch.allclose(tp_param, local_param[tuple(tp_slice)]) - assert torch.allclose(tp_param.grad, local_param.grad[tuple(tp_slice)]) - else: - tp_param.data.copy_(local_param[tuple(tp_slice)].data) +def split_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad -def sync_tp_from_ep(tp_model, ep_model, assert_grad_flag: bool = False) -> None: - """Sync the parameters of tp model from ep model +@parameterize("stage", [1]) +@parameterize("ep_size", [1, 2, 4]) +@parameterize("tp_size", [1, 2, 4]) +def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int=1): + dtype = torch.bfloat16 - Args: - tp_model (MoeModule) - ep_model (MoeModule) - """ - for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()): - assert tp_name == ep_name - if not is_moe_tensor(tp_param): - if assert_grad_flag: - assert torch.allclose(tp_param, ep_param) - assert torch.allclose(tp_param.grad, ep_param.grad) - else: - tp_param.data.copy_(ep_param.data) - continue + rank = torch.distributed.get_rank() + torch.cuda.set_device(dist.get_rank()) - # gather param from ep model - param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) - all_param = torch.cat(param_list, dim=0) - if assert_grad_flag: - grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) - all_grad = torch.cat(grad_list, dim=0) + seed_all(10086) - # get tp param - tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2][0] + 1 - tp_rank = get_ep_rank(tp_param) - tp_slice = [slice(None)] * tp_dim + [ - slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1)) - ] - new_tp_param = all_param[tuple(tp_slice)] - if assert_grad_flag: - new_grad = all_grad[tuple(tp_slice)] - if assert_grad_flag: - assert torch.allclose(tp_param, new_tp_param) - assert torch.allclose(tp_param.grad, new_grad) - else: - tp_param.data.copy_(new_tp_param.data) - - -def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None: - """Sync the parameters of tp model from ep model - - Args: - local_model (MoeModule) - ep_model (MoeModule) - """ - for (local_name, local_param), (ep_name, ep_param) in zip( - local_model.named_parameters(), ep_model.named_parameters() - ): - assert local_name == ep_name - if "experts" not in local_name: - if assert_grad_flag: - assert torch.allclose(local_param, ep_param) - assert torch.allclose(local_param.grad, ep_param.grad) - else: - local_param.data.copy_(ep_param.data) - continue - - # gather param from ep model - param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) - all_param = torch.cat(param_list, dim=0) - if assert_grad_flag: - grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) - all_grad = torch.cat(grad_list, dim=0) - - if assert_grad_flag: - assert torch.allclose(local_param, all_param) - assert torch.allclose(local_param.grad, all_grad) - else: - local_param.data.copy_(all_param.data) - - -def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict): - assert batch_size % world_size == 0 - - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel=None) - local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel="EP") - enable_hierarchical_comm = config.get("enable_hierarchical_comm", False) - if enable_hierarchical_comm: - os.environ["LOCAL_WORLD_SIZE"] = str(world_size) - ep_model = SparseMLP( - num_experts=num_experts, - hidden_size=dim, - intermediate_size=dim * 2, - enable_hierarchical_comm=enable_hierarchical_comm, + config = MixtralConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=2, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + num_local_experts=NUM_EXPERTS, + num_experts_per_tok=TOP_K, ) - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel="TP") - tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) - ep_model = ep_model.to(get_accelerator().get_current_device()) - tp_model = tp_model.to(get_accelerator().get_current_device()) - local_model = local_model.to(get_accelerator().get_current_device()) + torch_model = MixtralModel(config).to(dtype).cuda() - # sync ep param - sync_moe_model_param(ep_model) - dist_dict = MOE_MANAGER.parallel_info_dict - assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group) - assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group) - ep_grad_handler = MoeGradientHandler(ep_model) - # sync local param - sync_local_from_ep(local_model, ep_model) - # sync tp param - sync_tp_from_ep(tp_model, ep_model) - tp_grad_handler = MoeGradientHandler(tp_model) - - rank = dist.get_rank() - input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device()) - micro_batch_size = batch_size // world_size - index = rank * micro_batch_size - # NOTE: ep & tp takes in sharded data for each process - shard_data = input_data.detach()[index : index + micro_batch_size] - - out_local = local_model(input_data) - MOE_MANAGER.reset_loss() - out_tp = tp_model(shard_data) - MOE_MANAGER.reset_loss() - out_ep = ep_model(shard_data) - MOE_MANAGER.reset_loss() - - assert torch.allclose( - out_tp, out_ep, atol=1e-6 - ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}" - try: - out_local_slice = out_local[index : index + micro_batch_size] - assert torch.allclose( - out_ep, out_local_slice, atol=1e-6 - ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}" - except AssertionError: - """ - e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1 - router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2 - However, in ep mode, there are 2 separate routers dealing with sharded data. - Assume router 0 handles token [01] and router 1 handles token [23]. - Note that for each router the capacity is only 1 !!! - Thus, router 0 may yields [0] --> [0] or [1] --> [0], but not both. - The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature. - """ - warnings.warn( - "EP & TP may result in different behavior from local model. " "Please check the comments for details." + zero_model = deepcopy(torch_model).to(dtype) + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) + booster = Booster( + plugin=MoeHybridParallelPlugin( + tp_size=tp_size, + pp_size=1, + ep_size=ep_size, + zero_stage=stage, + overlap_communication=False, + initial_scale=1 ) + ) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) - out_local.mean().backward() - out_tp.mean().backward() - tp_grad_handler.handle_gradient() - out_ep.mean().backward() - ep_grad_handler.handle_gradient() - - assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group) - assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group) - sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) - try: - sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) - except AssertionError: - warnings.warn( - "EP & TP may result in different behavior from local model. " "Please check the comments for details." + booster = Booster( + plugin=HybridParallelPlugin( + tp_size=tp_size, + pp_size=1, + zero_stage=stage, + overlap_communication=False, + initial_scale=1, ) + ) + hybrid_model, hybrid_optimizer, _, _, _ = booster.boost(torch_model, torch.optim.SGD(torch_model.parameters(), lr=1)) + + # create different input + seed_all(1453 + rank) + + hybrid_model.train() + zero_model.train() + for _ in range(2): + # zero-dp forward + input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda() + zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() + # zero-dp backward + zero_optimizer.backward(zero_output) + # torch-ddp forward + hybrid_output = hybrid_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() + loose_close(zero_output, hybrid_output, dtype=dtype) + # torch-ddp backward + hybrid_optimizer.backward(hybrid_output) + + # check grad + name_to_p = {n: p for n, p in hybrid_model.named_parameters()} + for n, p in zero_model.named_parameters(): + zero_grad = zero_optimizer.get_param_grad(p) + if name_to_p[n].grad is None: + name_to_p[n].grad = torch.zeros_like(name_to_p[n]) + continue + loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n) + + # zero-dp step + zero_optimizer.step() + + # original model step + hybrid_optimizer.step() + + # check updated param + for n, p in zero_model.named_parameters(): + loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n) + + print(f"{dist.get_rank()} test passed") + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_zero_with_original_model() -@pytest.mark.skip(reason="moe need to be refactored") @pytest.mark.dist -@pytest.mark.parametrize("num_experts", [4, 64]) -@pytest.mark.parametrize("batch_size", [16]) -@pytest.mark.parametrize("dim", [64]) -@pytest.mark.parametrize( - "config", - [ - {"enable_hierarchical_comm": False}, - {"enable_hierarchical_comm": True}, - ], -) +@pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() -def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict): - spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config) +def test_moe_ep_tp(world_size): + spawn(run_dist, world_size) if __name__ == "__main__": - test_moe_ep_tp(num_experts=8, batch_size=32, dim=32) + test_moe_ep_tp(world_size=4) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py index 2e6d0d786..3d6af2b1a 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py @@ -5,20 +5,20 @@ import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from transformers.models.mixtral.configuration_mixtral import MixtralConfig -from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from transformers.models.mixtral.modeling_mixtral import MixtralModel import colossalai +from colossalai.booster.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock -from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all -from colossalai.zero import LowLevelZeroOptimizer from tests.test_moe.moe_utils import loose_close -tokens, n_experts = 7, 4 -hidden_size = 8 -top_k = 2 +NUM_BATCH=4 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS=2 +TOP_K = 2 def split_grad(grad, world_size): @@ -31,94 +31,87 @@ def split_grad(grad, world_size): return splited_grad -@parameterize("stage", [1, 2]) +@parameterize("stage", [1]) @parameterize("ep_size", [1, 2, 4]) def run_zero_with_original_model(stage: int, ep_size: int): - dtype = torch.float16 + dtype = torch.bfloat16 rank = torch.distributed.get_rank() torch.cuda.set_device(dist.get_rank()) + plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=1, + tp_size=1, ep_size=ep_size, + zero_stage=stage, + overlap_communication=False, + initial_scale=1 ) + booster = Booster(plugin=plugin) seed_all(10086) + config = MixtralConfig( - hidden_size=hidden_size, - intermediate_size=hidden_size * 2, - num_local_experts=n_experts, - num_experts_per_tok=top_k, + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=2, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + num_local_experts=NUM_EXPERTS, + num_experts_per_tok=TOP_K, ) - orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda() + torch_model = MixtralModel(config).to(dtype).cuda() - ori_model = DDP( - orig_model.cuda(), + zero_model = deepcopy(torch_model).to(dtype) + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) + + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + ddp_model = DDP( + torch_model.cuda(), process_group=plugin.dp_group, find_unused_parameters=True, # important for torch ddp, not all experts are routed ).cuda() + ddp_optimizer = torch.optim.SGD(ddp_model.parameters(), lr=1) - zero_model = deepcopy(orig_model).to(dtype) - zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group) - - zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) - pg_param_list = {plugin.dp_group: [], plugin.moe_dp_group: []} - for p in zero_model.parameters(): - if is_moe_tensor(p): - pg_param_list[plugin.moe_dp_group].append(p) - else: - pg_param_list[plugin.dp_group].append(p) - - zero_optimizer = LowLevelZeroOptimizer( - zero_optimizer, - pg_to_param_list=pg_param_list, - master_weights=False, - initial_scale=1, - overlap_communication=True, - partition_grad=stage == 2, - ) - - ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1) - - # create + # create different input seed_all(1453 + rank) + ddp_model.train() + zero_model.train() for _ in range(2): # zero-dp forward - input_data = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() - zero_output, _ = zero_model(input_data.to(dtype)) + input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda() + zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() + # zero-dp backward + zero_optimizer.backward(zero_output) # torch-ddp forward - ori_output, _ = ori_model(input_data.to(dtype)) - loose_close(zero_output, ori_output, dtype=dtype) - - # zero-dp backward - zero_optimizer.backward(zero_output.mean().float()) - + ddp_output = ddp_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() + loose_close(zero_output, ddp_output, dtype=dtype) # torch-ddp backward - ori_output.mean().backward() + ddp_output.backward() # check grad - name_to_p = {n: p for n, p in ori_model.module.named_parameters()} + name_to_p = {n: p for n, p in ddp_model.named_parameters()} for n, p in zero_model.named_parameters(): + print(f"rank {dist.get_rank()} {n}") zero_grad = zero_optimizer.get_param_grad(p) if name_to_p[n].grad is None: - assert zero_grad is None + name_to_p[n].grad = torch.zeros_like(name_to_p[n].data) continue - - loose_close(zero_grad, name_to_p[n].grad, dtype=dtype) + loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n) # zero-dp step zero_optimizer.step() # original model step - ori_optimizer.step() + ddp_optimizer.step() # check updated param for n, p in zero_model.named_parameters(): - loose_close(p.data, name_to_p[n].data, dtype=dtype) + loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n) print(f"{dist.get_rank()} test passed") @@ -131,9 +124,9 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() -def test_moe_zero_model(world_size): +def test_moe_ep_tp(world_size): spawn(run_dist, world_size) if __name__ == "__main__": - test_moe_zero_model(world_size=4) + test_moe_ep_tp(world_size=4) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 70b576908..4e9d3878b 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -113,65 +113,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { - "tp_size": 2, - "pp_size": 1, - "ep_size": 1, - "zero_stage": 2, - "precision": "fp32", - }, # [dp(2) + tp(2)] + [moe_dp(4)] - { - "tp_size": 2, - "pp_size": 1, - "ep_size": 2, - "zero_stage": 2, - "precision": "fp32", - }, # [dp(2) + tp(2)] + [ep(2) + moe_dp(2)] { "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, + "pp_size": 1, "ep_size": 1, "zero_stage": 2, "precision": "fp32", }, # [dp(2) + pp(2)] + [moe_dp(4)] - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "ep_size": 1, - "zero_stage": 2, - "precision": "fp32", - }, # [dp(2) + pp(2)] + [moe_dp(4)] - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "ep_size": 4, - "zero_stage": 2, - "precision": "fp32", - }, # [dp(2) + pp(2)] + [ep(4))] - { - "tp_size": 1, - "pp_size": 1, - "ep_size": 2, - "zero_stage": 2, - "precision": "fp32", - }, # [dp(4)] + [ep(2) + moe_tp(2)] - { - "tp_size": 1, - "pp_size": 1, - "ep_size": 4, - "zero_stage": 2, - "precision": "fp32" - }, # full dp for non-moe and full ep for moe - { - "tp_size": 1, - "pp_size": 1, - "ep_size": 1, - "zero_stage": 2, - "precision": "fp32" - }, # full dp for moe and non-moe + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 2, + # "ep_size": 1, + # "zero_stage": 1, + # "precision": "fp32", + # }, # [dp(2) + pp(2)] + [moe_dp(4)] + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 2, + # "ep_size": 4, + # "zero_stage": 1, + # "precision": "fp32", + # }, # [dp(2) + pp(2)] + [ep(4))] + # { + # "tp_size": 1, + # "pp_size": 1, + # "ep_size": 2, + # "zero_stage": 0, + # "precision": "fp32", + # }, # [dp(4)] + [ep(2) + moe_tp(2)] + # { + # "tp_size": 1, + # "pp_size": 1, + # "ep_size": 4, + # "zero_stage": 0, + # "precision": "fp32" + # }, # full dp for non-moe and full ep for moe ], ) def run_mixtral_test(test_config):