[feat] support zbv in mixtral benchmark;

pull/6082/head
duanjunwen 2024-10-09 03:58:01 +00:00
parent cc500b3e25
commit 3f5bec8dc4
1 changed files with 27 additions and 2 deletions

View File

@ -11,6 +11,7 @@ from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator, get_profile_context from performance_evaluator import PerformanceEvaluator, get_profile_context
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoConfig
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai import colossalai
@ -20,6 +21,7 @@ from colossalai.booster.plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -85,7 +87,7 @@ def main():
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code") parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument( parser.add_argument(
@ -120,7 +122,7 @@ def main():
num_ckpt_layers_per_stage=[19, 19, 19, 13], num_ckpt_layers_per_stage=[19, 19, 19, 13],
), ),
"num_layers_per_stage": [19, 20, 20, 21], "num_layers_per_stage": [19, 20, 20, 21],
"pp_style": "interleaved", # "pp_style": "interleaved",
} }
if args.custom_ckpt if args.custom_ckpt
else {} else {}
@ -129,7 +131,29 @@ def main():
# ============================== # ==============================
# Initialize Booster # Initialize Booster
# ============================== # ==============================
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
if args.plugin == "3d": if args.plugin == "3d":
if args.pp_style == "zbv":
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
mem_w = -32 * config.hidden_size
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=args.pp,
n_micro=args.batch_size // args.mbs,
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
).get_v_schedule()
else:
scheduler_nodes = None
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
ep_size=args.ep, ep_size=args.ep,
tp_size=args.tp, tp_size=args.tp,
@ -148,6 +172,7 @@ def main():
overlap_allgather=args.overlap_allgather, overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
**hybrid_kwargs, **hybrid_kwargs,
) )
else: else: