mirror of https://github.com/hpcaitech/ColossalAI
[example] add llama2 example (#4527)
* [example] transfer llama-1 example * [example] fit llama-2 * [example] refactor scripts folder * [example] fit new gemini plugin * [cli] fix multinode runner * [example] fit gemini optim checkpoint * [example] refactor scripts * [example] update requirements * [example] update requirements * [example] rename llama to llama2 * [example] update readme and pretrain script * [example] refactor scriptspull/4546/head
parent
839847b7d7
commit
0b00def881
|
@ -265,6 +265,10 @@ def launch_multi_processes(args: Config) -> None:
|
|||
# establish remote connection
|
||||
runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env)
|
||||
|
||||
# overwrite master addr when num_nodes > 1 and not specified
|
||||
if len(active_device_pool) > 1 and args.master_addr == "127.0.0.1":
|
||||
args.master_addr = active_device_pool.hostinfo_list[0].hostname
|
||||
|
||||
# execute distributed launching command
|
||||
for node_id, hostinfo in enumerate(active_device_pool):
|
||||
cmd = get_launch_command(master_addr=args.master_addr,
|
||||
|
|
|
@ -2,7 +2,13 @@ import warnings
|
|||
|
||||
HAS_MEM_EFF_ATTN = False
|
||||
try:
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
|
||||
from xformers.ops.fmha.attn_bias import (
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMask,
|
||||
LowerTriangularMaskWithTensorBias,
|
||||
)
|
||||
HAS_MEM_EFF_ATTN = True
|
||||
except ImportError:
|
||||
warnings.warn('please install xformers from https://github.com/facebookresearch/xformers')
|
||||
|
@ -16,13 +22,6 @@ if HAS_MEM_EFF_ATTN:
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
|
||||
from xformers.ops.fmha.attn_bias import (
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMask,
|
||||
LowerTriangularMaskWithTensorBias,
|
||||
)
|
||||
|
||||
from .utils import SeqLenInfo
|
||||
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
# Pretraining LLaMA: best practices for building LLaMA-like base models
|
||||
|
||||
<p id="ColossalChat-Speed" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA_pretraining.png" width=600/>
|
||||
</p>
|
||||
|
||||
- 65-billion-parameter large model pretraining accelerated by 38%
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
|
||||
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
|
||||
|
||||
> Since the main branch is being updated, in order to maintain the stability of the code, this example is temporarily kept as an [independent branch](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama).
|
|
@ -0,0 +1,176 @@
|
|||
# Pretraining LLaMA-2: best practices for building LLaMA-2-like base models
|
||||
|
||||
## Dataset
|
||||
|
||||
Different from the original LLaMA, we use [RedPajama](https://www.together.xyz/blog/redpajama) dataset, which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. The full dataset is ~5TB unzipped on disk and ~3TB to download compressed.
|
||||
|
||||
A smaller, more consumable random sample can be downloaded through [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T). If you just want to try out the pretraining script, you can use a 1B-token sample subset of RedPajama, which is available at [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample).
|
||||
|
||||
RedPajama-Data-1T consists of seven data slices:
|
||||
|
||||
| | RedPajama | LLaMA |
|
||||
|---------------|--------------|---------------|
|
||||
| CommonCrawl | 878 billion | 852 billion |
|
||||
| C4 | 175 billion | 190 billion |
|
||||
| Github | 59 billion | 100 billion |
|
||||
| Books | 26 billion | 25 billion |
|
||||
| ArXiv | 28 billion | 33 billion |
|
||||
| Wikipedia | 24 billion | 25 billion |
|
||||
| StackExchange | 20 billion | 27 billion |
|
||||
| Total | 1.2 trillion | 1.25 trillion |
|
||||
|
||||
## Training
|
||||
|
||||
We follow the hyperparameter settings from the original LLaMA paper. We use AdamW with $beta1=0.9$ and $beta2=0.95$. We use a cosine learning rate schedule, such that the final learning rate is equal to 10% of the maximal learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. We use 2,000 warmup steps.
|
||||
|
||||
| params | learning rate | batch size |
|
||||
|--------|---------------|------------|
|
||||
| 6.7B | 3.0e-4 | 4M |
|
||||
| 13.0B | 3.0e-4 | 4M |
|
||||
| 32.5B | 1.5e-4 | 4M |
|
||||
| 65.2B | 1.5e-4 | 4M |
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Installation
|
||||
|
||||
Please install the latest ColossalAI from source.
|
||||
|
||||
```bash
|
||||
CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
|
||||
```
|
||||
|
||||
Then install other dependencies.
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
|
||||
|
||||
### 2. Download the dataset
|
||||
|
||||
The dataset can be automatically downloaded by using `huggingface/datasets`. You can specify the dataset path by `-d` or `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`.
|
||||
|
||||
### 3. Command line arguments
|
||||
|
||||
Yon can use colossalai run to launch multi-nodes training:
|
||||
```bash
|
||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
||||
pretrain.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
|
||||
Here is a sample hostfile:
|
||||
|
||||
```text
|
||||
hostname1
|
||||
hostname2
|
||||
hostname3
|
||||
hostname4
|
||||
```
|
||||
|
||||
Make sure master node can access all nodes (including itself) by ssh without password.
|
||||
|
||||
Here is details about CLI arguments:
|
||||
|
||||
- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported.
|
||||
- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
|
||||
- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama.
|
||||
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
|
||||
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
|
||||
- Learning rate: `--lr`. The default value is 3e-4.
|
||||
- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
|
||||
- Warmup steps: `-s`, `--warmup_steps`. The default value is 2000.
|
||||
- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
|
||||
- Max length: `-l`, `--max_length`. The default value is 4096.
|
||||
- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
|
||||
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
||||
- Checkpoint directory: `-o`, `--save_dir`. The directoty path to save checkpoints. The default value is `checkpoint`.
|
||||
- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
|
||||
- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
|
||||
- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
|
||||
- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
|
||||
|
||||
|
||||
### 4. Shell Script Examples
|
||||
|
||||
For your convenience, we provide some shell scripts to run benchmark with various configurations.
|
||||
|
||||
You can find them in `scripts/benchmark_7B` and `scripts/benchmark_70B` directory. The main command should be in the format of:
|
||||
```bash
|
||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
||||
benchmark.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
Here we will show an example of how to run training
|
||||
llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`.
|
||||
|
||||
#### a. Running environment
|
||||
This experiment was performed on 4 computing nodes with 32 A800 GPUs in total. The nodes are
|
||||
connected with RDMA and GPUs within one node are fully connected with NVLink.
|
||||
|
||||
#### b. Running command
|
||||
|
||||
```bash
|
||||
cd scripts/benchmark_7B
|
||||
```
|
||||
|
||||
First, put your host file (`hosts.txt`) in this directory with your real host ip or host name.
|
||||
|
||||
Here is a sample `hosts.txt`:
|
||||
```text
|
||||
hostname1
|
||||
hostname2
|
||||
hostname3
|
||||
hostname4
|
||||
```
|
||||
|
||||
Then add environment variables to script if needed.
|
||||
|
||||
Finally, run the following command to start training:
|
||||
|
||||
```bash
|
||||
bash gemini.sh
|
||||
```
|
||||
#### c. Results
|
||||
If you run the above command successfully, you will get the following results:
|
||||
`max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`.
|
||||
|
||||
|
||||
## Reference
|
||||
```
|
||||
@article{bian2021colossal,
|
||||
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
|
||||
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
|
||||
journal={arXiv preprint arXiv:2110.14883},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@software{openlm2023openllama,
|
||||
author = {Geng, Xinyang and Liu, Hao},
|
||||
title = {OpenLLaMA: An Open Reproduction of LLaMA},
|
||||
month = May,
|
||||
year = 2023,
|
||||
url = {https://github.com/openlm-research/open_llama}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@software{together2023redpajama,
|
||||
author = {Together Computer},
|
||||
title = {RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset},
|
||||
month = April,
|
||||
year = 2023,
|
||||
url = {https://github.com/togethercomputer/RedPajama-Data}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{touvron2023llama,
|
||||
title={Llama: Open and efficient foundation language models},
|
||||
author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others},
|
||||
journal={arXiv preprint arXiv:2302.13971},
|
||||
year={2023}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,83 @@
|
|||
from types import MethodType
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
SUPPORT_XFORMERS = False
|
||||
SUPPORT_FLASH2 = False
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
SUPPORT_XFORMERS = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
SUPPORT_FLASH2 = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
SUPPORT_FLASH = SUPPORT_XFORMERS or SUPPORT_FLASH2
|
||||
|
||||
|
||||
def llama_flash_attention(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# q, k, v is [B, H, S, K] and xformers need [B, S, H, K]. returns [B, S, H, K]
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
if SUPPORT_FLASH2:
|
||||
attn_output = flash_attn_func(query_states, key_states, value_states, causal=True)
|
||||
else:
|
||||
attn_output = xops.memory_efficient_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias=xops.LowerTriangularMask())
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def replace_xformers(model: nn.Module):
|
||||
for module in model.modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.forward = MethodType(llama_flash_attention, module)
|
|
@ -0,0 +1,211 @@
|
|||
import argparse
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from attn import SUPPORT_FLASH, replace_xformers
|
||||
from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
from performance_evaluator import PerformanceEvaluator
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
|
||||
from tqdm import tqdm
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# ==============================
|
||||
# Constants
|
||||
# ==============================
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
'7b':
|
||||
LlamaConfig(max_position_embeddings=4096),
|
||||
'13b':
|
||||
LlamaConfig(hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=40,
|
||||
max_position_embeddings=4096),
|
||||
'70b':
|
||||
LlamaConfig(hidden_size=8192,
|
||||
intermediate_size=28672,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
max_position_embeddings=4096,
|
||||
num_key_value_heads=8),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
choices=['gemini', 'gemini_auto', 'fsdp', 'fsdp_cpu', '3d', '3d_cpu'],
|
||||
default='gemini',
|
||||
help='Choose which plugin to use')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=2, help='Batch size')
|
||||
parser.add_argument('-s', '--num_steps', type=int, default=5, help='Number of steps to run')
|
||||
parser.add_argument('-i', '--ignore_steps', type=int, default=2, help='Number of steps to ignore')
|
||||
parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing')
|
||||
parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length')
|
||||
parser.add_argument('-w',
|
||||
'--warmup_ratio',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='warm up ratio of non-model data. Only for gemini-auto')
|
||||
parser.add_argument('-m', '--memory_limit', type=int, help='Gemini memory limit in mb')
|
||||
parser.add_argument('-x', '--xformers', action='store_true', help='Use xformers')
|
||||
parser.add_argument('--shard_param_frac', type=float, default=1.0, help='Shard param fraction. Only for gemini')
|
||||
parser.add_argument('--offload_optim_frac', type=float, default=0.0, help='Offload optim fraction. Only for gemini')
|
||||
parser.add_argument('--offload_param_frac', type=float, default=0.0, help='Offload param fraction. Only for gemini')
|
||||
parser.add_argument('--tp', type=int, default=1, help='Tensor parallel size')
|
||||
parser.add_argument('--pp', type=int, default=1, help='Pipeline parallel size')
|
||||
parser.add_argument('--mbs', type=int, default=1)
|
||||
parser.add_argument('--zero', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
def empty_init():
|
||||
pass
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
use_empty_init = True
|
||||
if args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(precision='bf16',
|
||||
shard_param_frac=args.shard_param_frac,
|
||||
offload_optim_frac=args.offload_optim_frac,
|
||||
offload_param_frac=args.offload_param_frac)
|
||||
elif args.plugin == 'gemini_auto':
|
||||
plugin = GeminiPlugin(placement_policy='auto', precision='bf16', warmup_non_model_data_ratio=args.warmup_ratio)
|
||||
elif args.plugin == 'fsdp':
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16),
|
||||
param_init_fn=empty_init(),
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16))
|
||||
elif args.plugin == 'fsdp_cpu':
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16),
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
param_init_fn=empty_init(),
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision(param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16),
|
||||
cpu_offload=CPUOffload(offload_params=True))
|
||||
elif args.plugin == '3d':
|
||||
plugin = HybridParallelPlugin(tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
enable_fused_normalization=True,
|
||||
num_microbatches=args.mbs,
|
||||
precision='bf16')
|
||||
elif args.plugin == '3d_cpu':
|
||||
plugin = HybridParallelPlugin(tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
cpu_offload=True,
|
||||
enable_fused_normalization=True,
|
||||
num_microbatches=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision='bf16')
|
||||
else:
|
||||
raise ValueError(f'Unknown plugin {args.plugin}')
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ==============================
|
||||
# Initialize Dataset and Dataloader
|
||||
# ==============================
|
||||
dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size
|
||||
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
dataset = RandomDataset(num_samples=args.batch_size * args.num_steps * dp_size,
|
||||
max_length=args.max_length,
|
||||
vocab_size=config.vocab_size)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
|
||||
|
||||
# ==============================
|
||||
# Initialize Model and Optimizer
|
||||
# ==============================
|
||||
init_ctx = LazyInitContext(
|
||||
default_device=get_current_device()) if isinstance(plugin,
|
||||
(GeminiPlugin, HybridParallelPlugin)) else nullcontext()
|
||||
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(config)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
if args.xformers:
|
||||
assert SUPPORT_FLASH, 'Use flash attention while xfomers is not installed'
|
||||
replace_xformers(model)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}')
|
||||
performance_evaluator = PerformanceEvaluator(model_numel,
|
||||
args.grad_checkpoint,
|
||||
args.ignore_steps,
|
||||
dp_world_size=dp_size)
|
||||
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
torch.set_default_dtype(torch.float)
|
||||
coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
coordinator.print_on_master(
|
||||
f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB')
|
||||
|
||||
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||
data_iter = iter(dataloader)
|
||||
for step in tqdm(range(len(dataloader)), desc='Step', disable=not coordinator.is_master()):
|
||||
performance_evaluator.on_step_start(step)
|
||||
booster.execute_pipeline(data_iter,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=optimizer,
|
||||
return_loss=False)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||
else:
|
||||
for step, batch in enumerate(tqdm(dataloader, desc='Step', disable=not coordinator.is_master())):
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
|
||||
performance_evaluator.on_fit_end()
|
||||
coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,119 @@
|
|||
import json
|
||||
import random
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
|
||||
def __init__(self,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False) -> None:
|
||||
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
||||
self.start_index: int = 0
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index:]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples - self.start_index
|
||||
|
||||
def set_start_index(self, start_index: int) -> None:
|
||||
self.start_index = start_index
|
||||
|
||||
|
||||
def prepare_dataloader(dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**kwargs):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
|
||||
|
||||
|
||||
Args:
|
||||
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs)
|
||||
|
||||
|
||||
def load_json(file_path: str):
|
||||
with open(file_path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_json(data, file_path: str):
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
|
||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
'input_ids': self.input_ids[idx],
|
||||
'attention_mask': self.attention_mask[idx],
|
||||
'labels': self.input_ids[idx]
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def low_precision_init(target_dtype: torch.dtype = torch.float16):
|
||||
dtype = torch.get_default_dtype()
|
||||
try:
|
||||
torch.set_default_dtype(target_dtype)
|
||||
yield
|
||||
finally:
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> int:
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f'{numel / B:.2f} B'
|
||||
elif numel >= M:
|
||||
return f'{numel / M:.2f} M'
|
||||
elif numel >= K:
|
||||
return f'{numel / K:.2f} K'
|
||||
else:
|
||||
return f'{numel}'
|
|
@ -0,0 +1,102 @@
|
|||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def divide(x: float, y: float) -> float:
|
||||
if y == 0:
|
||||
return float('inf')
|
||||
elif y == float('inf'):
|
||||
return float('nan')
|
||||
return x / y
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
||||
dist.all_reduce(tensor)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
||||
|
||||
class Timer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.start_time: Optional[float] = None
|
||||
self.duration: float = 0.
|
||||
|
||||
def start(self) -> None:
|
||||
self.start_time = time()
|
||||
|
||||
def end(self) -> None:
|
||||
assert self.start_time is not None
|
||||
self.duration += time() - self.start_time
|
||||
self.start_time = None
|
||||
|
||||
def reset(self) -> None:
|
||||
self.duration = 0.
|
||||
|
||||
|
||||
class PerformanceEvaluator:
|
||||
"""
|
||||
Callback for valuate the performance of the model.
|
||||
Args:
|
||||
actor_num_params: The number of parameters of the actor model.
|
||||
critic_num_params: The number of parameters of the critic model.
|
||||
initial_model_num_params: The number of parameters of the initial model.
|
||||
reward_model_num_params: The number of parameters of the reward model.
|
||||
enable_grad_checkpoint: Whether to enable gradient checkpointing.
|
||||
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_numel: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_steps: int = 0,
|
||||
dp_world_size: Optional[int] = None) -> None:
|
||||
self.model_numel = model_numel
|
||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
||||
self.ignore_steps = ignore_steps
|
||||
|
||||
self.coordinator = DistCoordinator()
|
||||
self.dp_world_size = dp_world_size or self.coordinator.world_size
|
||||
self.disable: bool = False
|
||||
self.timer = Timer()
|
||||
self.num_samples: int = 0
|
||||
self.flop: int = 0
|
||||
|
||||
def on_step_start(self, step: int) -> None:
|
||||
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
|
||||
if self.disable:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
self.timer.start()
|
||||
|
||||
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
self.timer.end()
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
self.num_samples += batch_size
|
||||
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
|
||||
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
|
||||
mp_world_size = self.coordinator.world_size // self.dp_world_size
|
||||
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
||||
self.coordinator.print_on_master(
|
||||
f'num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, '
|
||||
f'avg_throughput: {avg_throughput}')
|
||||
self.coordinator.print_on_master(
|
||||
f'Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}')
|
|
@ -0,0 +1,275 @@
|
|||
import argparse
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from attn import SUPPORT_XFORMERS, replace_xformers
|
||||
from data_utils import load_json, prepare_dataloader, save_json
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
'7b':
|
||||
LlamaConfig(max_position_embeddings=4096),
|
||||
'13b':
|
||||
LlamaConfig(hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=40,
|
||||
max_position_embeddings=4096),
|
||||
'70b':
|
||||
LlamaConfig(hidden_size=8192,
|
||||
intermediate_size=28672,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
max_position_embeddings=4096,
|
||||
num_key_value_heads=8),
|
||||
}
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> int:
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f'{numel / B:.2f} B'
|
||||
elif numel >= M:
|
||||
return f'{numel / M:.2f} M'
|
||||
elif numel >= K:
|
||||
return f'{numel / K:.2f} K'
|
||||
else:
|
||||
return f'{numel}'
|
||||
|
||||
|
||||
def tokenize_batch(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
|
||||
texts = [sample['text'] for sample in batch]
|
||||
data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
return data
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
||||
def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int,
|
||||
batch_size: int, coordinator: DistCoordinator, save_dir: str):
|
||||
save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}')
|
||||
os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, 'model'), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler'))
|
||||
running_states = {
|
||||
'epoch': epoch,
|
||||
'step': step,
|
||||
'sample_start_index': step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, 'running_states.json'))
|
||||
|
||||
|
||||
def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler,
|
||||
load_dir: str) -> Tuple[int, int, int]:
|
||||
booster.load_model(model, os.path.join(load_dir, 'model'))
|
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer'))
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler'))
|
||||
running_states = load_json(os.path.join(load_dir, 'running_states.json'))
|
||||
return running_states['epoch'], running_states['step'], running_states['sample_start_index']
|
||||
|
||||
|
||||
def main():
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu'],
|
||||
default='gemini',
|
||||
help='Choose which plugin to use')
|
||||
parser.add_argument('-d',
|
||||
'--dataset',
|
||||
type=str,
|
||||
default='togethercomputer/RedPajama-Data-1T-Sample',
|
||||
help='Data set path')
|
||||
parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size')
|
||||
parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
|
||||
parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay')
|
||||
parser.add_argument('-s', '--warmup_steps', type=int, default=2000, help='Warmup steps')
|
||||
parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing')
|
||||
parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length')
|
||||
parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision')
|
||||
parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval')
|
||||
parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory')
|
||||
parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint')
|
||||
parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping')
|
||||
parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory')
|
||||
parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention')
|
||||
args = parser.parse_args()
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Tensorboard
|
||||
# ==============================
|
||||
if coordinator.is_master():
|
||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
|
||||
elif args.plugin == 'gemini_auto':
|
||||
plugin = GeminiPlugin(precision=args.mixed_precision,
|
||||
placement_policy='auto',
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip)
|
||||
elif args.plugin == 'zero2':
|
||||
plugin = LowLevelZeroPlugin(stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip)
|
||||
elif args.plugin == 'zero2_cpu':
|
||||
plugin = LowLevelZeroPlugin(stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip)
|
||||
else:
|
||||
raise ValueError(f'Unknown plugin {args.plugin}')
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ==============================
|
||||
# Initialize Tokenizer, Dataset and Dataloader
|
||||
# ==============================
|
||||
tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
|
||||
# follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
|
||||
dataset = load_dataset(args.dataset)
|
||||
train_ds = dataset['train']
|
||||
dataloader = prepare_dataloader(train_ds,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=args.max_length))
|
||||
|
||||
# ==============================
|
||||
# Initialize Model, Optimizer and LR Scheduler
|
||||
# ==============================
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
init_ctx = LazyInitContext(
|
||||
default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
|
||||
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(config)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.flash_attention:
|
||||
assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed'
|
||||
replace_xformers(model)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}')
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer,
|
||||
total_steps=args.num_epochs * len(dataloader),
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr)
|
||||
default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model,
|
||||
optimizer,
|
||||
dataloader=dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
coordinator.print_on_master(
|
||||
f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB')
|
||||
|
||||
# load checkpoint if specified
|
||||
start_epoch = 0
|
||||
start_step = 0
|
||||
sampler_start_idx = 0
|
||||
if args.load is not None:
|
||||
coordinator.print_on_master('Loading checkpoint')
|
||||
start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
|
||||
coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
|
||||
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
# if resume training, set the sampler start index to the correct value
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
with tqdm(enumerate(dataloader),
|
||||
desc=f'Epoch {epoch}',
|
||||
disable=not coordinator.is_master(),
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step) as pbar:
|
||||
for step, batch in pbar:
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
all_reduce_mean(loss)
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
if coordinator.is_master():
|
||||
writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
|
||||
|
||||
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
|
||||
coordinator.print_on_master(f'Saving checkpoint')
|
||||
save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator,
|
||||
args.save_dir)
|
||||
coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}')
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(0)
|
||||
start_step = 0
|
||||
|
||||
coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,9 @@
|
|||
colossalai>=0.3.0
|
||||
datasets
|
||||
numpy
|
||||
torch>=1.12.0,<=2.0.0
|
||||
tqdm
|
||||
transformers
|
||||
flash-attn>=2.0.0,<=2.0.5
|
||||
SentencePiece==0.1.99
|
||||
tensorboard==2.14.0
|
|
@ -0,0 +1,17 @@
|
|||
#!/bin/bash
|
||||
|
||||
# TODO: fix this
|
||||
echo "3D parallel for LLaMA-2 is not ready yet"
|
||||
exit 1
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 4
|
|
@ -0,0 +1,13 @@
|
|||
#!/bin/bash
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -g -x -b 2
|
|
@ -0,0 +1,13 @@
|
|||
#!/bin/bash
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p gemini_auto -g -x -b 2
|
|
@ -0,0 +1,13 @@
|
|||
#!/bin/bash
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -g -x -b 16
|
|
@ -0,0 +1,13 @@
|
|||
#!/bin/bash
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -p gemini_auto -g -x -b 16
|
Loading…
Reference in New Issue