pull/5190/head
Xuanlei Zhao 2023-12-25 16:05:42 +08:00
parent 7c5b1a585f
commit aa2e091dc6
17 changed files with 407 additions and 1095 deletions

View File

@ -1,14 +1,4 @@
## OpenMoE
[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is the open-source community's first decoder-only MoE transformer. OpenMoE is implemented in Jax, and [Colossal-AI](https://github.com/hpcaitech/ColossalAI) has pioneered an efficient open-source support for this model in PyTorch, enabling a broader range of users to participate in and use this model. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods.
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/MOE_training.png" width=800/>
</p>
* [2023/11] [Enhanced MoE Parallelism, Open-source MoE Model Training Can Be 9 Times More Efficient](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient)
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/openmoe)
[[blog]](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient)
# Mixtral
## Usage
@ -23,116 +13,14 @@ CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
Then install dependencies.
```bash
cd ColossalAI/examples/language/openmoe
pip install -r requirements.txt
cd ColossalAI/applications/ColossalMoE
pip install -e .
```
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code.
### 2. Install kernels (Optional)
We have utilized `Triton`, `FlashAttention` and `Apex` kernel for better performance. They are not necessary but we recommend you to install them to fully utilize your hardware.
```
# install triton via pip
pip install triton
# install flash attention via pip
pip install flash-attn==2.0.5
# install apex from source
git clone https://github.com/NVIDIA/apex.git
cd apex
git checkout 741bdf50825a97664db08574981962d66436d16a
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ --global-option="--cuda_ext"
```
### 3. Train
Yon can use colossalai run to launch single-node training:
### 2. Inference
Yon can use colossalai run to launch inference:
```bash
colossalai run --standalone --nproc_per_node YOUR_GPU_PER_NODE train.py --OTHER_CONFIGURATIONS
```
Yon can also use colossalai run to launch multi-nodes training:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE train.py --OTHER_CONFIGURATIONS
```
Here is a sample hostfile:
```text
hostname1
hostname2
hostname3
hostname4
```
The hostname refers to the ip address of your nodes. Make sure master node can access all nodes (including itself) by ssh without password.
Here is details about CLI arguments:
- Model configuration: `--model_name`. `base` and `8b` are supported for OpenMoE.
- Booster plugin: `--plugin`. `ep`, `ep_zero` and `hybrid` are supported. `ep_zero` is recommended for general cases. `ep` can provides least memory consumption and `hybrid` suits large scale training.
- Output path: `--output_path`. The path to save your model. The default value is `./outputs`.
- Number of epochs: `--num_epochs`. The default value is 1.
- Local batch size: `--batch_size`. Batch size per GPU. The default value is 1.
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
- Mixed precision: `--precision`. The default value is "bf16". "fp16", "bf16" and "fp32" are supported.
- Max length: `--max_length`. Max sequence length. Default to 2048.
- Dataset: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as it.
- Task Name: `--task_name`. Task of corresponding dataset. Default to `super_natural_instructions`.
- Learning rate: `--lr`. The default value is 1e-5.
- Weight decay: `--weight_decay`. The default value is 0.
- Zero stage: `--zero_stage`. Zero stage. Recommend 2 for ep and 1 for ep zero.
- Extra dp size: `--extra_dp_size`. Extra moe param dp size for ep_zero plugin. Recommended to be 2 or 4.
- Use kernel: `--use_kernel`. Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.
- Use layernorm kernel: `--use_layernorm_kernel`. Use layernorm kernel. Need to install apex. Raise error if not installed.
- Router aux loss factor: `--router_aux_loss_factor`. Moe router z loss factor. You can refer to STMoE for details.
- Router z loss factor: `--router_z_loss_factor`. Moe router aux loss factor. You can refer to STMoE for details.
- Label smoothing: `--label_smoothing`. Label smoothing.
- Z loss factor: `--z_loss_factor`. The final outputs' classification z loss factor.
Load balance: `--load_balance`. Expert load balance. Defaults to False. Recommend enabling.
- Load balance interval: `--load_balance_interval`. Expert load balance interval.
- Communication overlap: `--comm_overlap`. Use communication overlap for MoE. Recommended to enable for multi-node training.
### 4. Shell Script Examples
For your convenience, we provide some shell scripts to train with various configurations. Here we will show an example of how to run training
OpenMoE.
#### a. Running environment
This experiment was performed on a single computing nodes with 8 A800 80GB GPUs in total for OpenMoE-8B. The GPUs are fully connected with NVLink.
#### b. Running command
We demonstrate how to run three plugins in `train.sh`. You can choose anyone and use your own args.
```bash
bash train.sh
```
#### c. Multi-Nodes Training
To run on multi-nodes, you can modify the script as:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
train.py --OTHER_CONFIGURATIONS
```
## Reference
```
@article{bian2021colossal,
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
journal={arXiv preprint arXiv:2110.14883},
year={2021}
}
```
```bibtex
@misc{openmoe2023,
author = {Fuzhao Xue, Zian Zheng, Yao Fu, Jinjie Ni, Zangwei Zheng, Wangchunshu Zhou and Yang You},
title = {OpenMoE: Open Mixture-of-Experts Language Models},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/XueFuzhao/OpenMoE}},
}
```
bash infer.sh
```

View File

@ -1,297 +0,0 @@
import argparse
import json
import os
import torch
import torch.distributed as dist
from huggingface_hub import snapshot_download
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from model.openmoe_policy import OpenMoeForCausalLMPolicy
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
ckpt_path = snapshot_download(repo_name)
# single ckpt
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
# shard ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
booster.load_model(model, ckpt_path)
class RandomDataset(Dataset):
def __init__(
self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384, tokenizer: T5Tokenizer = None
):
self.num_samples = num_samples
self.max_length = max_length
if os.path.exists("./mock_data.json"):
self.input_ids = []
self.attention_mask = []
with open("./mock_data.json", "r") as f:
data = json.load(f)
for v in data.values():
d = v["text"]
encode = tokenizer(
"<pad>" + d,
return_tensors="pt",
add_special_tokens=False,
max_length=max_length,
truncation=True,
padding="max_length",
)
self.input_ids.append(encode["input_ids"])
self.attention_mask.append(encode["attention_mask"])
self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device())
self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device())
repeat_times = num_samples // self.input_ids.shape[0] + 1
self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]
self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]
else:
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="base",
choices=["base", "8b"],
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--batch_size",
type=int,
default=4,
help="Batch size (per dp group) for the training dataloader.",
)
parser.add_argument(
"--seq_length",
type=int,
default=2048,
help="sequence length for the training dataloader.",
)
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
help="parallel plugin",
)
# hybrid plugin
parser.add_argument("--pp_size", type=int, default=2, help="pp size")
parser.add_argument("--dp_size", type=int, default=1, help="dp size")
parser.add_argument("--ep_size", type=int, default=2, help="ep size")
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin")
parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size")
parser.add_argument("--extra_dp_size", type=int, default=1)
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.",
)
# bench
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--active", type=int, default=20)
# load balance
parser.add_argument("--load_balance", action="store_true")
# overlap communication
parser.add_argument("--overlap_comm", action="store_true")
# hierarchical all-to-all
parser.add_argument("--hierarchical_alltoall", action="store_true")
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": OpenMoeForCausalLMPolicy(),
"enable_fused_normalization": args.use_kernel,
"enable_jit_fused": args.use_kernel,
"precision": "bf16",
"zero_stage": args.zero_stage,
}
mgr_dict = {}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(
pp_size=1,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size,
**mgr_dict,
)
elif args.plugin == "ep_zero":
dp_size = dist.get_world_size()
use_ep_inside = False
plugin = MoeHybridParallelPlugin(
pp_size=1,
extra_dp_size=args.extra_dp_size,
use_ep_inside=use_ep_inside,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size // args.extra_dp_size,
use_ep_inside=use_ep_inside,
**mgr_dict,
)
elif args.plugin == "hybrid":
dp_size = dist.get_world_size() // args.pp_size
plugin = MoeHybridParallelPlugin(
pp_size=args.pp_size,
zero_stage=args.zero_stage,
microbatch_size=args.microbatch_size,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=args.dp_size,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin}")
# Build OpenMoe model
repo_name = "hpcaitech/openmoe-" + args.model_name
config = LlamaConfig.from_pretrained(repo_name)
set_openmoe_args(
config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_load_balance=args.load_balance,
enable_kernel=args.use_kernel,
enable_comm_overlap=args.overlap_comm,
enable_hierarchical_alltoall=args.hierarchical_alltoall,
)
with skip_init():
model = OpenMoeForCausalLM(config)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
dataset = RandomDataset(
num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size,
max_length=args.seq_length,
tokenizer=tokenizer,
)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size)
# Set optimizer
optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5)
model_numel = get_model_numel(model)
performance_evaluator = PerformanceEvaluator(
model_numel,
enable_grad_checkpoint=True,
ignore_steps=args.warmup,
dp_world_size=dp_size,
)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
load_ckpt(repo_name, model, booster)
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
coordinator.print_on_master(f"Finish init booster")
# Start finetuning
coordinator.print_on_master(f"Start training")
model.train()
train_dataloader_iter = iter(dataloader)
total_len = len(train_dataloader_iter) - 1
exmaple_data = next(train_dataloader_iter)
with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar:
for step in pbar:
performance_evaluator.on_step_start(step)
if use_pipeline:
# Forward pass
outputs = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()})
else:
# Forward pass
data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data)
loss = outputs["loss"]
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(exmaple_data["input_ids"])
if (step == args.warmup // 2) and args.load_balance:
coordinator.print_on_master(f"Apply load balance")
apply_load_balance(model, optimizer)
performance_evaluator.on_fit_end()
if __name__ == "__main__":
main()

View File

@ -1,78 +0,0 @@
#!/bin/bash
set -xue
NUM_GPU=8
MODEL="8b"
SEQ_LENGTH=2048
WARMUP=20
ACTIVE=4
# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi
# ep
echo -e "\n\n Naive EP \n\n"
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 8 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep \
--zero_stage 2
# ep_zero
echo -e "\n\n EP-ZERO \n\n"
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 16 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep_zero \
--use_kernel \
--extra_dp_size 2 \
--zero_stage 1 \
--load_balance
echo -e "\n\n EP-ZERO + Overlap \n\n"
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 16 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep_zero \
--use_kernel \
--extra_dp_size 2 \
--zero_stage 1 \
--load_balance \
--overlap_alltoall
# hybrid
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 128 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--use_kernel \
--plugin hybrid \
--pp_size 2 \
--dp_size 1 \
--ep_size 4 \
--zero_stage 1 \
--microbatch_size 32

View File

@ -1,57 +0,0 @@
#!/bin/bash
set -xue
export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_GID_INDEX=3
export NCCL_IB_TIMEOUT=23
export NCCL_IB_RETRY_CNT=7
export TORCH_DISTRIBUTED_DEBUG=INFO
export TORCH_DISTRIBUTED_DETAIL=DEBUG
export GLOO_SOCKET_IFNAME=eth0
NUM_GPU=8
MODEL="8b"
SEQ_LENGTH=2048
WARMUP=20
ACTIVE=4
# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi
# ep
echo -e "\n\n Naive EP \n\n"
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 12 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep \
--zero_stage 2
# ep_zero
echo -e "\n\n EP-ZERO \n\n"
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 20 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep_zero \
--use_kernel \
--extra_dp_size 2 \
--zero_stage 1 \
--load_balance \
--overlap_alltoall

View File

@ -1,139 +0,0 @@
import argparse
import functools
import os
import torch
import torch.distributed as dist
import tqdm
from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel
from colossalai.moe.manager import MOE_MANAGER
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length))
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def fsdp_main(rank, world_size, args):
# initialize the process group
# initialize the process group
dist.init_process_group("nccl")
MOE_MANAGER.setup(parallel=None)
dp_size = dist.get_world_size()
dataset = RandomDataset(
max_length=args.seq_length,
num_samples=args.batch_size * (args.warmup + args.active) * dp_size,
)
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False)
train_kwargs = {"batch_size": args.batch_size, "sampler": sampler}
train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)
torch.cuda.set_device(rank)
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name)
set_openmoe_args(
config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_load_balance=False,
enable_kernel=False,
enable_comm_overlap=False,
)
torch.set_default_dtype(torch.float16)
model = OpenMoeForCausalLM(config)
torch.set_default_dtype(torch.float32)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
OpenMoeDecoderLayer,
},
)
model = FSDP(
model,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5)
model.train()
model_numel = get_model_numel(model)
performance_evaluator = PerformanceEvaluator(
model_numel,
enable_grad_checkpoint=True,
ignore_steps=args.warmup,
dp_world_size=dist.get_world_size(),
)
for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
performance_evaluator.on_step_start(step)
input_ids, attention_mask, labels = (
data["input_ids"].cuda(),
data["attention_mask"].cuda(),
data["labels"].cuda(),
)
optimizer.zero_grad()
output = model(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
chunk_head=False,
)
loss = output["loss"]
loss.backward()
optimizer.step()
performance_evaluator.on_step_end(input_ids)
performance_evaluator.on_fit_end()
if dist.get_rank() == 0:
print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="base",
choices=["base", "8b"],
help="base or 8b",
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--seq_length", type=int, default=2048)
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--active", type=int, default=20)
args = parser.parse_args()
torch.manual_seed(42)
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
fsdp_main(local_rank, world_size, args)

View File

@ -1,44 +0,0 @@
#!/bin/bash
set -xue
export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_GID_INDEX=3
export NCCL_IB_TIMEOUT=23
export NCCL_IB_RETRY_CNT=7
export TORCH_DISTRIBUTED_DEBUG=INFO
export TORCH_DISTRIBUTED_DETAIL=DEBUG
export GLOO_SOCKET_IFNAME=eth0
MODEL="8b"
BATCH_SIZE=1
SEQ_LENGTH=2048
WARMUP=8
ACTIVE=4
# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi
# single node
torchrun --standalone $example_dir/benchmark/benchmark_fsdp.py \
--model_name $MODEL \
--batch_size $BATCH_SIZE \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE
# multi node
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=node_rank --master_addr=master_addr --master_port=master_port \
$example_dir/benchmark/benchmark_fsdp.py \
--model_name $MODEL \
--batch_size $BATCH_SIZE \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE

View File

@ -1,126 +0,0 @@
from time import time
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from colossalai.logging import DistributedLogger
def print_model_numel(logger: DistributedLogger, model: nn.Module) -> None:
B = 1024**3
M = 1024**2
K = 1024
outputs = "Model param count: "
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
if model_param >= B:
outputs += f"{model_param / B:.2f} B\n"
elif model_param >= M:
outputs += f"{model_param / M:.2f} M\n"
elif model_param >= K:
outputs += f"{model_param / K:.2f} K\n"
else:
outputs += f"{model_param}\n"
logger.info(outputs, ranks=[0])
def get_model_numel(model: nn.Module) -> None:
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
return model_param
def divide(x: float, y: float) -> float:
if y == 0:
return float("inf")
elif y == float("inf"):
return float("nan")
return x / y
@torch.no_grad()
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
tensor = torch.tensor([x], device=torch.cuda.current_device())
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
def end(self) -> None:
assert self.start_time is not None
self.duration += time() - self.start_time
self.start_time = None
def reset(self) -> None:
self.duration = 0.0
class PerformanceEvaluator:
"""
Callback for valuate the performance of the model.
Args:
actor_num_params: The number of parameters of the actor model.
critic_num_params: The number of parameters of the critic model.
initial_model_num_params: The number of parameters of the initial model.
reward_model_num_params: The number of parameters of the reward model.
enable_grad_checkpoint: Whether to enable gradient checkpointing.
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def __init__(
self,
model_numel: int,
enable_grad_checkpoint: bool = False,
ignore_steps: int = 0,
dp_world_size: Optional[int] = None,
) -> None:
self.model_numel = model_numel
self.enable_grad_checkpoint = enable_grad_checkpoint
self.ignore_steps = ignore_steps
self.dp_world_size = dp_world_size
self.world_size = dist.get_world_size()
self.disable: bool = False
self.timer = Timer()
self.num_samples: int = 0
self.flop: int = 0
def on_step_start(self, step: int) -> None:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable:
return
torch.cuda.synchronize()
self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable:
return
torch.cuda.synchronize()
self.timer.end()
batch_size, seq_len = input_ids.shape
self.num_samples += batch_size
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
def on_fit_end(self) -> None:
avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
mp_world_size = self.world_size // self.dp_world_size
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
if dist.get_rank() == 0:
print(
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
f"avg_throughput: {avg_throughput}"
)
print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")

View File

@ -4,7 +4,46 @@ from pathlib import Path
import torch.distributed as dist
import torch.nn as nn
import copy
import logging
import os
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
from colossalai.checkpoint_io.utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import OptimizerWrapper
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import (
get_dp_group,
get_dp_rank,
get_dp_size,
get_ep_group,
get_ep_rank,
get_ep_size,
is_moe_tensor,
)
from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
from colossalai.moe import MoECheckpintIO
@ -15,39 +54,51 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.no_grad()
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
"""
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
"""
model_param_dict = dict(model.named_parameters())
for name, param in list(state_dict.items()):
if ".experts." in name:
if ".experts.gate.weight" in name:
new_name = name.replace(".experts.gate.weight", ".experts.gate_weight")
if ".gate.weight" in name:
new_name = "module." + name.replace(".gate.weight", ".gate_weight")
state_dict[new_name] = state_dict.pop(name)
else:
str_idx = name.index(".experts.")
int(name.split(".")[-3])
if ".w1." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
elif ".w2." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
elif ".w3." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wo")
model_param = model_param_dict[model_param_name]
assert is_moe_tensor(model_param)
elif ".experts." in name:
# if is moe tensor
# in our moe module, expert is cat as one tensor
# but mixtral's experts is not cat
# we will insert the loaded expert into the position of cat tensor
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = 8 // ep_size
range(ep_rank * expert_num, (ep_rank + 1) * expert_num)
state_dict[name] = param
# get model param
str_idx = name.index(".experts.")
expert_idx = int(name.split(".")[-3])
if ".w1." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
elif ".w2." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wo")
elif ".w3." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
model_param_name = "module." + model_param_name
model_param = model_param_dict[model_param_name]
assert is_moe_tensor(model_param)
# get expert range
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = 8 // ep_size
expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num))
# insert new param
if expert_idx in expert_range:
new_param = model_param
new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1)
state_dict[model_param_name] = new_param
state_dict.pop(name)
else:
new_name = "module." + name
state_dict[new_name] = state_dict.pop(name)
for name, param in list(state_dict.items()):
new_name = "module." + name
state_dict[new_name] = state_dict.pop(name)
assert new_name in model_param_dict, f"{new_name} not in model"
assert name in model_param_dict, f"{name} not in model. model param dict: {model_param_dict.keys()}"
dist.barrier()
return state_dict
@ -124,3 +175,53 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
@torch.no_grad()
def pre_save_model(self, model: nn.Module) -> dict:
state_dict = model.state_dict()
for name, param in model.named_parameters():
if ".experts." in name:
if ".experts.gate_weight" in name:
new_name = name.replace(".experts.gate_weight", ".experts.gate.weight")
state_dict[new_name] = state_dict.pop(name)
elif ".experts." in name and is_moe_tensor(param):
ep_group = get_ep_group(param)
ep_rank = get_ep_rank(param)
ep_size = get_ep_size(param)
dp_rank = get_dp_rank(param)
if dp_rank == 0:
param = param.data.cuda()
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
# gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
assert all_param.shape[0] == 8
for i in range(8):
if ".wi_gate" in name:
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
elif ".wi_up" in name:
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
elif ".wo" in name:
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
new_name = new_name.replace("module.", "")
new_param = all_param[i].transpose(-1, -2)
state_dict[new_name] = new_param.cpu()
state_dict.pop(name)
for name, param in list(state_dict.items()):
new_name = name.replace("module.", "")
state_dict[new_name] = state_dict.pop(name)
if self.pp_size > 1:
if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)]
dist.all_gather_object(out, state_dict, group=self.pp_group)
if self.pp_rank == 0:
new_state_dict = {}
for o in out:
new_state_dict.update(o)
state_dict = new_state_dict
dist.barrier()
return state_dict

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralDecoderLayer
from colossalai.lazy import LazyInitContext
from colossalai.moe import SparseMLP
@ -39,7 +39,7 @@ class MixtralSparseMLP:
# get the attributes of the module
moe_kwargs = dict(
num_experts=module.num_experts,
num_experts=8,
hidden_size=module.hidden_dim,
intermediate_size=module.ffn_dim,
router_top_k=module.top_k,
@ -62,53 +62,18 @@ class MixtralSparseMLP:
device = module.gate.weight.device
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
# cat all experts
w1 = None
w2 = None
w3 = None
for i in module.experts:
# origin
wi_1 = i.w1.weight.data.clone().transpose(0, 1).unsqueeze(0)
wi_2 = i.w2.weight.data.clone().transpose(0, 1).unsqueeze(0)
wi_3 = i.w3.weight.data.clone().transpose(0, 1).unsqueeze(0)
# cat
w1 = wi_1 if w1 is None else torch.cat([w1, wi_1], dim=0)
w2 = wi_2 if w2 is None else torch.cat([w2, wi_2], dim=0)
w3 = wi_3 if w3 is None else torch.cat([w3, wi_3], dim=0)
# get local experts
if is_moe_tensor(sparse_mlp.experts.wi_gate):
ep_rank = get_ep_rank(sparse_mlp.experts.wi_gate)
expert_num = sparse_mlp.experts.wi_gate.shape[0]
expert_slice = slice(ep_rank * expert_num, (ep_rank + 1) * expert_num)
else:
expert_slice = slice(None)
w1 = w1[expert_slice].clone().detach()
w2 = w2[expert_slice].clone().detach()
w3 = w3[expert_slice].clone().detach()
assert (
w1.shape == sparse_mlp.experts.wi_gate.shape
), f"current shape: {w1.shape}, target shape:{sparse_mlp.experts.wi_gate.shape}"
assert (
w2.shape == sparse_mlp.experts.wo.shape
), f"current shape: {w2.shape}, target shape:{sparse_mlp.experts.wo.shape}"
assert (
w3.shape == sparse_mlp.experts.wi_up.shape
), f"current shape: {w3.shape}, target shape:{sparse_mlp.experts.wi_up.shape}"
# assign new param to colossal moe moudle
sparse_mlp.experts.wi_gate.data = w1
sparse_mlp.experts.wi_up.data = w3
sparse_mlp.experts.wo.data = w2
sparse_mlp.gate_weight = module.gate.weight
# TODO: fix
# the old weight is referenced somewhere so we can not del it.
# Change data pointer of old weight to release memory.
# The pointer will not be used and can be any pointer.
for i in module.experts:
i.w1.weight.data = w1
i.w2.weight.data = w2
i.w3.weight.data = w3
return sparse_mlp
def replace_moe_layer(model: nn.Module) -> nn.Module:
"""
Reverse the replace layer operation
Args:
module (torch.nn.Module): The object of layer to shard
"""
if isinstance(model, MixtralDecoderLayer):
model.block_sparse_moe = MixtralSparseMLP.from_native_module(model.block_sparse_moe)
else:
for _, child in model.named_children():
replace_moe_layer(child)

View File

@ -48,18 +48,6 @@ class MixtralPolicy(Policy):
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
# use colossal moe module
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=MixtralSparseMLP,
),
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(

View File

@ -1,30 +1,184 @@
from argparse import ArgumentParser
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.models.mixtral_layer import replace_moe_layer
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.moe import MOE_MANAGER, apply_load_balance
from colossalai.utils import get_current_device
import argparse
import os
from functools import partial
from typing import Dict
import torch
import torch.distributed as dist
from datasets import load_dataset
from huggingface_hub import snapshot_download
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device
def move_to_cuda(batch, device):
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
return batch
def load_ckpt(repo_name: str, model, booster: Booster):
ckpt_path = snapshot_download(repo_name)
# shard ckpt
if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
booster.load_model(model, ckpt_path)
def parse_args():
parser = ArgumentParser()
parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"])
return parser.parse_args()
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="8x7b",
choices=["8x7b"],
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
choices=["ep"],
help="Parallel methos.",
)
parser.add_argument(
"--output_path",
type=str,
default="./outputs",
help="The path of your saved model after finetuning.",
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
choices=["fp32", "bf16", "fp16"],
help="The mixed precision training.",
)
parser.add_argument(
"--seed", type=int, default=42, help="A seed for reproducible training."
)
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
)
parser.add_argument(
"--use_layernorm_kernel",
action="store_true",
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
)
args = parser.parse_args()
return args
def inference(args):
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
def main():
args = parse_args()
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
model = model.eval().bfloat16()
print(f"param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB")
model = model.to(torch.cuda.current_device())
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
text = "Hello my name is"
inputs = tokenizer(text, return_tensors="pt")
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": MixtralForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"checkpoint_io": MixtralMoECheckpointIO,
"zero_stage": 1,
}
mgr_dict = {}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(
pp_size=1,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size,
**mgr_dict,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# Build mixtral model
model_name = "mistralai/Mixtral-8x7B-v0.1"
config = MixtralConfig.from_pretrained(model_name)
config.num_local_experts = 1 # dont change this. it will not affect model
with skip_init():
model = MixtralForCausalLM(config)
model = (
model.to(torch.bfloat16)
if args.precision == "bf16"
else model.to(torch.float16)
)
model = model.to(get_current_device())
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Replace moe
with skip_init():
replace_moe_layer(model)
model.eval()
coordinator.print_on_master(f"Finish replace moe module")
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, _, _, _, _ = booster.boost(model=model)
coordinator.print_on_master(f"Finish init booster")
# load ckpt
load_ckpt(model_name, model, booster)
coordinator.print_on_master(f"Finish load ckpt")
text = ["Hello my name is"]
inputs = tokenizer(text, return_tensors="pt").to(torch.cuda.current_device())
outputs = model.module.generate(**inputs, max_new_tokens=20)
outputs = tokenizer.batch_decode(outputs)[0]
print(outputs)
if __name__ == "__main__":
args = parse_args()
inference(args)
main()

View File

@ -1 +1,7 @@
python infer.py --model "base"
NUM_GPU=2
MODEL="8x7b"
# ep
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
--model_name $MODEL \
--plugin "ep" \

View File

@ -1,37 +0,0 @@
pip install -r requirements.txt
# inference
python infer.py --model "test"
# train
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep" \
--batch_size 1
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep_zero" \
--batch_size 1 \
--zero_stage 1 \
--extra_dp_size 2 \
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep_zero" \
--batch_size 1 \
--zero_stage 2 \
--extra_dp_size 2 \
torchrun --standalone --nproc_per_node 4 train.py \
--model_name "test" \
--plugin "hybrid" \
--num_epoch 1 \
--pp_size 2 \
--dp_size 1 \
--ep_size 2 \
--zero_stage 1 \
--batch_size 1

View File

@ -14,6 +14,9 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParall
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
sys.path.append(
os.path.join(
@ -22,10 +25,6 @@ sys.path.append(
)
)
OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
@ -71,65 +70,55 @@ def run_fwd_bwd(
def get_config():
config = LlamaConfig(
config = MixtralConfig(
vocab_size=300,
hidden_size=16,
intermediate_size=32,
hidden_size=32,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=2,
head_dim=4,
dropout_rate=0.0,
hidden_act="swiglu",
)
set_openmoe_args(config, num_experts=8, moe_layer_interval=1)
return config
def get_model(parallel):
config = get_config()
model = OpenMoeForCausalLM(config)
model = MixtralForCausalLM(config).to(torch.bfloat16)
optim = torch.optim.Adam(model.parameters())
args = dict(
precision="bf16",
tp_size=1,
zero_stage=1,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoECheckpointIO,
)
if parallel == None:
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
**args,
)
elif parallel == "ep":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
**args,
)
elif parallel == "ep_zero":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=2,
extra_dp_size=2,
custom_policy=OpenMoeForCausalLMPolicy(),
**args,
)
elif parallel == "hybrid":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=2,
zero_stage=1,
microbatch_size=1,
custom_policy=OpenMoeForCausalLMPolicy(),
**args,
)
booster = Booster(plugin=plugin)
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
return model, booster, optim
def _test_moe_checkpoint(rank, parallel):
def _test_moe_checkpoint(parallel):
if parallel == None:
MOE_MANAGER.setup(
parallel=None,
@ -153,18 +142,12 @@ def _test_moe_checkpoint(rank, parallel):
)
model1, booster1, optim1 = get_model(parallel)
model2, booster2, optim2 = get_model(parallel)
model3, booster3, optim3 = get_model(parallel)
# param ckpt
# shard
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
booster2.load_model(model2, "./tmp_ckpt1")
# unshard
booster1.save_model(model1, "./tmp_ckpt1.pth")
booster3.load_model(model3, "./tmp_ckpt1.pth")
# check
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)
# optim ckpt
criterion = lambda x: x.mean()
@ -181,18 +164,12 @@ def _test_moe_checkpoint(rank, parallel):
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
dist.barrier()
booster2.load_optimizer(optim2, "./tmp_ckpt2")
# unshard
booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")
booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")
# check
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)
if dist.get_rank() == 0:
shutil.rmtree("./tmp_ckpt1")
shutil.rmtree("./tmp_ckpt2")
os.remove("./tmp_ckpt1.pth")
os.remove("./tmp_ckpt2.pth")
def _run_dist(rank, world_size, port, parallel):
@ -204,16 +181,16 @@ def _run_dist(rank, world_size, port, parallel):
port=port,
backend="nccl",
)
_test_moe_checkpoint(rank, parallel)
_test_moe_checkpoint(parallel)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
@pytest.mark.parametrize("parallel", ["ep", "ep_zero"])
@rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size, parallel):
spawn(_run_dist, world_size, parallel=parallel)
if __name__ == "__main__":
test_moe_checkpoint(world_size=4, parallel="hybrid")
test_moe_checkpoint(world_size=4, parallel="ep")

View File

@ -1,38 +0,0 @@
NUM_GPU=8
MODEL="8b"
SEQ_LENGTH=2048
BATCH_SIZE=1
LR=0.00001
# ep zero
torchrun --standalone --nproc_per_node $NUM_GPU train.py \
--num_epoch 1 \
--model_name $MODEL \
--plugin "ep_zero" \
--batch_size $BATCH_SIZE \
--lr $LR \
--zero_stage 1 \
--extra_dp_size 2
# ep
# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
# --num_epoch 1 \
# --model_name $MODEL \
# --plugin "ep_zero" \
# --batch_size $BATCH_SIZE \
# --lr $LR \
# --zero_stage 1
# hybrid
# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
# --num_epoch 1 \
# --model_name $MODEL \
# --plugin "hybrid" \
# --batch_size $BATCH_SIZE \
# --lr $LR \
# --zero_stage 1 \
# --pp_size 2 \
# --dp_size 1 \
# --ep_size 2 \

View File

@ -4,6 +4,7 @@ import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.models.mixtral_layer import replace_moe_layer
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
@ -15,13 +16,47 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParall
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.moe import MOE_MANAGER, apply_load_balance
from colossalai.nn.optimizer import HybridAdam
# from colossalai.nn.optimizer import HybridAdam
from torch.optim import Adam as HybridAdam
from colossalai.utils import get_current_device
import argparse
import os
from functools import partial
from typing import Dict
import torch
import torch.distributed as dist
from datasets import load_dataset
from huggingface_hub import snapshot_download
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
def load_ckpt(repo_name: str, model, booster: Booster):
ckpt_path = snapshot_download(repo_name)
# single ckpt
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
# shard ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
booster.load_model(model, ckpt_path)
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
@ -240,26 +275,30 @@ def main():
# num_key_value_heads=4,
# use_cache=False,
# )
config = MixtralConfig.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
config = MixtralConfig.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
config.use_cache = False
init_ctx = LazyInitContext(default_device=get_current_device())
with init_ctx:
model = MixtralForCausalLM.from_pretrained(
"/home/lczxl/.cache/huggingface/hub/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/f1ca00645f0b1565c7f9a1c863d2be6ebf896b04",
config=config,
).bfloat16()
config.num_local_experts = 1
# torch.set_default_tensor_type(torch.float16)
model = MixtralForCausalLM(config)
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
model = model.to(get_current_device())
replace_moe_layer(model)
# torch.set_default_tensor_type(torch.float32)
print(f"0-2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)
collate_fn = None
dataloader = plugin.prepare_dataloader(
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
)
torch.cuda.synchronize()
print(f"1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
@ -267,6 +306,12 @@ def main():
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
torch.cuda.synchronize()
print(f"2-1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
load_ckpt("mistralai/Mixtral-8x7B-v0.1", model, booster)
torch.cuda.synchronize()
print(f"2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
coordinator.print_on_master(f"Finish init booster")
@ -303,8 +348,12 @@ def main():
data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data)
loss = outputs["loss"]
print(f"3 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
# Backward
booster.backward(loss, optimizer)
print(f"4 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
pbar.set_postfix({"loss": loss.item()})
optimizer.step()

View File

@ -13,7 +13,7 @@ LR=0.00001
# --plugin "ep_zero" \
# --batch_size $BATCH_SIZE \
# --lr $LR \
# --zero_stage 1 \
# --zero_stage 2 \
# --extra_dp_size 2
# ep