mirror of https://github.com/hpcaitech/ColossalAI
[feat] support zbv in mixtral benchmark;
parent
cc500b3e25
commit
3f5bec8dc4
|
@ -11,6 +11,7 @@ from data_utils import RandomDataset
|
||||||
from model_utils import format_numel_str, get_model_numel
|
from model_utils import format_numel_str, get_model_numel
|
||||||
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoConfig
|
||||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
@ -20,6 +21,7 @@ from colossalai.booster.plugin import MoeHybridParallelPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
@ -85,7 +87,7 @@ def main():
|
||||||
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
|
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
|
||||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||||
|
|
||||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
|
||||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||||
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -120,7 +122,7 @@ def main():
|
||||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||||
),
|
),
|
||||||
"num_layers_per_stage": [19, 20, 20, 21],
|
"num_layers_per_stage": [19, 20, 20, 21],
|
||||||
"pp_style": "interleaved",
|
# "pp_style": "interleaved",
|
||||||
}
|
}
|
||||||
if args.custom_ckpt
|
if args.custom_ckpt
|
||||||
else {}
|
else {}
|
||||||
|
@ -129,7 +131,29 @@ def main():
|
||||||
# ==============================
|
# ==============================
|
||||||
# Initialize Booster
|
# Initialize Booster
|
||||||
# ==============================
|
# ==============================
|
||||||
|
if args.config in MODEL_CONFIGS:
|
||||||
|
config = MODEL_CONFIGS[args.config]
|
||||||
|
else:
|
||||||
|
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||||
|
|
||||||
if args.plugin == "3d":
|
if args.plugin == "3d":
|
||||||
|
if args.pp_style == "zbv":
|
||||||
|
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
|
||||||
|
mem_w = -32 * config.hidden_size
|
||||||
|
mem_b = -mem_w - mem_f
|
||||||
|
scheduler_nodes = PipelineGraph(
|
||||||
|
n_stage=args.pp,
|
||||||
|
n_micro=args.batch_size // args.mbs,
|
||||||
|
f_cost=1000,
|
||||||
|
b_cost=1000,
|
||||||
|
w_cost=1000,
|
||||||
|
c_cost=1,
|
||||||
|
f_mem=mem_f,
|
||||||
|
b_mem=mem_b,
|
||||||
|
w_mem=mem_w,
|
||||||
|
).get_v_schedule()
|
||||||
|
else:
|
||||||
|
scheduler_nodes = None
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
ep_size=args.ep,
|
ep_size=args.ep,
|
||||||
tp_size=args.tp,
|
tp_size=args.tp,
|
||||||
|
@ -148,6 +172,7 @@ def main():
|
||||||
overlap_allgather=args.overlap_allgather,
|
overlap_allgather=args.overlap_allgather,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
scheduler_nodes=scheduler_nodes,
|
||||||
**hybrid_kwargs,
|
**hybrid_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue