ColossalAI/applications/ColossalMoE/train.py

385 lines
14 KiB
Python
Raw Normal View History

2023-12-14 09:52:05 +00:00
import argparse
import torch
import torch.distributed as dist
2023-12-18 02:37:07 +00:00
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
2023-12-14 09:52:05 +00:00
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
2023-12-25 08:05:42 +00:00
from colossal_moe.models.mixtral_layer import replace_moe_layer
2023-12-14 09:52:05 +00:00
from torch.utils.data import Dataset
from tqdm import tqdm
2023-12-15 08:32:32 +00:00
from transformers import AutoTokenizer
2023-12-14 09:52:05 +00:00
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
2023-12-15 08:32:32 +00:00
from colossalai.lazy import LazyInitContext
from colossalai.moe import MOE_MANAGER, apply_load_balance
2023-12-25 08:05:42 +00:00
# from colossalai.nn.optimizer import HybridAdam
from torch.optim import Adam as HybridAdam
2023-12-14 09:52:05 +00:00
from colossalai.utils import get_current_device
2023-12-25 08:05:42 +00:00
import argparse
import os
from functools import partial
from typing import Dict
import torch
import torch.distributed as dist
from datasets import load_dataset
from huggingface_hub import snapshot_download
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
2023-12-14 09:52:05 +00:00
2023-12-25 08:05:42 +00:00
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device
2023-12-14 09:52:05 +00:00
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
2023-12-25 08:05:42 +00:00
def load_ckpt(repo_name: str, model, booster: Booster):
ckpt_path = snapshot_download(repo_name)
# single ckpt
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
# shard ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
booster.load_model(model, ckpt_path)
2023-12-14 09:52:05 +00:00
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="base",
choices=["base", "8b", "test"],
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
choices=["ep", "ep_zero", "hybrid"],
help="Parallel methos. ep_zero is recommended for general cases. ep can provides least memory consumption and hybrid suits large scale training.",
)
parser.add_argument(
"--output_path",
type=str,
default="./outputs",
help="The path of your saved model after finetuning.",
)
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size (per dp group) for the training dataloader.",
)
parser.add_argument(
"--save_interval",
type=int,
default=1000,
help=" The interval (steps) of saving checkpoints.",
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
choices=["fp32", "bf16", "fp16"],
help="The mixed precision training.",
)
parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument(
"--dataset",
type=str,
default="yizhongw/self_instruct",
help="dataset name from `datasets` repo.",
)
parser.add_argument(
"--task_name",
type=str,
default="super_natural_instructions",
help="task of corresponding dataset.",
)
# optim
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
# zero stage for all plugins
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
# ep_zero plugin
parser.add_argument(
"--extra_dp_size", type=int, default=1, help="ep_zero plugin's moe dp size. Recommended to be 2 or 4."
)
# hybrid plugin
parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
)
parser.add_argument(
"--use_layernorm_kernel",
action="store_true",
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
)
# loss
parser.add_argument(
"--router_aux_loss_factor",
type=float,
default=0.01,
help="Moe router z loss. You can refer to STMoE for details.",
)
parser.add_argument(
"--router_z_loss_factor",
type=float,
default=0.0001,
help="Moe router aux loss. You can refer to STMoE for details.",
)
parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing.")
parser.add_argument(
"--z_loss_factor", type=float, default=0.0001, help="The final outputs' classification z loss factor."
)
# load balance
parser.add_argument(
"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
)
parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
# communicate overlap
parser.add_argument(
"--comm_overlap",
action="store_true",
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
)
# hierarchical all-to-all
parser.add_argument(
"--hierarchical_alltoall",
action="store_true",
help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": MixtralForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"zero_stage": args.zero_stage,
2023-12-18 02:37:07 +00:00
"checkpoint_io": MixtralMoECheckpointIO,
2023-12-14 09:52:05 +00:00
}
mgr_dict = {}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(
pp_size=1,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size,
**mgr_dict,
)
elif args.plugin == "ep_zero":
dp_size = dist.get_world_size()
use_ep_inside = False
plugin = MoeHybridParallelPlugin(
pp_size=1,
extra_dp_size=args.extra_dp_size,
use_ep_inside=use_ep_inside,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size // args.extra_dp_size,
use_ep_inside=use_ep_inside,
**mgr_dict,
)
elif args.plugin == "hybrid":
dp_size = dist.get_world_size() // args.pp_size
plugin = MoeHybridParallelPlugin(
pp_size=args.pp_size,
microbatch_size=args.microbatch_size,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=args.dp_size,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build OpenMoe model
2023-12-15 08:32:32 +00:00
# config = MixtralConfig(
# hidden_size=2048,
# intermediate_size=4096,
# num_hidden_layers=4,
# num_attention_heads=4,
# num_key_value_heads=4,
# use_cache=False,
# )
2023-12-25 08:05:42 +00:00
config = MixtralConfig.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
2023-12-15 08:32:32 +00:00
config.use_cache = False
2023-12-25 08:05:42 +00:00
config.num_local_experts = 1
# torch.set_default_tensor_type(torch.float16)
model = MixtralForCausalLM(config)
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
model = model.to(get_current_device())
replace_moe_layer(model)
# torch.set_default_tensor_type(torch.float32)
print(f"0-2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
2023-12-14 09:52:05 +00:00
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
2023-12-25 08:05:42 +00:00
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
2023-12-14 09:52:05 +00:00
dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)
collate_fn = None
dataloader = plugin.prepare_dataloader(
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
)
2023-12-25 08:05:42 +00:00
torch.cuda.synchronize()
print(f"1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
2023-12-14 09:52:05 +00:00
# Set optimizer
2023-12-15 08:32:32 +00:00
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
2023-12-14 09:52:05 +00:00
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
2023-12-25 08:05:42 +00:00
torch.cuda.synchronize()
print(f"2-1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
load_ckpt("mistralai/Mixtral-8x7B-v0.1", model, booster)
torch.cuda.synchronize()
print(f"2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
2023-12-14 09:52:05 +00:00
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
coordinator.print_on_master(f"Finish init booster")
# Start finetuning
coordinator.print_on_master(f"Start finetuning")
for epoch in range(args.num_epoch):
model.train()
train_dataloader_iter = iter(dataloader)
total_len = len(train_dataloader_iter)
with tqdm(
range(total_len),
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
disable=not coordinator.is_master(),
) as pbar:
for step in pbar:
if use_pipeline:
# Forward pass
outputs = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()})
else:
# Forward pass
data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data)
loss = outputs["loss"]
2023-12-25 08:05:42 +00:00
print(f"3 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
2023-12-14 09:52:05 +00:00
# Backward
booster.backward(loss, optimizer)
2023-12-25 08:05:42 +00:00
print(f"4 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
2023-12-14 09:52:05 +00:00
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
optimizer.zero_grad()
# Apply load balance
if (
args.load_balance
and args.load_balance_interval > 0
and (step + 1) % args.load_balance_interval == 0
):
coordinator.print_on_master(f"Apply load balance")
apply_load_balance(model, optimizer)
# save ckeckpoint
if (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
booster.save_model(model, args.output_path, shard=True)
# save checkpoint at the end of each epochs
2023-12-18 02:37:07 +00:00
booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
2023-12-14 09:52:05 +00:00
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
# Finish training
coordinator.print_on_master(f"Finish training")
if __name__ == "__main__":
main()