diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 74d35f5c5..2324a5239 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -64,13 +64,18 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer): forced_dtype: Optional[torch.dtype] = None, overlap_allgather: bool = False, ): - pg_param_list = { - dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())), - moe_dp_group: list(filter(is_moe_tensor, model.parameters())), - } + if dp_process_group is moe_dp_group: + pg_param_list = { + dp_process_group: list(model.parameters()), + } + else: + pg_param_list = { + dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())), + moe_dp_group: list(filter(is_moe_tensor, model.parameters())), + } - if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0: - raise ValueError("No parameters found in dp_process_group or moe_dp_group") + if len(pg_param_list[moe_dp_group]) == 0: + raise ValueError("No parameters found in moe_dp_group, please consider using HybridParallelPlugin instead") super().__init__( model=model, @@ -407,6 +412,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): and self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all" ) + + # sync gradients across DP * SP ranks + if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": + dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis]) + else: + dp_group = self.dp_group + if use_ddp: self.logger.warning( f"Will have to check all params are used in pytorch DDP since not all experts are always activated", @@ -414,17 +426,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) self.ddp_config["find_unused_parameters"] = True - if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group): + if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group): raise ValueError( - f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0" + f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this" ) - # sync gradients across DP * SP ranks - if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": - dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis]) - else: - dp_group = self.dp_group - model = HybridParallelModule( module=model, precision=self.precision, @@ -466,6 +472,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): tp_process_group=self.tp_group, ) else: + is_zero = True if self.dp_size <= 1: self.logger.warning( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index ba087a03b..62904d90e 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -308,7 +308,7 @@ class EPGradScalerIn(torch.autograd.Function): assert len(grad_outputs) == 1 grad = grad_outputs[0] if ctx.ep_size != 1: - grad = grad * ctx.ep_size + grad.mul_(ctx.ep_size) return grad, None @@ -328,7 +328,7 @@ class EPGradScalerOut(torch.autograd.Function): assert len(grad_outputs) == 1 grad = grad_outputs[0] if ctx.ep_size != 1: - grad = grad / ctx.ep_size + grad.div_(ctx.ep_size) return grad, None @@ -449,7 +449,4 @@ def all_to_all_uneven( overlap: bool = False, fp8_communication: bool = False, ): - assert ( - inputs.requires_grad - ), "Input must require grad to assure that backward is executed, otherwise it might hang the program." return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication) diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 7ec390d6a..4b1b82b7c 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist -import torch.nn as nn +import torch.functional as F from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache @@ -28,11 +28,13 @@ from colossalai.quantization.fp8 import all_reduce_fp8 from colossalai.shardformer.layer._operation import ( all_to_all_comm, gather_forward_split_backward, + linear_with_async_comm, split_forward_gather_backward, ) -from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard.utils import set_tensors_to_none +from colossalai.tensor.d_tensor.api import shard_rowwise, sharded_tensor_to_existing_param from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group @@ -58,7 +60,7 @@ class AddAuxiliaryLoss(torch.autograd.Function): return grad_output, grad_loss -class EPDeepseekMoE(nn.Module): +class EPDeepseekMoE(ParallelModule): def __init__(self): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") @@ -214,6 +216,79 @@ class EPDeepseekMoE(nn.Module): return output_hidden_states +class DeepseekMoEGate_Col(ParallelModule): + def parallel_linear(self, hidden_states): + assert ( + hidden_states.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + hidden_states.shape, self.weight.shape, self.weight.shape[-1] + ) + + output = linear_with_async_comm( + hidden_states, self.weight, None, self.process_group, True, fp8_communication=self.fp8_communication + ) + + # All-gather across the partitions. + output = gather_forward_split_backward( + output, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) + return output + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = self.parallel_linear(hidden_states) + if self.scoring_func == "softmax": + scores = logits.softmax(dim=-1) + else: + raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") + + ### select top-k experts + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + + ### expert-level computation auxiliary loss + if self.training and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) + ce.scatter_add_( + 1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device) + ).div_(seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha + else: + mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = None + + return topk_idx, topk_weight, aux_loss + + @staticmethod + def from_native_module( + module, process_group: ProcessGroup, config, gather_output, fp8_communication + ) -> "DeepseekMoEGate_Col": + LazyInitContext.materialize(module) + module.process_group = process_group + module.fp8_communication = fp8_communication + sharded_weight = shard_rowwise(module.weight.data, process_group) + sharded_tensor_to_existing_param(sharded_weight, module.weight) + module.__class__ = DeepseekMoEGate_Col + return module + + class DeepseekPipelineForwards: """ This class serves as a micro library for forward function substitution of Llama models diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 4850ef1b6..0103808dc 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -36,7 +36,7 @@ from colossalai.shardformer.layer._operation import ( gather_forward_split_backward, split_forward_gather_backward, ) -from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard.utils import set_tensors_to_none from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group @@ -49,7 +49,7 @@ if is_flash_attn_2_available(): _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) -class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): +class EPMixtralSparseMoeBlock(ParallelModule): def __init__(self, *args, **kwargs): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 0b8a602d1..bd54e6f2d 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -10,6 +10,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D from colossalai.shardformer.layer.linear import Linear1D_Row from colossalai.shardformer.modeling.deepseek import ( + DeepseekMoEGate_Col, DeepseekPipelineForwards, EPDeepseekMoE, get_deepseek_flash_attention_forward, @@ -56,16 +57,24 @@ class DeepseekPolicy(Policy): sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None sp_partial_derived = sp_mode in ["split_gather", "ring"] + tp_size = self.shard_config.tensor_parallel_size + + # modified for both SP and TP + num_q_heads = self.model.config.num_attention_heads + num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) if sp_mode == "all_to_all": + num_q_heads //= sp_size decoder_attribute_replacement = { - "num_heads": self.model.config.num_attention_heads // sp_size, + "num_heads": num_q_heads, } if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + num_kv_heads //= sp_size + decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) + if self.shard_config.enable_sequence_parallelism: if self.pipeline_stage_manager is not None: # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism @@ -97,6 +106,7 @@ class DeepseekPolicy(Policy): else: if self.tie_weight: embedding_cls = PaddingEmbedding + if self.shard_config.enable_tensor_parallelism: # tensor parallelism for non-moe params assert ( @@ -107,10 +117,15 @@ class DeepseekPolicy(Policy): ), f"The number of key_value heads must be divisible by tensor parallel size." decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attn.num_key_value_heads": self.model.config.num_key_value_heads - // self.shard_config.tensor_parallel_size, } + num_q_heads //= tp_size + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": num_q_heads, + } + if num_kv_heads: + num_kv_heads //= tp_size + decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads policy["DeepseekDecoderLayer"] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -135,8 +150,19 @@ class DeepseekPolicy(Policy): target_module=Linear1D_Row, kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), + SubModuleReplacementDescription( + suffix="mlp.gate", + target_module=DeepseekMoEGate_Col, + kwargs={ + "gather_output": True, + "fp8_communication": self.shard_config.fp8_communication, + "config": self.model.config, + }, + ignore_if_not_exist=True, + ), ], ) + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 3a373889c..9f03319e7 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -51,12 +51,20 @@ class MixtralPolicy(Policy): sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None sp_partial_derived = sp_mode in ["split_gather", "ring"] + tp_size = self.shard_config.tensor_parallel_size + + # modified for both SP and TP + num_q_heads = self.model.config.num_attention_heads + num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) + if sp_mode == "all_to_all": + num_q_heads //= sp_size decoder_attribute_replacement = { - "num_heads": self.model.config.num_attention_heads // sp_size, + "num_heads": num_q_heads, } if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + num_kv_heads //= sp_size + decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -101,12 +109,14 @@ class MixtralPolicy(Policy): assert ( self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 ), f"The number of key_value heads must be divisible by tensor parallel size." + num_q_heads //= tp_size decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attn.num_key_value_heads": self.model.config.num_key_value_heads - // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": num_q_heads, } + if num_kv_heads: + num_kv_heads //= tp_size + decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads policy[MixtralDecoderLayer] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -131,7 +141,7 @@ class MixtralPolicy(Policy): target_module=Linear1D_Row, kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), - SubModuleReplacementDescription( # or replicate? + SubModuleReplacementDescription( suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication}, diff --git a/examples/language/deepseek/benchmark.py b/examples/language/deepseek/benchmark.py new file mode 100644 index 000000000..fef181e71 --- /dev/null +++ b/examples/language/deepseek/benchmark.py @@ -0,0 +1,271 @@ +# modified from mixtral benchmark +import argparse +import resource +import time +import warnings +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from data_utils import RandomDataset +from model_utils import format_numel_str, get_model_numel +from performance_evaluator import PerformanceEvaluator, get_profile_context +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM + +import colossalai +from colossalai.accelerator import get_accelerator +from colossalai.booster import Booster +from colossalai.booster.plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.optimizer import HybridAdam +from colossalai.shardformer import PipelineGradientCheckpointConfig + +warnings.filterwarnings("ignore") +# ============================== +# Constants +# ============================== + +# We have lots of llamas for your choice! +MODEL_CONFIGS = { + "100m": lambda: AutoConfig.from_pretrained( + "deepseek-ai/deepseek-moe-16b-base", + max_position_embeddings=4096, + num_hidden_layers=1, + num_attention_heads=32, + intermediate_size=512, + moe_intermediate_size=128, + hidden_size=512, + n_routed_experts=8, + n_shared_experts=4, + num_experts_per_tok=2, + first_k_dense_replace=0, + attn_implementation="flash_attention_2", + trust_remote_code=True, + ), + "7b": lambda: AutoConfig.from_pretrained( + "deepseek-ai/deepseek-moe-16b-base", + max_position_embeddings=4096, + num_hidden_layers=13, + attn_implementation="flash_attention_2", + trust_remote_code=True, + ), + "14b": lambda: AutoConfig.from_pretrained( + "deepseek-ai/deepseek-moe-16b-base", + max_position_embeddings=4096, + num_hidden_layers=26, + attn_implementation="flash_attention_2", + trust_remote_code=True, + ), +} + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", type=str, default="100m", help="Model configuration") + parser.add_argument( + "-p", + "--plugin", + choices=["3d"], + default="3d", + help="Choose which plugin to use", + ) + parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") + parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument( + "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto" + ) + parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb") + parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers") + parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini") + parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") + parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--ep", type=int, default=1, help="Expert parallel size") + parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size") + parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") + parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") + parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled") + parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) + + parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) + parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) + parser.add_argument("--profile", action="store_true", help="Profile the code") + parser.add_argument( + "--nsys", + action="store_true", + help="Use nsys for profiling. \ + You should put something like this before colossalai launch: \ + nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out", + ) + parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") + parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") + parser.add_argument("--no_cache", action="store_true") + parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") + parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear") + parser.add_argument("--overlap_allgather", action="store_true") + parser.add_argument( + "--sp_mode", + default="all_to_all", + choices=["all_to_all"], + help="Sequence parallelism mode", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug mode") + args = parser.parse_args() + + colossalai.launch_from_torch() + coordinator = DistCoordinator() + + # ckpt config for LLaMA3-70B on 64 H100 GPUs + hybrid_kwargs = ( + { + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_ckpt_layers_per_stage=[19, 19, 19, 13], + ), + "num_layers_per_stage": [19, 20, 20, 21], + "pp_style": "interleaved", + } + if args.custom_ckpt + else {} + ) + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "3d": + plugin = MoeHybridParallelPlugin( + ep_size=args.ep, + tp_size=args.tp, + pp_size=args.pp, + pp_style=args.pp_style, + num_model_chunks=args.n_chunks, + zero_stage=args.zero, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + enable_sequence_parallelism=args.sp > 1, + enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, + microbatch_size=args.mbs, + precision="bf16", + enable_metadata_cache=not args.no_cache, + overlap_allgather=args.overlap_allgather, + use_fp8=args.use_fp8, + fp8_communication=args.use_fp8_comm, + **hybrid_kwargs, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + # ============================== + # Initialize Dataset and Dataloader + # ============================== + dp_size = getattr(plugin, "dp_size", coordinator.world_size) + + config = MODEL_CONFIGS[args.config]() + + torch.cuda.manual_seed(42) + + dataset = RandomDataset( + num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42) + + # ============================== + # Initialize Model and Optimizer + # ============================== + init_ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, MoeHybridParallelPlugin) + else nullcontext() + ) + + with init_ctx: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).to(torch.bfloat16) + + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + + model_numel = get_model_numel(model) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + performance_evaluator = PerformanceEvaluator( + model_numel, + model.config.num_hidden_layers, + model.config.hidden_size, + model.config.vocab_size, + args.grad_checkpoint, + args.ignore_steps, + dp_world_size=dp_size, + ) + + optimizer = HybridAdam(model.parameters()) + torch.set_default_dtype(torch.bfloat16) + model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + + torch.set_default_dtype(torch.float) + coordinator.print_on_master( + f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" + ) + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) + + with get_profile_context( + args.profile, + args.ignore_steps, + 1, # avoid creating massive log files + save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", + nsys=args.nsys, + ) as prof: # , distributed_debug_mode(10, enable=True): + if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1: + data_iter = iter(dataloader) + for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): + performance_evaluator.on_step_start(step) + outputs = booster.execute_pipeline( + data_iter, + model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") + optimizer.step() + optimizer.zero_grad() + + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) + prof.step() + print(f"rank {dist.get_rank()} step {step} passed") + else: + for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): + performance_evaluator.on_step_start(step) + outputs = model(**batch) + loss = outputs[0] + del outputs # free memory + + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") + + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + + performance_evaluator.on_step_end(**batch) + prof.step() + + performance_evaluator.on_fit_end() + coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + main() diff --git a/examples/language/deepseek/data_utils.py b/examples/language/deepseek/data_utils.py new file mode 120000 index 000000000..2da9822df --- /dev/null +++ b/examples/language/deepseek/data_utils.py @@ -0,0 +1 @@ +../data_utils.py \ No newline at end of file diff --git a/examples/language/deepseek/model_utils.py b/examples/language/deepseek/model_utils.py new file mode 120000 index 000000000..73c6818a8 --- /dev/null +++ b/examples/language/deepseek/model_utils.py @@ -0,0 +1 @@ +../model_utils.py \ No newline at end of file diff --git a/examples/language/deepseek/performance_evaluator.py b/examples/language/deepseek/performance_evaluator.py new file mode 120000 index 000000000..f4736354b --- /dev/null +++ b/examples/language/deepseek/performance_evaluator.py @@ -0,0 +1 @@ +../performance_evaluator.py \ No newline at end of file diff --git a/examples/language/deepseek/test_ci.sh b/examples/language/deepseek/test_ci.sh new file mode 100755 index 000000000..e69de29bb diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index bb14378ad..0e88fabf1 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -105,7 +105,7 @@ def main(): parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--no_cache", action="store_true") parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") - parser.add_argument("--use_fp8", action="store_true") + parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear") parser.add_argument("--overlap_allgather", action="store_true") parser.add_argument( "--sp_mode", @@ -151,6 +151,7 @@ def main(): max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, use_fp8=args.use_fp8, + fp8_communication=args.use_fp8_comm, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -164,6 +165,7 @@ def main(): enable_async_reduce=not args.disable_async_reduce, enable_flash_attention=args.xformers, use_fp8=args.use_fp8, + fp8_communication=args.use_fp8_comm, ) elif args.plugin == "fsdp": if use_empty_init: @@ -224,6 +226,7 @@ def main(): enable_metadata_cache=not args.no_cache, overlap_allgather=args.overlap_allgather, use_fp8=args.use_fp8, + fp8_communication=args.use_fp8_comm, **hybrid_kwargs, ) elif args.plugin == "3d_cpu": @@ -241,6 +244,7 @@ def main(): precision="bf16", overlap_p2p=args.overlap, use_fp8=args.use_fp8, + fp8_communication=args.use_fp8_comm, ) else: raise ValueError(f"Unknown plugin {args.plugin}") diff --git a/examples/language/mixtral/benchmark.py b/examples/language/mixtral/benchmark.py new file mode 100644 index 000000000..bb2a32d01 --- /dev/null +++ b/examples/language/mixtral/benchmark.py @@ -0,0 +1,259 @@ +# modified from llama benchmark +import argparse +import resource +import time +import warnings +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from data_utils import RandomDataset +from model_utils import format_numel_str, get_model_numel +from performance_evaluator import PerformanceEvaluator, get_profile_context +from tqdm import tqdm +from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM + +import colossalai +from colossalai.accelerator import get_accelerator +from colossalai.booster import Booster +from colossalai.booster.plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.optimizer import HybridAdam +from colossalai.shardformer import PipelineGradientCheckpointConfig + +warnings.filterwarnings("ignore") +# ============================== +# Constants +# ============================== + +# We have lots of llamas for your choice! +MODEL_CONFIGS = { + "100m": MixtralConfig( + max_position_embeddings=4096, + num_hidden_layers=4, + num_attention_heads=32, + intermediate_size=768, + hidden_size=768, + attn_implementation="flash_attention_2", + ), + "7b": MixtralConfig( + max_position_embeddings=4096, + num_hidden_layers=5, + attn_implementation="flash_attention_2", + ), + "14b": MixtralConfig( + max_position_embeddings=4096, + num_hidden_layers=10, + attn_implementation="flash_attention_2", + ), +} + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", type=str, default="100m", help="Model configuration") + parser.add_argument( + "-p", + "--plugin", + choices=["3d"], + default="3d", + help="Choose which plugin to use", + ) + parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") + parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument( + "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto" + ) + parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb") + parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers") + parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini") + parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") + parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--ep", type=int, default=1, help="Expert parallel size") + parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size") + parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") + parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") + parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled") + parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) + + parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) + parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) + parser.add_argument("--profile", action="store_true", help="Profile the code") + parser.add_argument( + "--nsys", + action="store_true", + help="Use nsys for profiling. \ + You should put something like this before colossalai launch: \ + nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out", + ) + parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") + parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") + parser.add_argument("--no_cache", action="store_true") + parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") + parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear") + parser.add_argument("--overlap_allgather", action="store_true") + parser.add_argument( + "--sp_mode", + default="all_to_all", + choices=["all_to_all"], + help="Sequence parallelism mode", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug mode") + args = parser.parse_args() + + colossalai.launch_from_torch() + coordinator = DistCoordinator() + + # ckpt config for LLaMA3-70B on 64 H100 GPUs + hybrid_kwargs = ( + { + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_ckpt_layers_per_stage=[19, 19, 19, 13], + ), + "num_layers_per_stage": [19, 20, 20, 21], + "pp_style": "interleaved", + } + if args.custom_ckpt + else {} + ) + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "3d": + plugin = MoeHybridParallelPlugin( + ep_size=args.ep, + tp_size=args.tp, + pp_size=args.pp, + pp_style=args.pp_style, + num_model_chunks=args.n_chunks, + zero_stage=args.zero, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + enable_sequence_parallelism=args.sp > 1, + enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, + microbatch_size=args.mbs, + precision="bf16", + enable_metadata_cache=not args.no_cache, + overlap_allgather=args.overlap_allgather, + use_fp8=args.use_fp8, + fp8_communication=args.use_fp8_comm, + **hybrid_kwargs, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + # ============================== + # Initialize Dataset and Dataloader + # ============================== + dp_size = getattr(plugin, "dp_size", coordinator.world_size) + + if args.config in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.config] + else: + config = MixtralConfig.from_pretrained(args.config, trust_remote_code=True) + torch.cuda.manual_seed(42) + + dataset = RandomDataset( + num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42) + + # ============================== + # Initialize Model and Optimizer + # ============================== + init_ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, MoeHybridParallelPlugin) + else nullcontext() + ) + + with init_ctx: + model = MixtralForCausalLM(config=config).to(torch.bfloat16) + + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + + model_numel = get_model_numel(model) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + performance_evaluator = PerformanceEvaluator( + model_numel, + model.config.num_hidden_layers, + model.config.hidden_size, + model.config.vocab_size, + args.grad_checkpoint, + args.ignore_steps, + dp_world_size=dp_size, + ) + + optimizer = HybridAdam(model.parameters()) + torch.set_default_dtype(torch.bfloat16) + model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + + torch.set_default_dtype(torch.float) + coordinator.print_on_master( + f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" + ) + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) + + with get_profile_context( + args.profile, + args.ignore_steps, + 1, # avoid creating massive log files + save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", + nsys=args.nsys, + ) as prof: + if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1: + data_iter = iter(dataloader) + for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): + performance_evaluator.on_step_start(step) + outputs = booster.execute_pipeline( + data_iter, + model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") + optimizer.step() + optimizer.zero_grad() + + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) + prof.step() + else: + for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): + performance_evaluator.on_step_start(step) + outputs = model(**batch) + loss = outputs[0] + del outputs # free memory + + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + + performance_evaluator.on_step_end(**batch) + prof.step() + performance_evaluator.on_fit_end() + coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + main() diff --git a/examples/language/mixtral/data_utils.py b/examples/language/mixtral/data_utils.py new file mode 120000 index 000000000..2da9822df --- /dev/null +++ b/examples/language/mixtral/data_utils.py @@ -0,0 +1 @@ +../data_utils.py \ No newline at end of file diff --git a/examples/language/mixtral/model_utils.py b/examples/language/mixtral/model_utils.py new file mode 120000 index 000000000..73c6818a8 --- /dev/null +++ b/examples/language/mixtral/model_utils.py @@ -0,0 +1 @@ +../model_utils.py \ No newline at end of file diff --git a/examples/language/mixtral/performance_evaluator.py b/examples/language/mixtral/performance_evaluator.py new file mode 120000 index 000000000..f4736354b --- /dev/null +++ b/examples/language/mixtral/performance_evaluator.py @@ -0,0 +1 @@ +../performance_evaluator.py \ No newline at end of file diff --git a/examples/language/mixtral/test_ci.sh b/examples/language/mixtral/test_ci.sh new file mode 100755 index 000000000..e69de29bb diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 8c411a33f..dbcd28ab5 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -1,4 +1,12 @@ +import os +import traceback +from contextlib import contextmanager +from time import sleep +from typing import Callable, List, Optional + import torch +import torch.distributed as dist +from torch.utils._pytree import tree_map def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""): @@ -25,7 +33,66 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32): return torch.allclose(a, b, rtol=rtol, atol=atol) -def check_model_equal(model1, model2): +def check_model_equal(model1, model2, dtype): assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())): - assert_loose_close(p1, p2, p1.dtype) + assert_loose_close(p1, p2, dtype, name=name) + + +@contextmanager +def distributed_debug_mode(num_stacks: int = 1, funcs_to_patch: Optional[List[Callable]] = None, enable=True): + if enable: + assert ( + os.environ.get("CUDA_LAUNCH_BLOCKING", "0") == "1" + ), f"Expect CUDA_LAUNCH_BLOCKING=1, got {os.environ.get('CUDA_LAUNCH_BLOCKING', '0')}" + if funcs_to_patch is None: + funcs_to_patch = [ + dist.all_reduce, + dist.all_reduce_coalesced, + dist.all_gather, + dist.all_gather_coalesced, + dist.all_gather_into_tensor, + dist.all_to_all, + dist.all_to_all_single, + dist.reduce_scatter, + ] + + original_funcs = {} + patched_funcs = {} + + def make_patched(func): + def patched_func(*args, **kwargs): + stack = traceback.format_stack() + + def format_node(node): + if isinstance(node, torch.Tensor): + return f"{node.shape}" + elif isinstance(node, list): + return f"[{', '.join([format_node(n) for n in node])}]" + + return str(node) + + args_str, kwargs_str = tree_map(format_node, (args, kwargs)) + en = len(stack) - 1 + st = max(0, en - num_stacks) + dist.barrier() + sleep(0.001 * dist.get_rank()) + print( + f"[Rank {dist.get_rank()}-{func.__name__}-{dist.get_process_group_ranks(kwargs.get('group', dist.group.WORLD))}]: Called from {''.join(stack[st:en])}args={args_str} kwargs={kwargs_str}\n" + ) + dist.barrier() + return func(*args, **kwargs) + + return patched_func + + if enable: + for func in funcs_to_patch: + original_funcs[func.__name__] = getattr(dist, func.__name__) + patched_funcs[func.__name__] = make_patched(func) + setattr(dist, func.__name__, patched_funcs[func.__name__]) + + try: + yield + finally: + for func_name, original_func in original_funcs.items(): + setattr(dist, func_name, original_func) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 89f5d1c64..f3f109192 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -130,7 +130,7 @@ def check_moe_checkpoint(test_config): dist.barrier() if dist.get_rank() == 0: saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype) - check_model_equal(orig_model, saved_model) + check_model_equal(orig_model, saved_model, dtype=dtype) saved_model.save_pretrained(hf_model_dir) dist.barrier() # check load model @@ -138,7 +138,7 @@ def check_moe_checkpoint(test_config): new_optimizer = Adam(new_model.parameters(), lr=1e-3) new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer) booster.load_model(new_model, hf_model_dir) - check_model_equal(model, new_model) + check_model_equal(model, new_model, dtype=dtype) # check save optimizer optimizer.step() diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py index 46da4522f..d782a2a09 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -12,43 +12,25 @@ from transformers import AutoConfig, AutoModel import colossalai from colossalai.booster.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.shardformer.layer.utils import Randomizer from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from tests.test_moe.moe_utils import assert_loose_close, check_model_equal NUM_BATCH = 8 -NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 2 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4 NUM_LAYERS = 4 HIDDEN_SIZE_PER_HEAD = 4 -NUM_HEADS = 4 +NUM_HEADS = 8 TOP_K = 2 -CHECKED_CONFIG = [ # FOR_WORLD=4 - (1, 4, 1, 1, 1), - (1, 1, 4, 1, 1), - (1, 1, 1, 4, 1), - (1, 1, 1, 1, 4), - (0, 1, 4, 1, 1), - (0, 1, 1, 4, 1), - (0, 1, 1, 1, 4), - (1, 2, 1, 1, 1), -] - - -@parameterize( - "config", - [ - (1, 2, 2, 1, 1), - (1, 2, 1, 2, 1), - (1, 2, 1, 1, 2), - ], -) -def run_zero_with_original_model(config: Tuple[int, ...]): +def run_deepseek_commom(config: Tuple[int, ...]): + Randomizer.reset_index() stage, ep_size, pp_size, tp_size, sp_size = config world_size = dist.get_world_size() rank = dist.get_rank() - dtype, precision = torch.float16, "fp16" + dtype, precision = torch.bfloat16, "bf16" torch.cuda.set_device(dist.get_rank()) plugin = MoeHybridParallelPlugin( @@ -60,11 +42,11 @@ def run_zero_with_original_model(config: Tuple[int, ...]): zero_stage=stage, enable_sequence_parallelism=sp_size > 1, sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, - enable_flash_attention=sp_size > 1, overlap_communication=False, initial_scale=1, precision=precision, find_unused_parameters=True, + enable_flash_attention=True, ) dp_size = plugin.dp_size @@ -171,7 +153,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]): dist.barrier() saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda() - check_model_equal(torch_model, saved_model) + check_model_equal(torch_model, saved_model, dtype=dtype) dist.barrier() if rank == world_size - 1: @@ -180,17 +162,77 @@ def run_zero_with_original_model(config: Tuple[int, ...]): print(f"rank {dist.get_rank()} test passed") -def run_dist(rank, world_size, port): +@parameterize( + "config", + [ + # DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp + (0, 1, 4, 1, 1), + (0, 1, 1, 4, 1), + (0, 1, 2, 2, 1), + # zero 1 + (1, 4, 1, 1, 1), + (1, 1, 4, 1, 1), + (1, 1, 1, 4, 1), + (1, 2, 1, 1, 2), + # zero 2 + (2, 4, 1, 1, 1), + (2, 1, 4, 1, 1), + (2, 1, 1, 4, 1), + (2, 2, 1, 1, 2), + ], +) +def run_deepseek_test(config: Tuple[int, ...]): + run_deepseek_commom(config) + + +@parameterize( + "config", + [ + # DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp + (0, 1, 2, 4, 1), + (0, 1, 4, 2, 1), + (0, 1, 1, 4, 1), + (0, 1, 4, 1, 1), + # zero 1: + (1, 2, 1, 1, 2), + (1, 2, 1, 4, 1), + (1, 1, 1, 2, 2), + (1, 2, 2, 2, 1), + # zero 2 + (2, 2, 1, 1, 2), + (2, 2, 1, 4, 1), + (2, 1, 1, 2, 2), + (2, 2, 2, 2, 1), + ], +) +def run_deepseek_3d_test(config: Tuple[int, ...]): + run_deepseek_commom(config) + + +def check_deepseek(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_deepseek_test() + + +def check_deepseek_3d(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_zero_with_original_model() + run_deepseek_3d_test() @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() def test_deepseek(world_size): - spawn(run_dist, world_size) + spawn(check_deepseek, world_size) + + +@pytest.mark.largedist +@pytest.mark.parametrize("world_size", [8]) +@rerun_if_address_is_in_use() +def test_deepseek_3d(world_size): + spawn(check_deepseek_3d, world_size) if __name__ == "__main__": - test_deepseek(world_size=4) + test_deepseek(world_size=8) + test_deepseek_3d(world_size=8) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index de09eedcb..940c66cf6 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -13,42 +13,25 @@ from transformers.models.mixtral.modeling_mixtral import MixtralModel import colossalai from colossalai.booster.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.shardformer.layer.utils import Randomizer from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from tests.test_moe.moe_utils import assert_loose_close, check_model_equal NUM_BATCH = 8 -NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4 NUM_LAYERS = 4 HIDDEN_SIZE_PER_HEAD = 4 -NUM_HEADS = 4 -TOP_K = 1 +NUM_HEADS = 8 +TOP_K = 2 -CHECKED_CONFIG = [ # FOR WORLD=4 - (0, 1, 4, 1, 1), - (0, 1, 1, 4, 1), - (0, 1, 1, 1, 4), - (1, 4, 1, 1, 1), - (1, 1, 4, 1, 1), - (1, 1, 1, 4, 1), - (1, 1, 1, 1, 4), - (1, 2, 1, 1, 1), -] - -@parameterize( - "config", - [ - (1, 2, 2, 1, 1), - (1, 2, 1, 2, 1), - (1, 2, 1, 1, 2), - ], -) -def run_zero_with_original_model(config: Tuple[int, ...]): +def run_mixtral_commom(config: Tuple[int, ...]): + Randomizer.reset_index() stage, ep_size, pp_size, tp_size, sp_size = config world_size = dist.get_world_size() rank = dist.get_rank() - dtype, precision = torch.float16, "fp16" + dtype, precision = torch.bfloat16, "bf16" torch.cuda.set_device(dist.get_rank()) plugin = MoeHybridParallelPlugin( @@ -165,7 +148,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]): dist.barrier() saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype) - check_model_equal(torch_model, saved_model) + check_model_equal(torch_model, saved_model, dtype=dtype) dist.barrier() if rank == world_size - 1: @@ -174,17 +157,78 @@ def run_zero_with_original_model(config: Tuple[int, ...]): print(f"rank {dist.get_rank()} test passed") -def run_dist(rank, world_size, port): +@parameterize( + "config", + [ + # DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp + (0, 1, 4, 1, 1), + (0, 1, 1, 4, 1), + (0, 1, 2, 2, 1), + # zero 1 + (1, 4, 1, 1, 1), + (1, 1, 4, 1, 1), + (1, 1, 1, 4, 1), + (1, 2, 1, 1, 2), + # zero 2 + (2, 4, 1, 1, 1), + (2, 1, 4, 1, 1), + (2, 1, 1, 4, 1), + (2, 2, 1, 1, 2), + ], +) +def run_mixtral_test(config: Tuple[int, ...]): + run_mixtral_commom(config) + + +@parameterize( + "config", + [ + # DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp + (0, 1, 2, 4, 1), + (0, 1, 4, 2, 1), + (0, 1, 1, 4, 1), + (0, 1, 4, 1, 1), + # zero 1: + (1, 2, 1, 1, 2), + (1, 2, 1, 4, 1), + (1, 1, 1, 2, 2), + (1, 2, 2, 2, 1), + # zero 2 + (2, 2, 1, 1, 2), + (2, 2, 1, 4, 1), + (2, 1, 1, 2, 2), + (2, 2, 2, 2, 1), + ], +) +def run_mixtral_3d_test(config: Tuple[int, ...]): + print(f"{config=}") + run_mixtral_commom(config) + + +def check_mixtral(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_mixtral_test() + + +def check_mixtral_3d(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_zero_with_original_model() + run_mixtral_3d_test() @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() def test_mixtral(world_size): - spawn(run_dist, world_size) + spawn(check_mixtral, world_size) + + +@pytest.mark.largedist +@pytest.mark.parametrize("world_size", [8]) +@rerun_if_address_is_in_use() +def test_mixtral_3d(world_size): + spawn(check_mixtral_3d, world_size) if __name__ == "__main__": - test_mixtral(world_size=4) + test_mixtral(world_size=8) + test_mixtral_3d(world_size=8)