diff --git a/examples/language/mixtral/benchmark.py b/examples/language/mixtral/benchmark.py index bb2a32d01..7c8a5fe65 100644 --- a/examples/language/mixtral/benchmark.py +++ b/examples/language/mixtral/benchmark.py @@ -11,6 +11,7 @@ from data_utils import RandomDataset from model_utils import format_numel_str, get_model_numel from performance_evaluator import PerformanceEvaluator, get_profile_context from tqdm import tqdm +from transformers import AutoConfig from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM import colossalai @@ -20,6 +21,7 @@ from colossalai.booster.plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from colossalai.shardformer import PipelineGradientCheckpointConfig warnings.filterwarnings("ignore") @@ -85,7 +87,7 @@ def main(): parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) - parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) + parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"]) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--profile", action="store_true", help="Profile the code") parser.add_argument( @@ -120,7 +122,7 @@ def main(): num_ckpt_layers_per_stage=[19, 19, 19, 13], ), "num_layers_per_stage": [19, 20, 20, 21], - "pp_style": "interleaved", + # "pp_style": "interleaved", } if args.custom_ckpt else {} @@ -129,7 +131,29 @@ def main(): # ============================== # Initialize Booster # ============================== + if args.config in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.config] + else: + config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) + if args.plugin == "3d": + if args.pp_style == "zbv": + mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length + mem_w = -32 * config.hidden_size + mem_b = -mem_w - mem_f + scheduler_nodes = PipelineGraph( + n_stage=args.pp, + n_micro=args.batch_size // args.mbs, + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + ).get_v_schedule() + else: + scheduler_nodes = None plugin = MoeHybridParallelPlugin( ep_size=args.ep, tp_size=args.tp, @@ -148,6 +172,7 @@ def main(): overlap_allgather=args.overlap_allgather, use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, + scheduler_nodes=scheduler_nodes, **hybrid_kwargs, ) else: