|
|
|
@ -11,6 +11,7 @@ from data_utils import RandomDataset
|
|
|
|
|
from model_utils import format_numel_str, get_model_numel |
|
|
|
|
from performance_evaluator import PerformanceEvaluator, get_profile_context |
|
|
|
|
from tqdm import tqdm |
|
|
|
|
from transformers import AutoConfig |
|
|
|
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
@ -20,6 +21,7 @@ from colossalai.booster.plugin import MoeHybridParallelPlugin
|
|
|
|
|
from colossalai.cluster import DistCoordinator |
|
|
|
|
from colossalai.lazy import LazyInitContext |
|
|
|
|
from colossalai.nn.optimizer import HybridAdam |
|
|
|
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph |
|
|
|
|
from colossalai.shardformer import PipelineGradientCheckpointConfig |
|
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
@ -85,7 +87,7 @@ def main():
|
|
|
|
|
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled") |
|
|
|
|
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) |
|
|
|
|
|
|
|
|
|
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) |
|
|
|
|
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"]) |
|
|
|
|
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) |
|
|
|
|
parser.add_argument("--profile", action="store_true", help="Profile the code") |
|
|
|
|
parser.add_argument( |
|
|
|
@ -120,7 +122,7 @@ def main():
|
|
|
|
|
num_ckpt_layers_per_stage=[19, 19, 19, 13], |
|
|
|
|
), |
|
|
|
|
"num_layers_per_stage": [19, 20, 20, 21], |
|
|
|
|
"pp_style": "interleaved", |
|
|
|
|
# "pp_style": "interleaved", |
|
|
|
|
} |
|
|
|
|
if args.custom_ckpt |
|
|
|
|
else {} |
|
|
|
@ -129,7 +131,29 @@ def main():
|
|
|
|
|
# ============================== |
|
|
|
|
# Initialize Booster |
|
|
|
|
# ============================== |
|
|
|
|
if args.config in MODEL_CONFIGS: |
|
|
|
|
config = MODEL_CONFIGS[args.config] |
|
|
|
|
else: |
|
|
|
|
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) |
|
|
|
|
|
|
|
|
|
if args.plugin == "3d": |
|
|
|
|
if args.pp_style == "zbv": |
|
|
|
|
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length |
|
|
|
|
mem_w = -32 * config.hidden_size |
|
|
|
|
mem_b = -mem_w - mem_f |
|
|
|
|
scheduler_nodes = PipelineGraph( |
|
|
|
|
n_stage=args.pp, |
|
|
|
|
n_micro=args.batch_size // args.mbs, |
|
|
|
|
f_cost=1000, |
|
|
|
|
b_cost=1000, |
|
|
|
|
w_cost=1000, |
|
|
|
|
c_cost=1, |
|
|
|
|
f_mem=mem_f, |
|
|
|
|
b_mem=mem_b, |
|
|
|
|
w_mem=mem_w, |
|
|
|
|
).get_v_schedule() |
|
|
|
|
else: |
|
|
|
|
scheduler_nodes = None |
|
|
|
|
plugin = MoeHybridParallelPlugin( |
|
|
|
|
ep_size=args.ep, |
|
|
|
|
tp_size=args.tp, |
|
|
|
@ -148,6 +172,7 @@ def main():
|
|
|
|
|
overlap_allgather=args.overlap_allgather, |
|
|
|
|
use_fp8=args.use_fp8, |
|
|
|
|
fp8_communication=args.use_fp8_comm, |
|
|
|
|
scheduler_nodes=scheduler_nodes, |
|
|
|
|
**hybrid_kwargs, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|