Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

291 lines
10 KiB

import argparse
import torch
import torch.distributed as dist
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralForCausalLM
[MoE/ZeRO] Moe refactor with zero refactor (#5821) * [moe] removed openmoe-coupled code and rectify mixstral code (#5471) * [Feauture] MoE refractor; Intergration with Mixtral (#5682) * cherry pick from refractor-moe branch * tests passed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support ep + zero --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add mixtral auto policy & move pipeline forward code to modeling folder * [moe refactor] modify kernel test without Route Class * [moe refactor] add moe tensor test path environment variable to github workflow * fix typos * fix moe test bug due to the code rebase * [moe refactor] fix moe zero test, and little bug in low level zero * fix typo * add moe tensor path to github workflow * remove some useless code * fix typo & unify global variable XX_AXIS logic without using -1 * fix typo & prettifier the code * remove print code & support zero 2 test * remove useless code * reanme function * fix typo * fix typo * Further improve the test code * remove print code * [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test * [moe refactor] skip some unit test which will be refactored later * [moe refactor] fix unit import error * [moe refactor] fix circular import issues * [moe refactor] remove debug code * [moe refactor] update github workflow * [moe/zero] refactor low level optimizer (#5767) * [zero] refactor low level optimizer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] MoE refactor with newest version of ZeRO (#5801) * [zero] remove redundant members in BucketStore (#5802) * [zero] align api with previous version * [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * [hotfix]Solve the compatibility issue of zero refactor (#5823) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * Modify function parameter names to resolve compatibility issues * [zero] fix missing hook removal (#5824) * [MoE] Resolve .github conflict (#5829) * [Fix/Example] Fix Llama Inference Loading Data Type (#5763) * [fix/example] fix llama inference loading dtype * revise loading dtype of benchmark llama3 * [release] update version (#5752) * [release] update version * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [test] fix ddp plugin test * [test] fix gptj and rpc test * [devops] fix cuda ext compatibility * [inference] fix flash decoding test * [inference] fix flash decoding test * fix (#5765) * [test] Fix/fix testcase (#5770) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [Hotfix] Add missing init file in inference.executor (#5774) * [CI/tests] simplify some test case to reduce testing time (#5755) * [ci/tests] simplify some test case to reduce testing time * [ci/tests] continue to remove test case to reduce ci time cost * restore some test config * [ci/tests] continue to reduce ci time cost * [misc] update dockerfile (#5776) * [misc] update dockerfile * [misc] update dockerfile * [devops] fix docker ci (#5780) * [Inference]Add Streaming LLM (#5745) * Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolist * [hotfix] fix llama flash attention forward (#5777) * [misc] Accelerate CI for zero and dist optim (#5758) * remove fp16 from lamb * remove d2h copy in checking states --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Test/CI] remove test cases to reduce CI duration (#5753) * [test] smaller gpt2 test case * [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py * [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py * [test] reduce test cases tests/test_zero/test_gemini/test_optim.py * Revert "[test] smaller gpt2 test case" Some tests might depend on the size of model (num of chunks) This reverts commit df705a5210b8901645992adf276e320e48766ebf. * [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py * [CI] smaller test model for two mwo the two modifid cases * [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there * [hotfix] fix testcase in test_fx/test_tracer (#5779) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [fix] fix test_deepfm_model & test_dlrf_model; * [fix] fix test_hf_albert & test_hf_gpt; * [gemini] optimize reduce scatter d2h copy (#5760) * [gemini] optimize reduce scatter d2h copy * [fix] fix missing reduce variable * [refactor] remove legacy async reduce scatter code * [gemini] missing sync * Revert "[refactor] remove legacy async reduce scatter code" This reverts commit 58ad76d4665032bbe548d066116d1c572ce98979. * [gemini] further optimize with async all reduce * [fix] pass flag from manager to chunk * Allow building cuda extension without a device. (#5535) Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are. * [misc] fix dist logger (#5782) * [install]fix setup (#5786) * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] update requirements (#5787) * [shardformer] fix import (#5788) * upgrade colossal-chat support tp_group>1, add sp for sft * upgrade ppo dpo rm script * run pre-commit * moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy * fix training script * fix ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix transformers version * remove duplicated test * fix datasets version * remove models that require huggingface auth from ci * remove local data path * update ci * remove baichuan from template test due to transformer version conflict * merge * Refactor modeling by adding attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix tests and naming Signed-off-by: char-1ee <xingjianli59@gmail.com> * Pass inference model shard configs for module init Signed-off-by: char-1ee <xingjianli59@gmail.com> * Clean up Signed-off-by: char-1ee <xingjianli59@gmail.com> * replace the customized dataloader setup with the build-in one * replace the customized dataloader setup with the build-in one * Remove flash attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * fix readme * Fix test import Signed-off-by: char-1ee <xingjianli59@gmail.com> * update sft trainning script * [Inference]refactor baichuan (#5791) * refactor baichuan * remove unused code and add TODO for lazyinit * [test] fix chatglm test kit (#5793) * [shardformer] fix modeling of bloom and falcon (#5796) * [test] fix qwen2 pytest distLarge (#5797) * [Inference] Fix flash-attn import and add model test (#5794) * Fix torch int32 dtype Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix flash-attn import Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add generalized model test Signed-off-by: char-1ee <xingjianli59@gmail.com> * Remove exposed path to model Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add default value for use_flash_attn Signed-off-by: char-1ee <xingjianli59@gmail.com> * Rename model test Signed-off-by: char-1ee <xingjianli59@gmail.com> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> * [Gemini] Use async stream to prefetch and h2d data moving (#5781) * use async stream to prefetch and h2d data moving * Remove redundant code * [gemini] quick fix on possible async operation (#5803) * [gemini] quick fix on possible async operation * [gemini] quick fix on possible async operation * [shardformer] upgrade transformers to 4.39.3 (#5815) * [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807) * [shardformer] fix modeling of gpt2 and gptj * [shardformer] fix whisper modeling * [misc] update requirements --------- Co-authored-by: ver217 <lhx0217@gmail.com> * [shardformer]upgrade transformers for mistral (#5808) * upgrade transformers for mistral * fix * fix * [shardformer]upgrade transformers for llama (#5809) * update transformers fix * fix * fix * [inference] upgrade transformers (#5810) * update transformers fix * fix * fix * fix * fix * [gemini] update transformers for gemini (#5814) --------- Co-authored-by: ver217 <lhx0217@gmail.com> * Support 4d parallel + flash attention (#5789) * support tp + sp + pp * remove comments --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com> * [zero] fix hook bug * [zero] add low level optimizer back (#5839) * [zero] fix param & refactor * [zero] add back original low level opt * [zero] remove moe related * [zero] pass zero tests * [zero] refactor * [chore] add del func back * [zero] comments and naming (#5840) * [zero] modify api (#5843) * [zero] modify api * [test] remove _grad_store access in tests * [test] fix (#5857) * [CI] skip openmoe CI check * [CI] fox pre-commit * [zero] remove redundant memebr init (#5862) * [misc] remove useless code, modify the pg mesh implementation * [misc] remove useless code, modify the pg mesh implementation * [misc] use tempfile * resolve conflict with main branch * [misc] use tempfile in test_moe_checkpoint.py * [misc] remove useless code, add assertion about sequence parallel, move logger into function * [misc] remove useless code --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
5 months ago
from utils import load_checkpoint, move_to_cuda, save_checkpoint
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
@torch.no_grad()
def get_global_loss(loss, booster):
global_loss = loss.clone().detach()
dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
global_loss.div_(booster.plugin.dp_size)
return global_loss
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="mistralai/Mixtral-8x7B-v0.1",
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
choices=["hybrid"],
help="Parallel methods.",
)
parser.add_argument(
"--output_path",
type=str,
default="./outputs",
help="The path of your saved model after finetuning.",
)
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size (per dp group) for the training dataloader.",
)
parser.add_argument(
"--save_interval",
type=int,
default=1000,
help=" The interval (steps) of saving checkpoints.",
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
choices=["fp32", "bf16", "fp16"],
help="The mixed precision training.",
)
parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
# optim
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
# lr scheduler
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
# zero stage for all plugins
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
# hybrid plugin
parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
)
parser.add_argument(
"--use_layernorm_kernel",
action="store_true",
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
)
# load balance
parser.add_argument(
"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
)
parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
# communicate overlap
parser.add_argument(
"--comm_overlap",
action="store_true",
help="Use communication overlap for MoE. Recommended to enable for multi-node training.",
)
# hierarchical all-to-all
parser.add_argument(
"--hierarchical_alltoall",
action="store_true",
help="Use hierarchical all-to-all for MoE. Recommended to enable for multi-node training.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
if args.plugin == "hybrid":
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=args.pp_size,
ep_size=args.ep_size,
microbatch_size=args.microbatch_size,
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
precision=args.precision,
zero_stage=args.zero_stage,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build Mixtral model
model = MixtralForCausalLM.from_pretrained(args.model_name)
coordinator.print_on_master(f"Finish init model")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
collate_fn = None
dataloader = plugin.prepare_dataloader(
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
)
# Set optimizer
optimizer = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# Set lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * len(dataloader),
warmup_steps=(
args.warmup_steps if args.warmup_steps is not None else int(args.num_epochs * len(dataloader) * 0.025)
),
eta_min=0.1 * args.lr,
)
# Set booster
booster = Booster(plugin=plugin)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
coordinator.print_on_master(f"Finish init booster")
# Load ckpt
if args.load_checkpoint is not None:
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
coordinator.print_on_master(f"Finish load optimizer")
# Start finetuning
coordinator.print_on_master(f"Start finetuning")
for epoch in range(args.num_epoch):
model.train()
train_dataloader_iter = iter(dataloader)
total_len = len(train_dataloader_iter)
with tqdm(
range(total_len),
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
) as pbar:
for step in pbar:
if use_pipeline:
# Forward pass
outputs = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = outputs["loss"]
global_loss = get_global_loss(loss, booster)
if coordinator._local_rank == "0":
pbar.set_postfix({"Loss": global_loss.item()})
else:
# Forward pass
data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data)
loss = outputs["loss"]
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Apply load balance
# if (
# args.load_balance
# and args.load_balance_interval > 0
# and (step + 1) % args.load_balance_interval == 0
# ):
# coordinator.print_on_master(f"Apply load balance")
# apply_load_balance(model, optimizer)
# save checkpoint
if (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
save_checkpoint(
args.output_path,
booster,
model,
optimizer,
lr_scheduler,
epoch,
step,
args.batch_size,
coordinator,
)
# save checkpoint at the end of each epochs
booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
# Finish training
coordinator.print_on_master(f"Finish training")
if __name__ == "__main__":
main()