[example] add llama2 example (#4527)

* [example] transfer llama-1 example

* [example] fit llama-2

* [example] refactor scripts folder

* [example] fit new gemini plugin

* [cli] fix multinode runner

* [example] fit gemini optim checkpoint

* [example] refactor scripts

* [example] update requirements

* [example] update requirements

* [example] rename llama to llama2

* [example] update readme and pretrain script

* [example] refactor scripts
pull/4546/head
Hongxin Liu 2023-08-28 17:59:11 +08:00 committed by GitHub
parent 839847b7d7
commit 0b00def881
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1087 additions and 19 deletions

View File

@ -265,6 +265,10 @@ def launch_multi_processes(args: Config) -> None:
# establish remote connection
runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env)
# overwrite master addr when num_nodes > 1 and not specified
if len(active_device_pool) > 1 and args.master_addr == "127.0.0.1":
args.master_addr = active_device_pool.hostinfo_list[0].hostname
# execute distributed launching command
for node_id, hostinfo in enumerate(active_device_pool):
cmd = get_launch_command(master_addr=args.master_addr,

View File

@ -2,7 +2,13 @@ import warnings
HAS_MEM_EFF_ATTN = False
try:
from xformers.ops.fmha import memory_efficient_attention
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
HAS_MEM_EFF_ATTN = True
except ImportError:
warnings.warn('please install xformers from https://github.com/facebookresearch/xformers')
@ -16,13 +22,6 @@ if HAS_MEM_EFF_ATTN:
from typing import Optional
import torch
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
from .utils import SeqLenInfo

View File

@ -1,11 +0,0 @@
# Pretraining LLaMA: best practices for building LLaMA-like base models
<p id="ColossalChat-Speed" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA_pretraining.png" width=600/>
</p>
- 65-billion-parameter large model pretraining accelerated by 38%
[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
> Since the main branch is being updated, in order to maintain the stability of the code, this example is temporarily kept as an [independent branch](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama).

View File

@ -0,0 +1,176 @@
# Pretraining LLaMA-2: best practices for building LLaMA-2-like base models
## Dataset
Different from the original LLaMA, we use [RedPajama](https://www.together.xyz/blog/redpajama) dataset, which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. The full dataset is ~5TB unzipped on disk and ~3TB to download compressed.
A smaller, more consumable random sample can be downloaded through [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T). If you just want to try out the pretraining script, you can use a 1B-token sample subset of RedPajama, which is available at [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample).
RedPajama-Data-1T consists of seven data slices:
| | RedPajama | LLaMA |
|---------------|--------------|---------------|
| CommonCrawl | 878 billion | 852 billion |
| C4 | 175 billion | 190 billion |
| Github | 59 billion | 100 billion |
| Books | 26 billion | 25 billion |
| ArXiv | 28 billion | 33 billion |
| Wikipedia | 24 billion | 25 billion |
| StackExchange | 20 billion | 27 billion |
| Total | 1.2 trillion | 1.25 trillion |
## Training
We follow the hyperparameter settings from the original LLaMA paper. We use AdamW with $beta1=0.9$ and $beta2=0.95$. We use a cosine learning rate schedule, such that the final learning rate is equal to 10% of the maximal learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. We use 2,000 warmup steps.
| params | learning rate | batch size |
|--------|---------------|------------|
| 6.7B | 3.0e-4 | 4M |
| 13.0B | 3.0e-4 | 4M |
| 32.5B | 1.5e-4 | 4M |
| 65.2B | 1.5e-4 | 4M |
## Usage
### 1. Installation
Please install the latest ColossalAI from source.
```bash
CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
```
Then install other dependencies.
```bash
pip install -r requirements.txt
```
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
### 2. Download the dataset
The dataset can be automatically downloaded by using `huggingface/datasets`. You can specify the dataset path by `-d` or `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`.
### 3. Command line arguments
Yon can use colossalai run to launch multi-nodes training:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
pretrain.py --OTHER_CONFIGURATIONS
```
Here is a sample hostfile:
```text
hostname1
hostname2
hostname3
hostname4
```
Make sure master node can access all nodes (including itself) by ssh without password.
Here is details about CLI arguments:
- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported.
- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama.
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
- Learning rate: `--lr`. The default value is 3e-4.
- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
- Warmup steps: `-s`, `--warmup_steps`. The default value is 2000.
- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
- Max length: `-l`, `--max_length`. The default value is 4096.
- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
- Checkpoint directory: `-o`, `--save_dir`. The directoty path to save checkpoints. The default value is `checkpoint`.
- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
### 4. Shell Script Examples
For your convenience, we provide some shell scripts to run benchmark with various configurations.
You can find them in `scripts/benchmark_7B` and `scripts/benchmark_70B` directory. The main command should be in the format of:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
benchmark.py --OTHER_CONFIGURATIONS
```
Here we will show an example of how to run training
llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`.
#### a. Running environment
This experiment was performed on 4 computing nodes with 32 A800 GPUs in total. The nodes are
connected with RDMA and GPUs within one node are fully connected with NVLink.
#### b. Running command
```bash
cd scripts/benchmark_7B
```
First, put your host file (`hosts.txt`) in this directory with your real host ip or host name.
Here is a sample `hosts.txt`:
```text
hostname1
hostname2
hostname3
hostname4
```
Then add environment variables to script if needed.
Finally, run the following command to start training:
```bash
bash gemini.sh
```
#### c. Results
If you run the above command successfully, you will get the following results:
`max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`.
## Reference
```
@article{bian2021colossal,
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
journal={arXiv preprint arXiv:2110.14883},
year={2021}
}
```
```bibtex
@software{openlm2023openllama,
author = {Geng, Xinyang and Liu, Hao},
title = {OpenLLaMA: An Open Reproduction of LLaMA},
month = May,
year = 2023,
url = {https://github.com/openlm-research/open_llama}
}
```
```bibtex
@software{together2023redpajama,
author = {Together Computer},
title = {RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset},
month = April,
year = 2023,
url = {https://github.com/togethercomputer/RedPajama-Data}
}
```
```bibtex
@article{touvron2023llama,
title={Llama: Open and efficient foundation language models},
author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others},
journal={arXiv preprint arXiv:2302.13971},
year={2023}
}
```

View File

@ -0,0 +1,83 @@
from types import MethodType
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
SUPPORT_XFORMERS = False
SUPPORT_FLASH2 = False
try:
import xformers.ops as xops
SUPPORT_XFORMERS = True
except ImportError:
pass
try:
from flash_attn import flash_attn_func
SUPPORT_FLASH2 = True
except ImportError:
pass
SUPPORT_FLASH = SUPPORT_XFORMERS or SUPPORT_FLASH2
def llama_flash_attention(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# q, k, v is [B, H, S, K] and xformers need [B, S, H, K]. returns [B, S, H, K]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if SUPPORT_FLASH2:
attn_output = flash_attn_func(query_states, key_states, value_states, causal=True)
else:
attn_output = xops.memory_efficient_attention(query_states,
key_states,
value_states,
attn_bias=xops.LowerTriangularMask())
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def replace_xformers(model: nn.Module):
for module in model.modules():
if isinstance(module, LlamaAttention):
module.forward = MethodType(llama_flash_attention, module)

View File

@ -0,0 +1,211 @@
import argparse
import resource
from contextlib import nullcontext
import torch
from attn import SUPPORT_FLASH, replace_xformers
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from tqdm import tqdm
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
# ==============================
# Constants
# ==============================
MODEL_CONFIGS = {
'7b':
LlamaConfig(max_position_embeddings=4096),
'13b':
LlamaConfig(hidden_size=5120,
intermediate_size=13824,
num_hidden_layers=40,
num_attention_heads=40,
max_position_embeddings=4096),
'70b':
LlamaConfig(hidden_size=8192,
intermediate_size=28672,
num_hidden_layers=80,
num_attention_heads=64,
max_position_embeddings=4096,
num_key_value_heads=8),
}
def main():
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
parser.add_argument('-p',
'--plugin',
choices=['gemini', 'gemini_auto', 'fsdp', 'fsdp_cpu', '3d', '3d_cpu'],
default='gemini',
help='Choose which plugin to use')
parser.add_argument('-b', '--batch_size', type=int, default=2, help='Batch size')
parser.add_argument('-s', '--num_steps', type=int, default=5, help='Number of steps to run')
parser.add_argument('-i', '--ignore_steps', type=int, default=2, help='Number of steps to ignore')
parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing')
parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length')
parser.add_argument('-w',
'--warmup_ratio',
type=float,
default=0.8,
help='warm up ratio of non-model data. Only for gemini-auto')
parser.add_argument('-m', '--memory_limit', type=int, help='Gemini memory limit in mb')
parser.add_argument('-x', '--xformers', action='store_true', help='Use xformers')
parser.add_argument('--shard_param_frac', type=float, default=1.0, help='Shard param fraction. Only for gemini')
parser.add_argument('--offload_optim_frac', type=float, default=0.0, help='Offload optim fraction. Only for gemini')
parser.add_argument('--offload_param_frac', type=float, default=0.0, help='Offload param fraction. Only for gemini')
parser.add_argument('--tp', type=int, default=1, help='Tensor parallel size')
parser.add_argument('--pp', type=int, default=1, help='Pipeline parallel size')
parser.add_argument('--mbs', type=int, default=1)
parser.add_argument('--zero', type=int, default=0)
args = parser.parse_args()
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
def empty_init():
pass
# ==============================
# Initialize Booster
# ==============================
use_empty_init = True
if args.plugin == 'gemini':
plugin = GeminiPlugin(precision='bf16',
shard_param_frac=args.shard_param_frac,
offload_optim_frac=args.offload_optim_frac,
offload_param_frac=args.offload_param_frac)
elif args.plugin == 'gemini_auto':
plugin = GeminiPlugin(placement_policy='auto', precision='bf16', warmup_non_model_data_ratio=args.warmup_ratio)
elif args.plugin == 'fsdp':
if use_empty_init:
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16),
param_init_fn=empty_init(),
)
else:
plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision(
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16))
elif args.plugin == 'fsdp_cpu':
if use_empty_init:
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16),
cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(),
)
else:
plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision(param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16),
cpu_offload=CPUOffload(offload_params=True))
elif args.plugin == '3d':
plugin = HybridParallelPlugin(tp_size=args.tp,
pp_size=args.pp,
zero_stage=args.zero,
enable_fused_normalization=True,
num_microbatches=args.mbs,
precision='bf16')
elif args.plugin == '3d_cpu':
plugin = HybridParallelPlugin(tp_size=args.tp,
pp_size=args.pp,
zero_stage=args.zero,
cpu_offload=True,
enable_fused_normalization=True,
num_microbatches=args.mbs,
initial_scale=2**8,
precision='bf16')
else:
raise ValueError(f'Unknown plugin {args.plugin}')
booster = Booster(plugin=plugin)
# ==============================
# Initialize Dataset and Dataloader
# ==============================
dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size
config = MODEL_CONFIGS[args.config]
dataset = RandomDataset(num_samples=args.batch_size * args.num_steps * dp_size,
max_length=args.max_length,
vocab_size=config.vocab_size)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
# ==============================
# Initialize Model and Optimizer
# ==============================
init_ctx = LazyInitContext(
default_device=get_current_device()) if isinstance(plugin,
(GeminiPlugin, HybridParallelPlugin)) else nullcontext()
with init_ctx:
model = LlamaForCausalLM(config)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if args.xformers:
assert SUPPORT_FLASH, 'Use flash attention while xfomers is not installed'
replace_xformers(model)
model_numel = get_model_numel(model)
coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}')
performance_evaluator = PerformanceEvaluator(model_numel,
args.grad_checkpoint,
args.ignore_steps,
dp_world_size=dp_size)
optimizer = HybridAdam(model.parameters())
torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
coordinator.print_on_master(
f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB')
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader)
for step in tqdm(range(len(dataloader)), desc='Step', disable=not coordinator.is_master()):
performance_evaluator.on_step_start(step)
booster.execute_pipeline(data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=False)
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
else:
for step, batch in enumerate(tqdm(dataloader, desc='Step', disable=not coordinator.is_master())):
performance_evaluator.on_step_start(step)
outputs = model(**batch)
loss = outputs[0]
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(**batch)
performance_evaluator.on_fit_end()
coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,119 @@
import json
import random
from typing import Iterator, Optional
import numpy as np
import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from colossalai.utils import get_current_device
class StatefulDistributedSampler(DistributedSampler):
def __init__(self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False) -> None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.start_index: int = 0
def __iter__(self) -> Iterator:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index:]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def prepare_dataloader(dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group: Optional[ProcessGroup] = None,
**kwargs):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
def load_json(file_path: str):
with open(file_path, 'r') as f:
return json.load(f)
def save_json(data, file_path: str):
with open(file_path, 'w') as f:
json.dump(data, f, indent=4)
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
'input_ids': self.input_ids[idx],
'attention_mask': self.attention_mask[idx],
'labels': self.input_ids[idx]
}

View File

@ -0,0 +1,32 @@
from contextlib import contextmanager
import torch
import torch.nn as nn
@contextmanager
def low_precision_init(target_dtype: torch.dtype = torch.float16):
dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(target_dtype)
yield
finally:
torch.set_default_dtype(dtype)
def get_model_numel(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f'{numel / B:.2f} B'
elif numel >= M:
return f'{numel / M:.2f} M'
elif numel >= K:
return f'{numel / K:.2f} K'
else:
return f'{numel}'

View File

@ -0,0 +1,102 @@
from time import time
from typing import Optional
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.cluster import DistCoordinator
def divide(x: float, y: float) -> float:
if y == 0:
return float('inf')
elif y == float('inf'):
return float('nan')
return x / y
@torch.no_grad()
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
tensor = torch.tensor([x], device=torch.cuda.current_device())
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.
def start(self) -> None:
self.start_time = time()
def end(self) -> None:
assert self.start_time is not None
self.duration += time() - self.start_time
self.start_time = None
def reset(self) -> None:
self.duration = 0.
class PerformanceEvaluator:
"""
Callback for valuate the performance of the model.
Args:
actor_num_params: The number of parameters of the actor model.
critic_num_params: The number of parameters of the critic model.
initial_model_num_params: The number of parameters of the initial model.
reward_model_num_params: The number of parameters of the reward model.
enable_grad_checkpoint: Whether to enable gradient checkpointing.
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def __init__(self,
model_numel: int,
enable_grad_checkpoint: bool = False,
ignore_steps: int = 0,
dp_world_size: Optional[int] = None) -> None:
self.model_numel = model_numel
self.enable_grad_checkpoint = enable_grad_checkpoint
self.ignore_steps = ignore_steps
self.coordinator = DistCoordinator()
self.dp_world_size = dp_world_size or self.coordinator.world_size
self.disable: bool = False
self.timer = Timer()
self.num_samples: int = 0
self.flop: int = 0
def on_step_start(self, step: int) -> None:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable:
return
torch.cuda.synchronize()
self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable:
return
torch.cuda.synchronize()
self.timer.end()
batch_size, seq_len = input_ids.shape
self.num_samples += batch_size
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
def on_fit_end(self) -> None:
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
mp_world_size = self.coordinator.world_size // self.dp_world_size
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
self.coordinator.print_on_master(
f'num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, '
f'avg_throughput: {avg_throughput}')
self.coordinator.print_on_master(
f'Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}')

View File

@ -0,0 +1,275 @@
import argparse
import os
import resource
from contextlib import nullcontext
from functools import partial
from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from attn import SUPPORT_XFORMERS, replace_xformers
from data_utils import load_json, prepare_dataloader, save_json
from datasets import load_dataset
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.llama.tokenization_llama import LlamaTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
MODEL_CONFIGS = {
'7b':
LlamaConfig(max_position_embeddings=4096),
'13b':
LlamaConfig(hidden_size=5120,
intermediate_size=13824,
num_hidden_layers=40,
num_attention_heads=40,
max_position_embeddings=4096),
'70b':
LlamaConfig(hidden_size=8192,
intermediate_size=28672,
num_hidden_layers=80,
num_attention_heads=64,
max_position_embeddings=4096,
num_key_value_heads=8),
}
def get_model_numel(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f'{numel / B:.2f} B'
elif numel >= M:
return f'{numel / M:.2f} M'
elif numel >= K:
return f'{numel / K:.2f} K'
else:
return f'{numel}'
def tokenize_batch(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
texts = [sample['text'] for sample in batch]
data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
data['labels'] = data['input_ids'].clone()
return data
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor
def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int,
batch_size: int, coordinator: DistCoordinator, save_dir: str):
save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}')
os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, 'model'), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler'))
running_states = {
'epoch': epoch,
'step': step,
'sample_start_index': step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, 'running_states.json'))
def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler,
load_dir: str) -> Tuple[int, int, int]:
booster.load_model(model, os.path.join(load_dir, 'model'))
booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer'))
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler'))
running_states = load_json(os.path.join(load_dir, 'running_states.json'))
return running_states['epoch'], running_states['step'], running_states['sample_start_index']
def main():
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
parser.add_argument('-p',
'--plugin',
choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu'],
default='gemini',
help='Choose which plugin to use')
parser.add_argument('-d',
'--dataset',
type=str,
default='togethercomputer/RedPajama-Data-1T-Sample',
help='Data set path')
parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs')
parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size')
parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay')
parser.add_argument('-s', '--warmup_steps', type=int, default=2000, help='Warmup steps')
parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing')
parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length')
parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision')
parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval')
parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory')
parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint')
parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping')
parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory')
parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention')
args = parser.parse_args()
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
# ==============================
# Initialize Tensorboard
# ==============================
if coordinator.is_master():
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
# ==============================
# Initialize Booster
# ==============================
if args.plugin == 'gemini':
plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
elif args.plugin == 'gemini_auto':
plugin = GeminiPlugin(precision=args.mixed_precision,
placement_policy='auto',
initial_scale=2**16,
max_norm=args.grad_clip)
elif args.plugin == 'zero2':
plugin = LowLevelZeroPlugin(stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip)
elif args.plugin == 'zero2_cpu':
plugin = LowLevelZeroPlugin(stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip)
else:
raise ValueError(f'Unknown plugin {args.plugin}')
booster = Booster(plugin=plugin)
# ==============================
# Initialize Tokenizer, Dataset and Dataloader
# ==============================
tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
# follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
tokenizer.pad_token = tokenizer.unk_token
dataset = load_dataset(args.dataset)
train_ds = dataset['train']
dataloader = prepare_dataloader(train_ds,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=args.max_length))
# ==============================
# Initialize Model, Optimizer and LR Scheduler
# ==============================
config = MODEL_CONFIGS[args.config]
init_ctx = LazyInitContext(
default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
with init_ctx:
model = LlamaForCausalLM(config)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if args.flash_attention:
assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed'
replace_xformers(model)
model_numel = get_model_numel(model)
coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}')
optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
lr_scheduler = CosineAnnealingWarmupLR(optimizer,
total_steps=args.num_epochs * len(dataloader),
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr)
default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model,
optimizer,
dataloader=dataloader,
lr_scheduler=lr_scheduler)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
coordinator.print_on_master(
f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB')
# load checkpoint if specified
start_epoch = 0
start_step = 0
sampler_start_idx = 0
if args.load is not None:
coordinator.print_on_master('Loading checkpoint')
start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
num_steps_per_epoch = len(dataloader)
# if resume training, set the sampler start index to the correct value
dataloader.sampler.set_start_index(sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch)
with tqdm(enumerate(dataloader),
desc=f'Epoch {epoch}',
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
initial=start_step) as pbar:
for step, batch in pbar:
batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(**batch)
loss = outputs[0]
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
all_reduce_mean(loss)
pbar.set_postfix({'loss': loss.item()})
if coordinator.is_master():
writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f'Saving checkpoint')
save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator,
args.save_dir)
coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}')
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)
start_step = 0
coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,9 @@
colossalai>=0.3.0
datasets
numpy
torch>=1.12.0,<=2.0.0
tqdm
transformers
flash-attn>=2.0.0,<=2.0.5
SentencePiece==0.1.99
tensorboard==2.14.0

View File

@ -0,0 +1,17 @@
#!/bin/bash
# TODO: fix this
echo "3D parallel for LLaMA-2 is not ready yet"
exit 1
################
#Load your environments and modules here
################
HOSTFILE=$(realpath hosts.txt)
cd ../..
export OMP_NUM_THREADS=8
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 4

View File

@ -0,0 +1,13 @@
#!/bin/bash
################
#Load your environments and modules here
################
HOSTFILE=$(realpath hosts.txt)
cd ../..
export OMP_NUM_THREADS=8
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -g -x -b 2

View File

@ -0,0 +1,13 @@
#!/bin/bash
################
#Load your environments and modules here
################
HOSTFILE=$(realpath hosts.txt)
cd ../..
export OMP_NUM_THREADS=8
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p gemini_auto -g -x -b 2

View File

@ -0,0 +1,13 @@
#!/bin/bash
################
#Load your environments and modules here
################
HOSTFILE=$(realpath hosts.txt)
cd ../..
export OMP_NUM_THREADS=8
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -g -x -b 16

View File

@ -0,0 +1,13 @@
#!/bin/bash
################
#Load your environments and modules here
################
HOSTFILE=$(realpath hosts.txt)
cd ../..
export OMP_NUM_THREADS=8
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -p gemini_auto -g -x -b 16