mirror of https://github.com/hpcaitech/ColossalAI
update
parent
7c5b1a585f
commit
aa2e091dc6
|
@ -1,14 +1,4 @@
|
||||||
## OpenMoE
|
# Mixtral
|
||||||
[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is the open-source community's first decoder-only MoE transformer. OpenMoE is implemented in Jax, and [Colossal-AI](https://github.com/hpcaitech/ColossalAI) has pioneered an efficient open-source support for this model in PyTorch, enabling a broader range of users to participate in and use this model. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods.
|
|
||||||
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/MOE_training.png" width=800/>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
* [2023/11] [Enhanced MoE Parallelism, Open-source MoE Model Training Can Be 9 Times More Efficient](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient)
|
|
||||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/openmoe)
|
|
||||||
[[blog]](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient)
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
@ -23,116 +13,14 @@ CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
|
||||||
Then install dependencies.
|
Then install dependencies.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd ColossalAI/examples/language/openmoe
|
cd ColossalAI/applications/ColossalMoE
|
||||||
pip install -r requirements.txt
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
|
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code.
|
||||||
|
|
||||||
### 2. Install kernels (Optional)
|
### 2. Inference
|
||||||
|
Yon can use colossalai run to launch inference:
|
||||||
We have utilized `Triton`, `FlashAttention` and `Apex` kernel for better performance. They are not necessary but we recommend you to install them to fully utilize your hardware.
|
|
||||||
```
|
|
||||||
# install triton via pip
|
|
||||||
pip install triton
|
|
||||||
|
|
||||||
# install flash attention via pip
|
|
||||||
pip install flash-attn==2.0.5
|
|
||||||
|
|
||||||
# install apex from source
|
|
||||||
git clone https://github.com/NVIDIA/apex.git
|
|
||||||
cd apex
|
|
||||||
git checkout 741bdf50825a97664db08574981962d66436d16a
|
|
||||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ --global-option="--cuda_ext"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Train
|
|
||||||
Yon can use colossalai run to launch single-node training:
|
|
||||||
```bash
|
```bash
|
||||||
colossalai run --standalone --nproc_per_node YOUR_GPU_PER_NODE train.py --OTHER_CONFIGURATIONS
|
bash infer.sh
|
||||||
```
|
|
||||||
Yon can also use colossalai run to launch multi-nodes training:
|
|
||||||
```bash
|
|
||||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE train.py --OTHER_CONFIGURATIONS
|
|
||||||
```
|
|
||||||
|
|
||||||
Here is a sample hostfile:
|
|
||||||
|
|
||||||
```text
|
|
||||||
hostname1
|
|
||||||
hostname2
|
|
||||||
hostname3
|
|
||||||
hostname4
|
|
||||||
```
|
|
||||||
|
|
||||||
The hostname refers to the ip address of your nodes. Make sure master node can access all nodes (including itself) by ssh without password.
|
|
||||||
|
|
||||||
Here is details about CLI arguments:
|
|
||||||
|
|
||||||
- Model configuration: `--model_name`. `base` and `8b` are supported for OpenMoE.
|
|
||||||
- Booster plugin: `--plugin`. `ep`, `ep_zero` and `hybrid` are supported. `ep_zero` is recommended for general cases. `ep` can provides least memory consumption and `hybrid` suits large scale training.
|
|
||||||
- Output path: `--output_path`. The path to save your model. The default value is `./outputs`.
|
|
||||||
- Number of epochs: `--num_epochs`. The default value is 1.
|
|
||||||
- Local batch size: `--batch_size`. Batch size per GPU. The default value is 1.
|
|
||||||
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
|
||||||
- Mixed precision: `--precision`. The default value is "bf16". "fp16", "bf16" and "fp32" are supported.
|
|
||||||
- Max length: `--max_length`. Max sequence length. Default to 2048.
|
|
||||||
- Dataset: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as it.
|
|
||||||
- Task Name: `--task_name`. Task of corresponding dataset. Default to `super_natural_instructions`.
|
|
||||||
- Learning rate: `--lr`. The default value is 1e-5.
|
|
||||||
- Weight decay: `--weight_decay`. The default value is 0.
|
|
||||||
- Zero stage: `--zero_stage`. Zero stage. Recommend 2 for ep and 1 for ep zero.
|
|
||||||
- Extra dp size: `--extra_dp_size`. Extra moe param dp size for ep_zero plugin. Recommended to be 2 or 4.
|
|
||||||
- Use kernel: `--use_kernel`. Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.
|
|
||||||
- Use layernorm kernel: `--use_layernorm_kernel`. Use layernorm kernel. Need to install apex. Raise error if not installed.
|
|
||||||
- Router aux loss factor: `--router_aux_loss_factor`. Moe router z loss factor. You can refer to STMoE for details.
|
|
||||||
- Router z loss factor: `--router_z_loss_factor`. Moe router aux loss factor. You can refer to STMoE for details.
|
|
||||||
- Label smoothing: `--label_smoothing`. Label smoothing.
|
|
||||||
- Z loss factor: `--z_loss_factor`. The final outputs' classification z loss factor.
|
|
||||||
Load balance: `--load_balance`. Expert load balance. Defaults to False. Recommend enabling.
|
|
||||||
- Load balance interval: `--load_balance_interval`. Expert load balance interval.
|
|
||||||
- Communication overlap: `--comm_overlap`. Use communication overlap for MoE. Recommended to enable for multi-node training.
|
|
||||||
|
|
||||||
### 4. Shell Script Examples
|
|
||||||
|
|
||||||
For your convenience, we provide some shell scripts to train with various configurations. Here we will show an example of how to run training
|
|
||||||
OpenMoE.
|
|
||||||
|
|
||||||
#### a. Running environment
|
|
||||||
This experiment was performed on a single computing nodes with 8 A800 80GB GPUs in total for OpenMoE-8B. The GPUs are fully connected with NVLink.
|
|
||||||
|
|
||||||
#### b. Running command
|
|
||||||
We demonstrate how to run three plugins in `train.sh`. You can choose anyone and use your own args.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
bash train.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
#### c. Multi-Nodes Training
|
|
||||||
|
|
||||||
To run on multi-nodes, you can modify the script as:
|
|
||||||
```bash
|
|
||||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
|
||||||
train.py --OTHER_CONFIGURATIONS
|
|
||||||
```
|
|
||||||
|
|
||||||
## Reference
|
|
||||||
```
|
|
||||||
@article{bian2021colossal,
|
|
||||||
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
|
|
||||||
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
|
|
||||||
journal={arXiv preprint arXiv:2110.14883},
|
|
||||||
year={2021}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@misc{openmoe2023,
|
|
||||||
author = {Fuzhao Xue, Zian Zheng, Yao Fu, Jinjie Ni, Zangwei Zheng, Wangchunshu Zhou and Yang You},
|
|
||||||
title = {OpenMoE: Open Mixture-of-Experts Language Models},
|
|
||||||
year = {2023},
|
|
||||||
publisher = {GitHub},
|
|
||||||
journal = {GitHub repository},
|
|
||||||
howpublished = {\url{https://github.com/XueFuzhao/OpenMoE}},
|
|
||||||
}
|
|
||||||
```
|
```
|
|
@ -1,297 +0,0 @@
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
|
|
||||||
from model.openmoe_policy import OpenMoeForCausalLMPolicy
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
from transformers import T5Tokenizer
|
|
||||||
from transformers.models.llama import LlamaConfig
|
|
||||||
from utils import PerformanceEvaluator, get_model_numel
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.booster import Booster
|
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
|
||||||
from colossalai.cluster import DistCoordinator
|
|
||||||
from colossalai.moe.layers import apply_load_balance
|
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
|
||||||
from colossalai.moe.utils import skip_init
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
|
|
||||||
|
|
||||||
def move_to_cuda(batch, device):
|
|
||||||
return {k: v.to(device) for k, v in batch.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
|
|
||||||
ckpt_path = snapshot_download(repo_name)
|
|
||||||
# single ckpt
|
|
||||||
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
|
|
||||||
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
|
|
||||||
# shard ckpt
|
|
||||||
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
|
|
||||||
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
|
|
||||||
booster.load_model(model, ckpt_path)
|
|
||||||
|
|
||||||
|
|
||||||
class RandomDataset(Dataset):
|
|
||||||
def __init__(
|
|
||||||
self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384, tokenizer: T5Tokenizer = None
|
|
||||||
):
|
|
||||||
self.num_samples = num_samples
|
|
||||||
self.max_length = max_length
|
|
||||||
if os.path.exists("./mock_data.json"):
|
|
||||||
self.input_ids = []
|
|
||||||
self.attention_mask = []
|
|
||||||
with open("./mock_data.json", "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
for v in data.values():
|
|
||||||
d = v["text"]
|
|
||||||
encode = tokenizer(
|
|
||||||
"<pad>" + d,
|
|
||||||
return_tensors="pt",
|
|
||||||
add_special_tokens=False,
|
|
||||||
max_length=max_length,
|
|
||||||
truncation=True,
|
|
||||||
padding="max_length",
|
|
||||||
)
|
|
||||||
self.input_ids.append(encode["input_ids"])
|
|
||||||
self.attention_mask.append(encode["attention_mask"])
|
|
||||||
self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device())
|
|
||||||
self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device())
|
|
||||||
repeat_times = num_samples // self.input_ids.shape[0] + 1
|
|
||||||
self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]
|
|
||||||
self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]
|
|
||||||
else:
|
|
||||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
|
||||||
self.attention_mask = torch.ones_like(self.input_ids)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.num_samples
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return {
|
|
||||||
"input_ids": self.input_ids[idx],
|
|
||||||
"attention_mask": self.attention_mask[idx],
|
|
||||||
"labels": self.input_ids[idx],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
# basic settings
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_name",
|
|
||||||
type=str,
|
|
||||||
default="base",
|
|
||||||
choices=["base", "8b"],
|
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch_size",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="Batch size (per dp group) for the training dataloader.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seq_length",
|
|
||||||
type=int,
|
|
||||||
default=2048,
|
|
||||||
help="sequence length for the training dataloader.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--plugin",
|
|
||||||
type=str,
|
|
||||||
default="hybrid",
|
|
||||||
help="parallel plugin",
|
|
||||||
)
|
|
||||||
# hybrid plugin
|
|
||||||
parser.add_argument("--pp_size", type=int, default=2, help="pp size")
|
|
||||||
parser.add_argument("--dp_size", type=int, default=1, help="dp size")
|
|
||||||
parser.add_argument("--ep_size", type=int, default=2, help="ep size")
|
|
||||||
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin")
|
|
||||||
parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size")
|
|
||||||
parser.add_argument("--extra_dp_size", type=int, default=1)
|
|
||||||
# kernel
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_kernel",
|
|
||||||
action="store_true",
|
|
||||||
help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.",
|
|
||||||
)
|
|
||||||
# bench
|
|
||||||
parser.add_argument("--warmup", type=int, default=20)
|
|
||||||
parser.add_argument("--active", type=int, default=20)
|
|
||||||
# load balance
|
|
||||||
parser.add_argument("--load_balance", action="store_true")
|
|
||||||
|
|
||||||
# overlap communication
|
|
||||||
parser.add_argument("--overlap_comm", action="store_true")
|
|
||||||
# hierarchical all-to-all
|
|
||||||
parser.add_argument("--hierarchical_alltoall", action="store_true")
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
# Launch ColossalAI
|
|
||||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
|
||||||
coordinator = DistCoordinator()
|
|
||||||
|
|
||||||
# Set plugin
|
|
||||||
booster_kwargs = {}
|
|
||||||
hybrid_dict = {
|
|
||||||
"tp_size": 1,
|
|
||||||
"custom_policy": OpenMoeForCausalLMPolicy(),
|
|
||||||
"enable_fused_normalization": args.use_kernel,
|
|
||||||
"enable_jit_fused": args.use_kernel,
|
|
||||||
"precision": "bf16",
|
|
||||||
"zero_stage": args.zero_stage,
|
|
||||||
}
|
|
||||||
mgr_dict = {}
|
|
||||||
if args.plugin == "ep":
|
|
||||||
dp_size = dist.get_world_size()
|
|
||||||
plugin = MoeHybridParallelPlugin(
|
|
||||||
pp_size=1,
|
|
||||||
**hybrid_dict,
|
|
||||||
)
|
|
||||||
MOE_MANAGER.setup(
|
|
||||||
parallel="EP",
|
|
||||||
max_ep_size=dp_size,
|
|
||||||
**mgr_dict,
|
|
||||||
)
|
|
||||||
elif args.plugin == "ep_zero":
|
|
||||||
dp_size = dist.get_world_size()
|
|
||||||
use_ep_inside = False
|
|
||||||
plugin = MoeHybridParallelPlugin(
|
|
||||||
pp_size=1,
|
|
||||||
extra_dp_size=args.extra_dp_size,
|
|
||||||
use_ep_inside=use_ep_inside,
|
|
||||||
**hybrid_dict,
|
|
||||||
)
|
|
||||||
MOE_MANAGER.setup(
|
|
||||||
parallel="EP",
|
|
||||||
max_ep_size=dp_size // args.extra_dp_size,
|
|
||||||
use_ep_inside=use_ep_inside,
|
|
||||||
**mgr_dict,
|
|
||||||
)
|
|
||||||
elif args.plugin == "hybrid":
|
|
||||||
dp_size = dist.get_world_size() // args.pp_size
|
|
||||||
plugin = MoeHybridParallelPlugin(
|
|
||||||
pp_size=args.pp_size,
|
|
||||||
zero_stage=args.zero_stage,
|
|
||||||
microbatch_size=args.microbatch_size,
|
|
||||||
**hybrid_dict,
|
|
||||||
)
|
|
||||||
MOE_MANAGER.setup(
|
|
||||||
parallel="EP",
|
|
||||||
mode="fixed",
|
|
||||||
fixed_dp_size=args.dp_size,
|
|
||||||
fixed_ep_size=args.ep_size,
|
|
||||||
fixed_pp_size=args.pp_size,
|
|
||||||
**mgr_dict,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
|
||||||
coordinator.print_on_master(f"Set plugin as {plugin}")
|
|
||||||
|
|
||||||
# Build OpenMoe model
|
|
||||||
repo_name = "hpcaitech/openmoe-" + args.model_name
|
|
||||||
config = LlamaConfig.from_pretrained(repo_name)
|
|
||||||
set_openmoe_args(
|
|
||||||
config,
|
|
||||||
num_experts=config.num_experts,
|
|
||||||
moe_layer_interval=config.moe_layer_interval,
|
|
||||||
enable_load_balance=args.load_balance,
|
|
||||||
enable_kernel=args.use_kernel,
|
|
||||||
enable_comm_overlap=args.overlap_comm,
|
|
||||||
enable_hierarchical_alltoall=args.hierarchical_alltoall,
|
|
||||||
)
|
|
||||||
with skip_init():
|
|
||||||
model = OpenMoeForCausalLM(config)
|
|
||||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
|
||||||
|
|
||||||
# Enable gradient checkpointing
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
# Prepare tokenizer and dataloader
|
|
||||||
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
|
|
||||||
dataset = RandomDataset(
|
|
||||||
num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size,
|
|
||||||
max_length=args.seq_length,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size)
|
|
||||||
|
|
||||||
# Set optimizer
|
|
||||||
optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5)
|
|
||||||
|
|
||||||
model_numel = get_model_numel(model)
|
|
||||||
performance_evaluator = PerformanceEvaluator(
|
|
||||||
model_numel,
|
|
||||||
enable_grad_checkpoint=True,
|
|
||||||
ignore_steps=args.warmup,
|
|
||||||
dp_world_size=dp_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set booster
|
|
||||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
|
||||||
load_ckpt(repo_name, model, booster)
|
|
||||||
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
|
||||||
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
|
||||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
|
||||||
coordinator.print_on_master(f"Finish init booster")
|
|
||||||
|
|
||||||
# Start finetuning
|
|
||||||
coordinator.print_on_master(f"Start training")
|
|
||||||
model.train()
|
|
||||||
train_dataloader_iter = iter(dataloader)
|
|
||||||
total_len = len(train_dataloader_iter) - 1
|
|
||||||
exmaple_data = next(train_dataloader_iter)
|
|
||||||
with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar:
|
|
||||||
for step in pbar:
|
|
||||||
performance_evaluator.on_step_start(step)
|
|
||||||
if use_pipeline:
|
|
||||||
# Forward pass
|
|
||||||
outputs = booster.execute_pipeline(
|
|
||||||
train_dataloader_iter,
|
|
||||||
model,
|
|
||||||
lambda x, y: x.loss,
|
|
||||||
optimizer,
|
|
||||||
return_loss=True,
|
|
||||||
return_outputs=True,
|
|
||||||
)
|
|
||||||
# Backward and optimize
|
|
||||||
if is_pp_last_stage:
|
|
||||||
loss = outputs["loss"]
|
|
||||||
pbar.set_postfix({"loss": loss.item()})
|
|
||||||
else:
|
|
||||||
# Forward pass
|
|
||||||
data = next(train_dataloader_iter)
|
|
||||||
data = move_to_cuda(data, torch.cuda.current_device())
|
|
||||||
outputs = model(**data)
|
|
||||||
loss = outputs["loss"]
|
|
||||||
# Backward
|
|
||||||
booster.backward(loss, optimizer)
|
|
||||||
pbar.set_postfix({"loss": loss.item()})
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
performance_evaluator.on_step_end(exmaple_data["input_ids"])
|
|
||||||
if (step == args.warmup // 2) and args.load_balance:
|
|
||||||
coordinator.print_on_master(f"Apply load balance")
|
|
||||||
apply_load_balance(model, optimizer)
|
|
||||||
performance_evaluator.on_fit_end()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,78 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xue
|
|
||||||
|
|
||||||
NUM_GPU=8
|
|
||||||
MODEL="8b"
|
|
||||||
SEQ_LENGTH=2048
|
|
||||||
WARMUP=20
|
|
||||||
ACTIVE=4
|
|
||||||
|
|
||||||
# HACK: make model importable
|
|
||||||
example_dir=$(dirname $(realpath $(dirname $0)))
|
|
||||||
if [ -z ${PYTHONPATH+x} ]; then
|
|
||||||
export PYTHONPATH=$example_dir
|
|
||||||
else
|
|
||||||
export PYTHONPATH=$example_dir:$PYTHONPATH
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
# ep
|
|
||||||
echo -e "\n\n Naive EP \n\n"
|
|
||||||
torchrun --standalone --nproc_per_node $NUM_GPU \
|
|
||||||
$example_dir/benchmark/benchmark_cai.py \
|
|
||||||
--model_name $MODEL \
|
|
||||||
--batch_size 8 \
|
|
||||||
--seq_length $SEQ_LENGTH \
|
|
||||||
--warmup $WARMUP \
|
|
||||||
--active $ACTIVE \
|
|
||||||
--plugin ep \
|
|
||||||
--zero_stage 2
|
|
||||||
|
|
||||||
|
|
||||||
# ep_zero
|
|
||||||
echo -e "\n\n EP-ZERO \n\n"
|
|
||||||
torchrun --standalone --nproc_per_node $NUM_GPU \
|
|
||||||
$example_dir/benchmark/benchmark_cai.py \
|
|
||||||
--model_name $MODEL \
|
|
||||||
--batch_size 16 \
|
|
||||||
--seq_length $SEQ_LENGTH \
|
|
||||||
--warmup $WARMUP \
|
|
||||||
--active $ACTIVE \
|
|
||||||
--plugin ep_zero \
|
|
||||||
--use_kernel \
|
|
||||||
--extra_dp_size 2 \
|
|
||||||
--zero_stage 1 \
|
|
||||||
--load_balance
|
|
||||||
|
|
||||||
echo -e "\n\n EP-ZERO + Overlap \n\n"
|
|
||||||
torchrun --standalone --nproc_per_node $NUM_GPU \
|
|
||||||
$example_dir/benchmark/benchmark_cai.py \
|
|
||||||
--model_name $MODEL \
|
|
||||||
--batch_size 16 \
|
|
||||||
--seq_length $SEQ_LENGTH \
|
|
||||||
--warmup $WARMUP \
|
|
||||||
--active $ACTIVE \
|
|
||||||
--plugin ep_zero \
|
|
||||||
--use_kernel \
|
|
||||||
--extra_dp_size 2 \
|
|
||||||
--zero_stage 1 \
|
|
||||||
--load_balance \
|
|
||||||
--overlap_alltoall
|
|
||||||
|
|
||||||
|
|
||||||
# hybrid
|
|
||||||
torchrun --standalone --nproc_per_node $NUM_GPU \
|
|
||||||
$example_dir/benchmark/benchmark_cai.py \
|
|
||||||
--model_name $MODEL \
|
|
||||||
--batch_size 128 \
|
|
||||||
--seq_length $SEQ_LENGTH \
|
|
||||||
--warmup $WARMUP \
|
|
||||||
--active $ACTIVE \
|
|
||||||
--use_kernel \
|
|
||||||
--plugin hybrid \
|
|
||||||
--pp_size 2 \
|
|
||||||
--dp_size 1 \
|
|
||||||
--ep_size 4 \
|
|
||||||
--zero_stage 1 \
|
|
||||||
--microbatch_size 32
|
|
|
@ -1,57 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xue
|
|
||||||
|
|
||||||
export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
|
|
||||||
export NCCL_IB_DISABLE=0
|
|
||||||
export NCCL_SOCKET_IFNAME=eth0
|
|
||||||
export NCCL_IB_GID_INDEX=3
|
|
||||||
export NCCL_IB_TIMEOUT=23
|
|
||||||
export NCCL_IB_RETRY_CNT=7
|
|
||||||
export TORCH_DISTRIBUTED_DEBUG=INFO
|
|
||||||
export TORCH_DISTRIBUTED_DETAIL=DEBUG
|
|
||||||
export GLOO_SOCKET_IFNAME=eth0
|
|
||||||
|
|
||||||
NUM_GPU=8
|
|
||||||
MODEL="8b"
|
|
||||||
SEQ_LENGTH=2048
|
|
||||||
WARMUP=20
|
|
||||||
ACTIVE=4
|
|
||||||
|
|
||||||
# HACK: make model importable
|
|
||||||
example_dir=$(dirname $(realpath $(dirname $0)))
|
|
||||||
if [ -z ${PYTHONPATH+x} ]; then
|
|
||||||
export PYTHONPATH=$example_dir
|
|
||||||
else
|
|
||||||
export PYTHONPATH=$example_dir:$PYTHONPATH
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
# ep
|
|
||||||
echo -e "\n\n Naive EP \n\n"
|
|
||||||
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
|
|
||||||
$example_dir/benchmark/benchmark_cai.py \
|
|
||||||
--model_name $MODEL \
|
|
||||||
--batch_size 12 \
|
|
||||||
--seq_length $SEQ_LENGTH \
|
|
||||||
--warmup $WARMUP \
|
|
||||||
--active $ACTIVE \
|
|
||||||
--plugin ep \
|
|
||||||
--zero_stage 2
|
|
||||||
|
|
||||||
|
|
||||||
# ep_zero
|
|
||||||
echo -e "\n\n EP-ZERO \n\n"
|
|
||||||
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
|
|
||||||
$example_dir/benchmark/benchmark_cai.py \
|
|
||||||
--model_name $MODEL \
|
|
||||||
--batch_size 20 \
|
|
||||||
--seq_length $SEQ_LENGTH \
|
|
||||||
--warmup $WARMUP \
|
|
||||||
--active $ACTIVE \
|
|
||||||
--plugin ep_zero \
|
|
||||||
--use_kernel \
|
|
||||||
--extra_dp_size 2 \
|
|
||||||
--zero_stage 1 \
|
|
||||||
--load_balance \
|
|
||||||
--overlap_alltoall
|
|
|
@ -1,139 +0,0 @@
|
||||||
import argparse
|
|
||||||
import functools
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import tqdm
|
|
||||||
from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
|
|
||||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from transformers.models.llama import LlamaConfig
|
|
||||||
from utils import PerformanceEvaluator, get_model_numel
|
|
||||||
|
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
|
||||||
|
|
||||||
|
|
||||||
class RandomDataset(Dataset):
|
|
||||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
|
||||||
self.num_samples = num_samples
|
|
||||||
self.max_length = max_length
|
|
||||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length))
|
|
||||||
self.attention_mask = torch.ones_like(self.input_ids)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.num_samples
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return {
|
|
||||||
"input_ids": self.input_ids[idx],
|
|
||||||
"attention_mask": self.attention_mask[idx],
|
|
||||||
"labels": self.input_ids[idx],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def fsdp_main(rank, world_size, args):
|
|
||||||
# initialize the process group
|
|
||||||
|
|
||||||
# initialize the process group
|
|
||||||
dist.init_process_group("nccl")
|
|
||||||
|
|
||||||
MOE_MANAGER.setup(parallel=None)
|
|
||||||
|
|
||||||
dp_size = dist.get_world_size()
|
|
||||||
dataset = RandomDataset(
|
|
||||||
max_length=args.seq_length,
|
|
||||||
num_samples=args.batch_size * (args.warmup + args.active) * dp_size,
|
|
||||||
)
|
|
||||||
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False)
|
|
||||||
train_kwargs = {"batch_size": args.batch_size, "sampler": sampler}
|
|
||||||
train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)
|
|
||||||
torch.cuda.set_device(rank)
|
|
||||||
|
|
||||||
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name)
|
|
||||||
set_openmoe_args(
|
|
||||||
config,
|
|
||||||
num_experts=config.num_experts,
|
|
||||||
moe_layer_interval=config.moe_layer_interval,
|
|
||||||
enable_load_balance=False,
|
|
||||||
enable_kernel=False,
|
|
||||||
enable_comm_overlap=False,
|
|
||||||
)
|
|
||||||
torch.set_default_dtype(torch.float16)
|
|
||||||
model = OpenMoeForCausalLM(config)
|
|
||||||
torch.set_default_dtype(torch.float32)
|
|
||||||
auto_wrap_policy = functools.partial(
|
|
||||||
transformer_auto_wrap_policy,
|
|
||||||
transformer_layer_cls={
|
|
||||||
OpenMoeDecoderLayer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
model = FSDP(
|
|
||||||
model,
|
|
||||||
mixed_precision=MixedPrecision(
|
|
||||||
param_dtype=torch.bfloat16,
|
|
||||||
reduce_dtype=torch.bfloat16,
|
|
||||||
buffer_dtype=torch.bfloat16,
|
|
||||||
),
|
|
||||||
auto_wrap_policy=auto_wrap_policy,
|
|
||||||
device_id=torch.cuda.current_device(),
|
|
||||||
)
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5)
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
model_numel = get_model_numel(model)
|
|
||||||
performance_evaluator = PerformanceEvaluator(
|
|
||||||
model_numel,
|
|
||||||
enable_grad_checkpoint=True,
|
|
||||||
ignore_steps=args.warmup,
|
|
||||||
dp_world_size=dist.get_world_size(),
|
|
||||||
)
|
|
||||||
|
|
||||||
for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
|
|
||||||
performance_evaluator.on_step_start(step)
|
|
||||||
input_ids, attention_mask, labels = (
|
|
||||||
data["input_ids"].cuda(),
|
|
||||||
data["attention_mask"].cuda(),
|
|
||||||
data["labels"].cuda(),
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
output = model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
labels=labels,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
chunk_head=False,
|
|
||||||
)
|
|
||||||
loss = output["loss"]
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
performance_evaluator.on_step_end(input_ids)
|
|
||||||
|
|
||||||
performance_evaluator.on_fit_end()
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_name",
|
|
||||||
type=str,
|
|
||||||
default="base",
|
|
||||||
choices=["base", "8b"],
|
|
||||||
help="base or 8b",
|
|
||||||
)
|
|
||||||
parser.add_argument("--batch_size", type=int, default=1)
|
|
||||||
parser.add_argument("--seq_length", type=int, default=2048)
|
|
||||||
parser.add_argument("--warmup", type=int, default=20)
|
|
||||||
parser.add_argument("--active", type=int, default=20)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
|
||||||
local_rank = int(os.environ["LOCAL_RANK"])
|
|
||||||
fsdp_main(local_rank, world_size, args)
|
|
|
@ -1,44 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xue
|
|
||||||
|
|
||||||
export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
|
|
||||||
export NCCL_IB_DISABLE=0
|
|
||||||
export NCCL_SOCKET_IFNAME=eth0
|
|
||||||
export NCCL_IB_GID_INDEX=3
|
|
||||||
export NCCL_IB_TIMEOUT=23
|
|
||||||
export NCCL_IB_RETRY_CNT=7
|
|
||||||
export TORCH_DISTRIBUTED_DEBUG=INFO
|
|
||||||
export TORCH_DISTRIBUTED_DETAIL=DEBUG
|
|
||||||
export GLOO_SOCKET_IFNAME=eth0
|
|
||||||
|
|
||||||
MODEL="8b"
|
|
||||||
BATCH_SIZE=1
|
|
||||||
SEQ_LENGTH=2048
|
|
||||||
WARMUP=8
|
|
||||||
ACTIVE=4
|
|
||||||
|
|
||||||
# HACK: make model importable
|
|
||||||
example_dir=$(dirname $(realpath $(dirname $0)))
|
|
||||||
if [ -z ${PYTHONPATH+x} ]; then
|
|
||||||
export PYTHONPATH=$example_dir
|
|
||||||
else
|
|
||||||
export PYTHONPATH=$example_dir:$PYTHONPATH
|
|
||||||
fi
|
|
||||||
|
|
||||||
# single node
|
|
||||||
torchrun --standalone $example_dir/benchmark/benchmark_fsdp.py \
|
|
||||||
--model_name $MODEL \
|
|
||||||
--batch_size $BATCH_SIZE \
|
|
||||||
--seq_length $SEQ_LENGTH \
|
|
||||||
--warmup $WARMUP \
|
|
||||||
--active $ACTIVE
|
|
||||||
|
|
||||||
# multi node
|
|
||||||
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=node_rank --master_addr=master_addr --master_port=master_port \
|
|
||||||
$example_dir/benchmark/benchmark_fsdp.py \
|
|
||||||
--model_name $MODEL \
|
|
||||||
--batch_size $BATCH_SIZE \
|
|
||||||
--seq_length $SEQ_LENGTH \
|
|
||||||
--warmup $WARMUP \
|
|
||||||
--active $ACTIVE
|
|
|
@ -1,126 +0,0 @@
|
||||||
from time import time
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from colossalai.logging import DistributedLogger
|
|
||||||
|
|
||||||
|
|
||||||
def print_model_numel(logger: DistributedLogger, model: nn.Module) -> None:
|
|
||||||
B = 1024**3
|
|
||||||
M = 1024**2
|
|
||||||
K = 1024
|
|
||||||
outputs = "Model param count: "
|
|
||||||
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
if model_param >= B:
|
|
||||||
outputs += f"{model_param / B:.2f} B\n"
|
|
||||||
elif model_param >= M:
|
|
||||||
outputs += f"{model_param / M:.2f} M\n"
|
|
||||||
elif model_param >= K:
|
|
||||||
outputs += f"{model_param / K:.2f} K\n"
|
|
||||||
else:
|
|
||||||
outputs += f"{model_param}\n"
|
|
||||||
logger.info(outputs, ranks=[0])
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_numel(model: nn.Module) -> None:
|
|
||||||
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
return model_param
|
|
||||||
|
|
||||||
|
|
||||||
def divide(x: float, y: float) -> float:
|
|
||||||
if y == 0:
|
|
||||||
return float("inf")
|
|
||||||
elif y == float("inf"):
|
|
||||||
return float("nan")
|
|
||||||
return x / y
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
|
||||||
if world_size == 1:
|
|
||||||
return x
|
|
||||||
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
|
||||||
dist.all_reduce(tensor)
|
|
||||||
tensor = tensor / world_size
|
|
||||||
return tensor.item()
|
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.start_time: Optional[float] = None
|
|
||||||
self.duration: float = 0.0
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
self.start_time = time()
|
|
||||||
|
|
||||||
def end(self) -> None:
|
|
||||||
assert self.start_time is not None
|
|
||||||
self.duration += time() - self.start_time
|
|
||||||
self.start_time = None
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
self.duration = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
class PerformanceEvaluator:
|
|
||||||
"""
|
|
||||||
Callback for valuate the performance of the model.
|
|
||||||
Args:
|
|
||||||
actor_num_params: The number of parameters of the actor model.
|
|
||||||
critic_num_params: The number of parameters of the critic model.
|
|
||||||
initial_model_num_params: The number of parameters of the initial model.
|
|
||||||
reward_model_num_params: The number of parameters of the reward model.
|
|
||||||
enable_grad_checkpoint: Whether to enable gradient checkpointing.
|
|
||||||
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_numel: int,
|
|
||||||
enable_grad_checkpoint: bool = False,
|
|
||||||
ignore_steps: int = 0,
|
|
||||||
dp_world_size: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
self.model_numel = model_numel
|
|
||||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
|
||||||
self.ignore_steps = ignore_steps
|
|
||||||
self.dp_world_size = dp_world_size
|
|
||||||
self.world_size = dist.get_world_size()
|
|
||||||
self.disable: bool = False
|
|
||||||
self.timer = Timer()
|
|
||||||
self.num_samples: int = 0
|
|
||||||
self.flop: int = 0
|
|
||||||
|
|
||||||
def on_step_start(self, step: int) -> None:
|
|
||||||
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
|
|
||||||
if self.disable:
|
|
||||||
return
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.timer.start()
|
|
||||||
|
|
||||||
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
|
|
||||||
if self.disable:
|
|
||||||
return
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.timer.end()
|
|
||||||
|
|
||||||
batch_size, seq_len = input_ids.shape
|
|
||||||
|
|
||||||
self.num_samples += batch_size
|
|
||||||
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
|
|
||||||
|
|
||||||
def on_fit_end(self) -> None:
|
|
||||||
avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
|
|
||||||
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
|
|
||||||
mp_world_size = self.world_size // self.dp_world_size
|
|
||||||
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
print(
|
|
||||||
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
|
||||||
f"avg_throughput: {avg_throughput}"
|
|
||||||
)
|
|
||||||
print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")
|
|
|
@ -4,7 +4,46 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import rmtree
|
||||||
|
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
|
||||||
|
from colossalai.checkpoint_io.utils import (
|
||||||
|
StateDictSharder,
|
||||||
|
gather_distributed_param,
|
||||||
|
get_model_base_filenames,
|
||||||
|
get_optimizer_base_filenames,
|
||||||
|
is_safetensors_available,
|
||||||
|
load_shard_state_dict,
|
||||||
|
load_state_dict,
|
||||||
|
load_state_dict_into_model,
|
||||||
|
load_states_into_optimizer,
|
||||||
|
save_config_file,
|
||||||
|
save_param_groups,
|
||||||
|
save_state_dict,
|
||||||
|
save_state_dict_shards,
|
||||||
|
sharded_optimizer_loading_epilogue,
|
||||||
|
)
|
||||||
|
from colossalai.interface import OptimizerWrapper
|
||||||
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
|
from colossalai.tensor.moe_tensor.api import (
|
||||||
|
get_dp_group,
|
||||||
|
get_dp_rank,
|
||||||
|
get_dp_size,
|
||||||
|
get_ep_group,
|
||||||
|
get_ep_rank,
|
||||||
|
get_ep_size,
|
||||||
|
is_moe_tensor,
|
||||||
|
)
|
||||||
from colossalai.checkpoint_io import CheckpointIndexFile
|
from colossalai.checkpoint_io import CheckpointIndexFile
|
||||||
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
|
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
|
||||||
from colossalai.moe import MoECheckpintIO
|
from colossalai.moe import MoECheckpintIO
|
||||||
|
@ -15,39 +54,51 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
|
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
|
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
|
||||||
"""
|
"""
|
||||||
model_param_dict = dict(model.named_parameters())
|
model_param_dict = dict(model.named_parameters())
|
||||||
for name, param in list(state_dict.items()):
|
for name, param in list(state_dict.items()):
|
||||||
if ".experts." in name:
|
if ".gate.weight" in name:
|
||||||
if ".experts.gate.weight" in name:
|
new_name = "module." + name.replace(".gate.weight", ".gate_weight")
|
||||||
new_name = name.replace(".experts.gate.weight", ".experts.gate_weight")
|
|
||||||
state_dict[new_name] = state_dict.pop(name)
|
state_dict[new_name] = state_dict.pop(name)
|
||||||
else:
|
elif ".experts." in name:
|
||||||
str_idx = name.index(".experts.")
|
# if is moe tensor
|
||||||
int(name.split(".")[-3])
|
# in our moe module, expert is cat as one tensor
|
||||||
if ".w1." in name:
|
# but mixtral's experts is not cat
|
||||||
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
|
# we will insert the loaded expert into the position of cat tensor
|
||||||
elif ".w2." in name:
|
|
||||||
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
|
|
||||||
elif ".w3." in name:
|
|
||||||
model_param_name = name.replace(name[str_idx:], ".experts.wo")
|
|
||||||
model_param = model_param_dict[model_param_name]
|
|
||||||
assert is_moe_tensor(model_param)
|
|
||||||
|
|
||||||
ep_rank = get_ep_rank(model_param)
|
# get model param
|
||||||
ep_size = get_ep_size(model_param)
|
str_idx = name.index(".experts.")
|
||||||
expert_num = 8 // ep_size
|
expert_idx = int(name.split(".")[-3])
|
||||||
range(ep_rank * expert_num, (ep_rank + 1) * expert_num)
|
if ".w1." in name:
|
||||||
|
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
|
||||||
state_dict[name] = param
|
elif ".w2." in name:
|
||||||
|
model_param_name = name.replace(name[str_idx:], ".experts.wo")
|
||||||
|
elif ".w3." in name:
|
||||||
|
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
|
||||||
|
model_param_name = "module." + model_param_name
|
||||||
|
model_param = model_param_dict[model_param_name]
|
||||||
|
assert is_moe_tensor(model_param)
|
||||||
|
# get expert range
|
||||||
|
ep_rank = get_ep_rank(model_param)
|
||||||
|
ep_size = get_ep_size(model_param)
|
||||||
|
expert_num = 8 // ep_size
|
||||||
|
expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num))
|
||||||
|
# insert new param
|
||||||
|
if expert_idx in expert_range:
|
||||||
|
new_param = model_param
|
||||||
|
new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1)
|
||||||
|
state_dict[model_param_name] = new_param
|
||||||
|
state_dict.pop(name)
|
||||||
|
else:
|
||||||
|
new_name = "module." + name
|
||||||
|
state_dict[new_name] = state_dict.pop(name)
|
||||||
|
|
||||||
for name, param in list(state_dict.items()):
|
for name, param in list(state_dict.items()):
|
||||||
new_name = "module." + name
|
assert name in model_param_dict, f"{name} not in model. model param dict: {model_param_dict.keys()}"
|
||||||
state_dict[new_name] = state_dict.pop(name)
|
|
||||||
assert new_name in model_param_dict, f"{new_name} not in model"
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
@ -124,3 +175,53 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def pre_save_model(self, model: nn.Module) -> dict:
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if ".experts." in name:
|
||||||
|
if ".experts.gate_weight" in name:
|
||||||
|
new_name = name.replace(".experts.gate_weight", ".experts.gate.weight")
|
||||||
|
state_dict[new_name] = state_dict.pop(name)
|
||||||
|
elif ".experts." in name and is_moe_tensor(param):
|
||||||
|
ep_group = get_ep_group(param)
|
||||||
|
ep_rank = get_ep_rank(param)
|
||||||
|
ep_size = get_ep_size(param)
|
||||||
|
dp_rank = get_dp_rank(param)
|
||||||
|
|
||||||
|
if dp_rank == 0:
|
||||||
|
param = param.data.cuda()
|
||||||
|
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
|
||||||
|
# gather param from every ep rank
|
||||||
|
dist.all_gather(all_param, param, group=ep_group)
|
||||||
|
if ep_rank == 0:
|
||||||
|
all_param = torch.cat(all_param, dim=0)
|
||||||
|
assert all_param.shape[0] == 8
|
||||||
|
for i in range(8):
|
||||||
|
if ".wi_gate" in name:
|
||||||
|
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
|
||||||
|
elif ".wi_up" in name:
|
||||||
|
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
|
||||||
|
elif ".wo" in name:
|
||||||
|
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
|
||||||
|
new_name = new_name.replace("module.", "")
|
||||||
|
new_param = all_param[i].transpose(-1, -2)
|
||||||
|
state_dict[new_name] = new_param.cpu()
|
||||||
|
state_dict.pop(name)
|
||||||
|
|
||||||
|
for name, param in list(state_dict.items()):
|
||||||
|
new_name = name.replace("module.", "")
|
||||||
|
state_dict[new_name] = state_dict.pop(name)
|
||||||
|
|
||||||
|
if self.pp_size > 1:
|
||||||
|
if self.dp_rank == 0:
|
||||||
|
out = [None for _ in range(self.pp_size)]
|
||||||
|
dist.all_gather_object(out, state_dict, group=self.pp_group)
|
||||||
|
if self.pp_rank == 0:
|
||||||
|
new_state_dict = {}
|
||||||
|
for o in out:
|
||||||
|
new_state_dict.update(o)
|
||||||
|
state_dict = new_state_dict
|
||||||
|
dist.barrier()
|
||||||
|
return state_dict
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralDecoderLayer
|
||||||
|
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.moe import SparseMLP
|
from colossalai.moe import SparseMLP
|
||||||
|
@ -39,7 +39,7 @@ class MixtralSparseMLP:
|
||||||
|
|
||||||
# get the attributes of the module
|
# get the attributes of the module
|
||||||
moe_kwargs = dict(
|
moe_kwargs = dict(
|
||||||
num_experts=module.num_experts,
|
num_experts=8,
|
||||||
hidden_size=module.hidden_dim,
|
hidden_size=module.hidden_dim,
|
||||||
intermediate_size=module.ffn_dim,
|
intermediate_size=module.ffn_dim,
|
||||||
router_top_k=module.top_k,
|
router_top_k=module.top_k,
|
||||||
|
@ -62,53 +62,18 @@ class MixtralSparseMLP:
|
||||||
device = module.gate.weight.device
|
device = module.gate.weight.device
|
||||||
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
|
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
|
||||||
|
|
||||||
# cat all experts
|
|
||||||
w1 = None
|
|
||||||
w2 = None
|
|
||||||
w3 = None
|
|
||||||
for i in module.experts:
|
|
||||||
# origin
|
|
||||||
wi_1 = i.w1.weight.data.clone().transpose(0, 1).unsqueeze(0)
|
|
||||||
wi_2 = i.w2.weight.data.clone().transpose(0, 1).unsqueeze(0)
|
|
||||||
wi_3 = i.w3.weight.data.clone().transpose(0, 1).unsqueeze(0)
|
|
||||||
# cat
|
|
||||||
w1 = wi_1 if w1 is None else torch.cat([w1, wi_1], dim=0)
|
|
||||||
w2 = wi_2 if w2 is None else torch.cat([w2, wi_2], dim=0)
|
|
||||||
w3 = wi_3 if w3 is None else torch.cat([w3, wi_3], dim=0)
|
|
||||||
|
|
||||||
# get local experts
|
|
||||||
if is_moe_tensor(sparse_mlp.experts.wi_gate):
|
|
||||||
ep_rank = get_ep_rank(sparse_mlp.experts.wi_gate)
|
|
||||||
expert_num = sparse_mlp.experts.wi_gate.shape[0]
|
|
||||||
expert_slice = slice(ep_rank * expert_num, (ep_rank + 1) * expert_num)
|
|
||||||
else:
|
|
||||||
expert_slice = slice(None)
|
|
||||||
w1 = w1[expert_slice].clone().detach()
|
|
||||||
w2 = w2[expert_slice].clone().detach()
|
|
||||||
w3 = w3[expert_slice].clone().detach()
|
|
||||||
assert (
|
|
||||||
w1.shape == sparse_mlp.experts.wi_gate.shape
|
|
||||||
), f"current shape: {w1.shape}, target shape:{sparse_mlp.experts.wi_gate.shape}"
|
|
||||||
assert (
|
|
||||||
w2.shape == sparse_mlp.experts.wo.shape
|
|
||||||
), f"current shape: {w2.shape}, target shape:{sparse_mlp.experts.wo.shape}"
|
|
||||||
assert (
|
|
||||||
w3.shape == sparse_mlp.experts.wi_up.shape
|
|
||||||
), f"current shape: {w3.shape}, target shape:{sparse_mlp.experts.wi_up.shape}"
|
|
||||||
|
|
||||||
# assign new param to colossal moe moudle
|
|
||||||
sparse_mlp.experts.wi_gate.data = w1
|
|
||||||
sparse_mlp.experts.wi_up.data = w3
|
|
||||||
sparse_mlp.experts.wo.data = w2
|
|
||||||
sparse_mlp.gate_weight = module.gate.weight
|
|
||||||
|
|
||||||
# TODO: fix
|
|
||||||
# the old weight is referenced somewhere so we can not del it.
|
|
||||||
# Change data pointer of old weight to release memory.
|
|
||||||
# The pointer will not be used and can be any pointer.
|
|
||||||
for i in module.experts:
|
|
||||||
i.w1.weight.data = w1
|
|
||||||
i.w2.weight.data = w2
|
|
||||||
i.w3.weight.data = w3
|
|
||||||
|
|
||||||
return sparse_mlp
|
return sparse_mlp
|
||||||
|
|
||||||
|
|
||||||
|
def replace_moe_layer(model: nn.Module) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Reverse the replace layer operation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (torch.nn.Module): The object of layer to shard
|
||||||
|
"""
|
||||||
|
if isinstance(model, MixtralDecoderLayer):
|
||||||
|
model.block_sparse_moe = MixtralSparseMLP.from_native_module(model.block_sparse_moe)
|
||||||
|
else:
|
||||||
|
for _, child in model.named_children():
|
||||||
|
replace_moe_layer(child)
|
||||||
|
|
|
@ -48,18 +48,6 @@ class MixtralPolicy(Policy):
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||||
|
|
||||||
# use colossal moe module
|
|
||||||
self.append_or_create_submodule_replacement(
|
|
||||||
description=[
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="block_sparse_moe",
|
|
||||||
target_module=MixtralSparseMLP,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
policy=policy,
|
|
||||||
target_key=MixtralDecoderLayer,
|
|
||||||
)
|
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
|
|
|
@ -1,30 +1,184 @@
|
||||||
from argparse import ArgumentParser
|
import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
import torch.distributed as dist
|
||||||
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||||
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||||
|
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.moe import MOE_MANAGER, apply_load_balance
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from datasets import load_dataset
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import T5Tokenizer
|
||||||
|
from transformers.models.llama import LlamaConfig
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.moe.layers import apply_load_balance
|
||||||
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
|
from colossalai.moe.utils import skip_init
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
def move_to_cuda(batch, device):
|
||||||
|
for k, v in batch.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
batch[k] = v.to(device)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def load_ckpt(repo_name: str, model, booster: Booster):
|
||||||
|
ckpt_path = snapshot_download(repo_name)
|
||||||
|
# shard ckpt
|
||||||
|
if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
|
||||||
|
ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
|
||||||
|
booster.load_model(model, ckpt_path)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = ArgumentParser()
|
# basic settings
|
||||||
parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"])
|
parser = argparse.ArgumentParser()
|
||||||
return parser.parse_args()
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
type=str,
|
||||||
|
default="8x7b",
|
||||||
|
choices=["8x7b"],
|
||||||
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plugin",
|
||||||
|
type=str,
|
||||||
|
default="hybrid",
|
||||||
|
choices=["ep"],
|
||||||
|
help="Parallel methos.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_path",
|
||||||
|
type=str,
|
||||||
|
default="./outputs",
|
||||||
|
help="The path of your saved model after finetuning.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--precision",
|
||||||
|
type=str,
|
||||||
|
default="bf16",
|
||||||
|
choices=["fp32", "bf16", "fp16"],
|
||||||
|
help="The mixed precision training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed", type=int, default=42, help="A seed for reproducible training."
|
||||||
|
)
|
||||||
|
|
||||||
|
# kernel
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_kernel",
|
||||||
|
action="store_true",
|
||||||
|
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_layernorm_kernel",
|
||||||
|
action="store_true",
|
||||||
|
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
def inference(args):
|
def main():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
args = parse_args()
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
# Launch ColossalAI
|
||||||
model = model.eval().bfloat16()
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||||
print(f"param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB")
|
coordinator = DistCoordinator()
|
||||||
model = model.to(torch.cuda.current_device())
|
|
||||||
|
|
||||||
text = "Hello my name is"
|
# Set plugin
|
||||||
inputs = tokenizer(text, return_tensors="pt")
|
booster_kwargs = {}
|
||||||
|
hybrid_dict = {
|
||||||
|
"tp_size": 1,
|
||||||
|
"custom_policy": MixtralForCausalLMPolicy(),
|
||||||
|
"enable_fused_normalization": args.use_layernorm_kernel,
|
||||||
|
"enable_jit_fused": args.use_kernel,
|
||||||
|
"precision": args.precision,
|
||||||
|
"checkpoint_io": MixtralMoECheckpointIO,
|
||||||
|
"zero_stage": 1,
|
||||||
|
}
|
||||||
|
mgr_dict = {}
|
||||||
|
if args.plugin == "ep":
|
||||||
|
dp_size = dist.get_world_size()
|
||||||
|
plugin = MoeHybridParallelPlugin(
|
||||||
|
pp_size=1,
|
||||||
|
**hybrid_dict,
|
||||||
|
)
|
||||||
|
MOE_MANAGER.setup(
|
||||||
|
parallel="EP",
|
||||||
|
max_ep_size=dp_size,
|
||||||
|
**mgr_dict,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||||
|
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||||
|
|
||||||
outputs = model.generate(**inputs, max_new_tokens=20)
|
# Build mixtral model
|
||||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
model_name = "mistralai/Mixtral-8x7B-v0.1"
|
||||||
|
config = MixtralConfig.from_pretrained(model_name)
|
||||||
|
config.num_local_experts = 1 # dont change this. it will not affect model
|
||||||
|
with skip_init():
|
||||||
|
model = MixtralForCausalLM(config)
|
||||||
|
model = (
|
||||||
|
model.to(torch.bfloat16)
|
||||||
|
if args.precision == "bf16"
|
||||||
|
else model.to(torch.float16)
|
||||||
|
)
|
||||||
|
model = model.to(get_current_device())
|
||||||
|
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||||
|
|
||||||
|
# Replace moe
|
||||||
|
with skip_init():
|
||||||
|
replace_moe_layer(model)
|
||||||
|
model.eval()
|
||||||
|
coordinator.print_on_master(f"Finish replace moe module")
|
||||||
|
|
||||||
|
# Prepare tokenizer and dataloader
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# Set booster
|
||||||
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
model, _, _, _, _ = booster.boost(model=model)
|
||||||
|
coordinator.print_on_master(f"Finish init booster")
|
||||||
|
|
||||||
|
# load ckpt
|
||||||
|
load_ckpt(model_name, model, booster)
|
||||||
|
coordinator.print_on_master(f"Finish load ckpt")
|
||||||
|
|
||||||
|
text = ["Hello my name is"]
|
||||||
|
inputs = tokenizer(text, return_tensors="pt").to(torch.cuda.current_device())
|
||||||
|
outputs = model.module.generate(**inputs, max_new_tokens=20)
|
||||||
|
outputs = tokenizer.batch_decode(outputs)[0]
|
||||||
|
print(outputs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
main()
|
||||||
inference(args)
|
|
||||||
|
|
|
@ -1 +1,7 @@
|
||||||
python infer.py --model "base"
|
NUM_GPU=2
|
||||||
|
MODEL="8x7b"
|
||||||
|
|
||||||
|
# ep
|
||||||
|
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
|
||||||
|
--model_name $MODEL \
|
||||||
|
--plugin "ep" \
|
||||||
|
|
|
@ -1,37 +0,0 @@
|
||||||
pip install -r requirements.txt
|
|
||||||
|
|
||||||
# inference
|
|
||||||
python infer.py --model "test"
|
|
||||||
|
|
||||||
# train
|
|
||||||
torchrun --standalone --nproc_per_node 4 train.py \
|
|
||||||
--num_epoch 1 \
|
|
||||||
--model_name "test" \
|
|
||||||
--plugin "ep" \
|
|
||||||
--batch_size 1
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node 4 train.py \
|
|
||||||
--num_epoch 1 \
|
|
||||||
--model_name "test" \
|
|
||||||
--plugin "ep_zero" \
|
|
||||||
--batch_size 1 \
|
|
||||||
--zero_stage 1 \
|
|
||||||
--extra_dp_size 2 \
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node 4 train.py \
|
|
||||||
--num_epoch 1 \
|
|
||||||
--model_name "test" \
|
|
||||||
--plugin "ep_zero" \
|
|
||||||
--batch_size 1 \
|
|
||||||
--zero_stage 2 \
|
|
||||||
--extra_dp_size 2 \
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node 4 train.py \
|
|
||||||
--model_name "test" \
|
|
||||||
--plugin "hybrid" \
|
|
||||||
--num_epoch 1 \
|
|
||||||
--pp_size 2 \
|
|
||||||
--dp_size 1 \
|
|
||||||
--ep_size 2 \
|
|
||||||
--zero_stage 1 \
|
|
||||||
--batch_size 1
|
|
|
@ -14,6 +14,9 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParall
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||||
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||||
|
|
||||||
sys.path.append(
|
sys.path.append(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
|
@ -22,10 +25,6 @@ sys.path.append(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
|
|
||||||
set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
|
|
||||||
OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
|
|
||||||
|
|
||||||
|
|
||||||
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
|
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
|
||||||
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
|
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
|
||||||
|
@ -71,65 +70,55 @@ def run_fwd_bwd(
|
||||||
|
|
||||||
|
|
||||||
def get_config():
|
def get_config():
|
||||||
config = LlamaConfig(
|
config = MixtralConfig(
|
||||||
vocab_size=300,
|
vocab_size=300,
|
||||||
hidden_size=16,
|
hidden_size=32,
|
||||||
intermediate_size=32,
|
intermediate_size=128,
|
||||||
num_hidden_layers=2,
|
num_hidden_layers=2,
|
||||||
num_attention_heads=2,
|
|
||||||
head_dim=4,
|
|
||||||
dropout_rate=0.0,
|
dropout_rate=0.0,
|
||||||
hidden_act="swiglu",
|
|
||||||
)
|
)
|
||||||
set_openmoe_args(config, num_experts=8, moe_layer_interval=1)
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_model(parallel):
|
def get_model(parallel):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
model = OpenMoeForCausalLM(config)
|
model = MixtralForCausalLM(config).to(torch.bfloat16)
|
||||||
optim = torch.optim.Adam(model.parameters())
|
optim = torch.optim.Adam(model.parameters())
|
||||||
|
args = dict(
|
||||||
|
precision="bf16",
|
||||||
|
tp_size=1,
|
||||||
|
zero_stage=1,
|
||||||
|
custom_policy=MixtralForCausalLMPolicy(),
|
||||||
|
checkpoint_io=MixtralMoECheckpointIO,
|
||||||
|
)
|
||||||
if parallel == None:
|
if parallel == None:
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
precision="bf16",
|
|
||||||
tp_size=1,
|
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
zero_stage=2,
|
**args,
|
||||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
|
||||||
)
|
)
|
||||||
elif parallel == "ep":
|
elif parallel == "ep":
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
precision="bf16",
|
|
||||||
tp_size=1,
|
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
zero_stage=2,
|
**args,
|
||||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
|
||||||
)
|
)
|
||||||
elif parallel == "ep_zero":
|
elif parallel == "ep_zero":
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
precision="bf16",
|
|
||||||
tp_size=1,
|
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
zero_stage=2,
|
|
||||||
extra_dp_size=2,
|
extra_dp_size=2,
|
||||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
**args,
|
||||||
)
|
)
|
||||||
elif parallel == "hybrid":
|
elif parallel == "hybrid":
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
precision="bf16",
|
|
||||||
tp_size=1,
|
|
||||||
pp_size=2,
|
pp_size=2,
|
||||||
zero_stage=1,
|
|
||||||
microbatch_size=1,
|
microbatch_size=1,
|
||||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
**args,
|
||||||
)
|
)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
|
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
|
||||||
return model, booster, optim
|
return model, booster, optim
|
||||||
|
|
||||||
|
|
||||||
def _test_moe_checkpoint(rank, parallel):
|
def _test_moe_checkpoint(parallel):
|
||||||
if parallel == None:
|
if parallel == None:
|
||||||
MOE_MANAGER.setup(
|
MOE_MANAGER.setup(
|
||||||
parallel=None,
|
parallel=None,
|
||||||
|
@ -153,18 +142,12 @@ def _test_moe_checkpoint(rank, parallel):
|
||||||
)
|
)
|
||||||
model1, booster1, optim1 = get_model(parallel)
|
model1, booster1, optim1 = get_model(parallel)
|
||||||
model2, booster2, optim2 = get_model(parallel)
|
model2, booster2, optim2 = get_model(parallel)
|
||||||
model3, booster3, optim3 = get_model(parallel)
|
|
||||||
|
|
||||||
# param ckpt
|
# param ckpt
|
||||||
# shard
|
# shard
|
||||||
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
|
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
|
||||||
booster2.load_model(model2, "./tmp_ckpt1")
|
booster2.load_model(model2, "./tmp_ckpt1")
|
||||||
# unshard
|
|
||||||
booster1.save_model(model1, "./tmp_ckpt1.pth")
|
|
||||||
booster3.load_model(model3, "./tmp_ckpt1.pth")
|
|
||||||
# check
|
# check
|
||||||
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
|
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
|
||||||
check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)
|
|
||||||
|
|
||||||
# optim ckpt
|
# optim ckpt
|
||||||
criterion = lambda x: x.mean()
|
criterion = lambda x: x.mean()
|
||||||
|
@ -181,18 +164,12 @@ def _test_moe_checkpoint(rank, parallel):
|
||||||
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
|
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
booster2.load_optimizer(optim2, "./tmp_ckpt2")
|
booster2.load_optimizer(optim2, "./tmp_ckpt2")
|
||||||
# unshard
|
|
||||||
booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")
|
|
||||||
booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")
|
|
||||||
# check
|
# check
|
||||||
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
|
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
|
||||||
check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)
|
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
shutil.rmtree("./tmp_ckpt1")
|
shutil.rmtree("./tmp_ckpt1")
|
||||||
shutil.rmtree("./tmp_ckpt2")
|
shutil.rmtree("./tmp_ckpt2")
|
||||||
os.remove("./tmp_ckpt1.pth")
|
|
||||||
os.remove("./tmp_ckpt2.pth")
|
|
||||||
|
|
||||||
|
|
||||||
def _run_dist(rank, world_size, port, parallel):
|
def _run_dist(rank, world_size, port, parallel):
|
||||||
|
@ -204,16 +181,16 @@ def _run_dist(rank, world_size, port, parallel):
|
||||||
port=port,
|
port=port,
|
||||||
backend="nccl",
|
backend="nccl",
|
||||||
)
|
)
|
||||||
_test_moe_checkpoint(rank, parallel)
|
_test_moe_checkpoint(parallel)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [4])
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
|
@pytest.mark.parametrize("parallel", ["ep", "ep_zero"])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_moe_checkpoint(world_size, parallel):
|
def test_moe_checkpoint(world_size, parallel):
|
||||||
spawn(_run_dist, world_size, parallel=parallel)
|
spawn(_run_dist, world_size, parallel=parallel)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_moe_checkpoint(world_size=4, parallel="hybrid")
|
test_moe_checkpoint(world_size=4, parallel="ep")
|
||||||
|
|
|
@ -1,38 +0,0 @@
|
||||||
|
|
||||||
|
|
||||||
NUM_GPU=8
|
|
||||||
MODEL="8b"
|
|
||||||
SEQ_LENGTH=2048
|
|
||||||
BATCH_SIZE=1
|
|
||||||
LR=0.00001
|
|
||||||
|
|
||||||
# ep zero
|
|
||||||
torchrun --standalone --nproc_per_node $NUM_GPU train.py \
|
|
||||||
--num_epoch 1 \
|
|
||||||
--model_name $MODEL \
|
|
||||||
--plugin "ep_zero" \
|
|
||||||
--batch_size $BATCH_SIZE \
|
|
||||||
--lr $LR \
|
|
||||||
--zero_stage 1 \
|
|
||||||
--extra_dp_size 2
|
|
||||||
|
|
||||||
# ep
|
|
||||||
# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
|
|
||||||
# --num_epoch 1 \
|
|
||||||
# --model_name $MODEL \
|
|
||||||
# --plugin "ep_zero" \
|
|
||||||
# --batch_size $BATCH_SIZE \
|
|
||||||
# --lr $LR \
|
|
||||||
# --zero_stage 1
|
|
||||||
|
|
||||||
# hybrid
|
|
||||||
# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
|
|
||||||
# --num_epoch 1 \
|
|
||||||
# --model_name $MODEL \
|
|
||||||
# --plugin "hybrid" \
|
|
||||||
# --batch_size $BATCH_SIZE \
|
|
||||||
# --lr $LR \
|
|
||||||
# --zero_stage 1 \
|
|
||||||
# --pp_size 2 \
|
|
||||||
# --dp_size 1 \
|
|
||||||
# --ep_size 2 \
|
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||||
|
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
@ -15,13 +16,47 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParall
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.moe import MOE_MANAGER, apply_load_balance
|
from colossalai.moe import MOE_MANAGER, apply_load_balance
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
# from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from torch.optim import Adam as HybridAdam
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from datasets import load_dataset
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import T5Tokenizer
|
||||||
|
from transformers.models.llama import LlamaConfig
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.moe.layers import apply_load_balance
|
||||||
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
|
from colossalai.moe.utils import skip_init
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
def move_to_cuda(batch, device):
|
def move_to_cuda(batch, device):
|
||||||
return {k: v.to(device) for k, v in batch.items()}
|
return {k: v.to(device) for k, v in batch.items()}
|
||||||
|
|
||||||
|
def load_ckpt(repo_name: str, model, booster: Booster):
|
||||||
|
ckpt_path = snapshot_download(repo_name)
|
||||||
|
# single ckpt
|
||||||
|
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
|
||||||
|
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
|
||||||
|
# shard ckpt
|
||||||
|
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
|
||||||
|
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
|
||||||
|
booster.load_model(model, ckpt_path)
|
||||||
|
|
||||||
|
|
||||||
class RandomDataset(Dataset):
|
class RandomDataset(Dataset):
|
||||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
|
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
|
||||||
|
@ -240,26 +275,30 @@ def main():
|
||||||
# num_key_value_heads=4,
|
# num_key_value_heads=4,
|
||||||
# use_cache=False,
|
# use_cache=False,
|
||||||
# )
|
# )
|
||||||
config = MixtralConfig.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
config = MixtralConfig.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
init_ctx = LazyInitContext(default_device=get_current_device())
|
config.num_local_experts = 1
|
||||||
with init_ctx:
|
# torch.set_default_tensor_type(torch.float16)
|
||||||
model = MixtralForCausalLM.from_pretrained(
|
model = MixtralForCausalLM(config)
|
||||||
"/home/lczxl/.cache/huggingface/hub/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/f1ca00645f0b1565c7f9a1c863d2be6ebf896b04",
|
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
|
||||||
config=config,
|
model = model.to(get_current_device())
|
||||||
).bfloat16()
|
replace_moe_layer(model)
|
||||||
|
# torch.set_default_tensor_type(torch.float32)
|
||||||
|
print(f"0-2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
|
||||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||||
|
|
||||||
# Enable gradient checkpointing
|
# Enable gradient checkpointing
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
# Prepare tokenizer and dataloader
|
# Prepare tokenizer and dataloader
|
||||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
|
||||||
dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)
|
dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)
|
||||||
collate_fn = None
|
collate_fn = None
|
||||||
dataloader = plugin.prepare_dataloader(
|
dataloader = plugin.prepare_dataloader(
|
||||||
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
|
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
|
||||||
)
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f"1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
|
||||||
|
|
||||||
# Set optimizer
|
# Set optimizer
|
||||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
|
@ -267,6 +306,12 @@ def main():
|
||||||
# Set booster
|
# Set booster
|
||||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f"2-1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
|
||||||
|
load_ckpt("mistralai/Mixtral-8x7B-v0.1", model, booster)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f"2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
|
||||||
|
|
||||||
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||||
coordinator.print_on_master(f"Finish init booster")
|
coordinator.print_on_master(f"Finish init booster")
|
||||||
|
@ -303,8 +348,12 @@ def main():
|
||||||
data = move_to_cuda(data, torch.cuda.current_device())
|
data = move_to_cuda(data, torch.cuda.current_device())
|
||||||
outputs = model(**data)
|
outputs = model(**data)
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
|
print(f"3 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
|
||||||
|
|
||||||
# Backward
|
# Backward
|
||||||
booster.backward(loss, optimizer)
|
booster.backward(loss, optimizer)
|
||||||
|
print(f"4 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
|
||||||
|
|
||||||
pbar.set_postfix({"loss": loss.item()})
|
pbar.set_postfix({"loss": loss.item()})
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
|
@ -13,7 +13,7 @@ LR=0.00001
|
||||||
# --plugin "ep_zero" \
|
# --plugin "ep_zero" \
|
||||||
# --batch_size $BATCH_SIZE \
|
# --batch_size $BATCH_SIZE \
|
||||||
# --lr $LR \
|
# --lr $LR \
|
||||||
# --zero_stage 1 \
|
# --zero_stage 2 \
|
||||||
# --extra_dp_size 2
|
# --extra_dp_size 2
|
||||||
|
|
||||||
# ep
|
# ep
|
||||||
|
|
Loading…
Reference in New Issue