[exampe] update llama example (#5626)

* [plugin] support dp inside for hybriad parallel

* [example] update llama benchmark

* [example] update llama benchmark

* [example] update llama readme

* [example] update llama readme
pull/5631/head
Hongxin Liu 2024-04-23 14:12:20 +08:00 committed by GitHub
parent 862fbaaa62
commit 4de4e31818
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 72 additions and 783 deletions

View File

@ -424,6 +424,7 @@ class GeminiPlugin(DPPluginBase):
)
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
self.dp_size = self.zero_size * self.extra_dp_size
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,

View File

@ -34,7 +34,6 @@ from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase
DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
@ -987,6 +986,7 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
) -> None:
super().__init__()
assert (
@ -1034,7 +1034,12 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
@ -1048,7 +1053,7 @@ class HybridParallelPlugin(PipelinePluginBase):
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=PP_AXIS,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks,
)
@ -1072,13 +1077,13 @@ class HybridParallelPlugin(PipelinePluginBase):
else:
raise NotImplementedError()
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
@ -1169,7 +1174,7 @@ class HybridParallelPlugin(PipelinePluginBase):
and self.sequence_parallelism_mode == "all_to_all"
)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS])
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
dp_group = self.dp_group
model = HybridParallelModule(
@ -1317,7 +1322,10 @@ class HybridParallelPlugin(PipelinePluginBase):
_kwargs = kwargs.copy()
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
dataset,
num_replicas=self.pg_mesh.size(self.dp_axis),
rank=self.pg_mesh.coordinate(self.dp_axis),
shuffle=shuffle,
)
# Deterministic dataloader

View File

@ -1,4 +1,4 @@
# Pretraining LLaMA-1/2: best practices for building LLaMA-1/2-like base models
# Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models
### LLaMA2
<p align="center">
@ -16,38 +16,10 @@
- 65-billion-parameter large model pretraining accelerated by 38%
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
## Dataset
Different from the original LLaMA, we use [RedPajama](https://www.together.xyz/blog/redpajama) dataset, which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. The full dataset is ~5TB unzipped on disk and ~3TB to download compressed.
A smaller, more consumable random sample can be downloaded through [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T). If you just want to try out the pretraining script, you can use a 1B-token sample subset of RedPajama, which is available at [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample).
RedPajama-Data-1T consists of seven data slices:
| | RedPajama | LLaMA |
|---------------|--------------|---------------|
| CommonCrawl | 878 billion | 852 billion |
| C4 | 175 billion | 190 billion |
| Github | 59 billion | 100 billion |
| Books | 26 billion | 25 billion |
| ArXiv | 28 billion | 33 billion |
| Wikipedia | 24 billion | 25 billion |
| StackExchange | 20 billion | 27 billion |
| Total | 1.2 trillion | 1.25 trillion |
## Training
We follow the hyperparameter settings from the original LLaMA paper. We use AdamW with $beta1=0.9$ and $beta2=0.95$. We use a cosine learning rate schedule, such that the final learning rate is equal to 10% of the maximal learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. We use 2,000 warmup steps.
| params | learning rate | batch size |
|--------|---------------|------------|
| 6.7B | 3.0e-4 | 4M |
| 13.0B | 3.0e-4 | 4M |
| 32.5B | 1.5e-4 | 4M |
| 65.2B | 1.5e-4 | 4M |
## Usage
> ⚠ This example only has benchmarking script. For training/finetuning, please refer to the [applications/Colossal-LLaMA](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA).
### 1. Installation
Please install the latest ColossalAI from source.
@ -62,52 +34,6 @@ Then install other dependencies.
pip install -r requirements.txt
```
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
### 2. Download the dataset
The dataset can be automatically downloaded by using `huggingface/datasets`. You can specify the dataset path by `-d` or `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`.
### 3. Command line arguments
Yon can use colossalai run to launch multi-nodes training:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
pretrain.py --OTHER_CONFIGURATIONS
```
Here is a sample hostfile:
```text
hostname1
hostname2
hostname3
hostname4
```
Make sure master node can access all nodes (including itself) by ssh without password.
Here is details about CLI arguments:
- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2.
- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama.
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
- Learning rate: `--lr`. The default value is 3e-4.
- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
- Warmup steps: `-s`, `--warmup_steps`. The default value is 2000.
- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
- Max length: `-l`, `--max_length`. The default value is 4096.
- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`.
- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
### 4. Shell Script Examples
For your convenience, we provide some shell scripts to run benchmark with various configurations.
@ -193,40 +119,3 @@ If you run the above command successfully, you will get the following results:
year={2023}
}
```
# Fine-tune Llama2
We also provide a example to fine-tune llama2 in `finetune.py`,
Make sure master node can access all nodes (including itself) by ssh without password.
Here is details about CLI arguments:
- Pretrained checkpoint path: `--model_path`, the path of your model checkpoint, it can be your local directory or a Hugging Face tag.
- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
- Dataset path: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as `yizhongw/self_instruct`.
- task name: `--task_name`, the task to fine-tune, it's also related to the target of loading dataset, The default value is `super_natural_instructions`.
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
- Learning rate: `--lr`. The default value is 3e-4.
- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
- Max length: `-l`, `--max_length`. The default value is 4096.
- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`.
- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
```shell
torchrun --standalone --nproc_per_node 8 finetune.py \
--plugin "hybrid_parallel" \
--dataset "yizhongw/self_instruct" \
--model_path "/path/llama" \
--task_name "super_natural_instructions" \
--save_dir "/path/output"
```

View File

@ -1 +0,0 @@
../../../applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py

View File

@ -3,14 +3,13 @@ import resource
from contextlib import nullcontext
import torch
from attn import replace_with_flash_attention
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import colossalai
from colossalai.accelerator import get_accelerator
@ -19,6 +18,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer import PipelineGradientCheckpointConfig
from examples.language.data_utils import RandomDataset
from examples.language.model_utils import format_numel_str, get_model_numel
from examples.language.performance_evaluator import PerformanceEvaluator
@ -78,6 +78,7 @@ def main():
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
args = parser.parse_args()
colossalai.launch_from_torch({})
@ -86,6 +87,19 @@ def main():
def empty_init():
pass
# ckpt config for LLaMA3-70B on 64 H100 GPUs
ckpt_config = (
PipelineGradientCheckpointConfig(
num_stages=args.pp,
num_model_chunks=1,
num_model_layers=80,
num_layers_per_stage=[19, 20, 20, 21],
num_ckpt_layers_per_stage=[19, 19, 19, 13],
)
if args.custom_ckpt
else None
)
# ==============================
# Initialize Booster
# ==============================
@ -98,6 +112,8 @@ def main():
offload_param_frac=args.offload_param_frac,
tp_size=args.tp,
extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@ -106,26 +122,34 @@ def main():
warmup_non_model_data_ratio=args.warmup_ratio,
tp_size=args.tp,
extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
)
elif args.plugin == "fsdp":
if use_empty_init:
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
),
param_init_fn=empty_init(),
)
else:
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
)
)
elif args.plugin == "fsdp_cpu":
if use_empty_init:
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
),
cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(),
@ -133,7 +157,9 @@ def main():
else:
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
),
cpu_offload=CPUOffload(offload_params=True),
)
@ -141,12 +167,13 @@ def main():
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
pp_style="interleaved",
zero_stage=args.zero,
num_model_chunks=2,
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
dp_outside=False,
gradient_checkpoint_config=ckpt_config,
)
elif args.plugin == "3d_cpu":
plugin = HybridParallelPlugin(
@ -155,6 +182,7 @@ def main():
zero_stage=args.zero,
cpu_offload=True,
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
initial_scale=2**8,
precision="bf16",
@ -167,9 +195,12 @@ def main():
# ==============================
# Initialize Dataset and Dataloader
# ==============================
dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
config = MODEL_CONFIGS[args.config]
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
)
@ -184,14 +215,17 @@ def main():
else nullcontext()
)
init_kwargs = {}
if config.model_type == "chatglm":
init_kwargs["empty_init"] = False
with init_ctx:
model = LlamaForCausalLM(config)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if args.xformers:
replace_with_flash_attention(model)
if config.model_type == "chatglm":
model.transformer.encoder.gradient_checkpointing = True
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")

View File

@ -1,313 +0,0 @@
import argparse
import math
import os
import resource
from contextlib import nullcontext
from functools import partial
from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from attn import replace_with_flash_attention
from data_utils import load_json, prepare_dataloader, save_json
from datasets import load_dataset
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.llama.tokenization_llama import LlamaTokenizer
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
def get_model_numel(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f"{numel / B:.2f} B"
elif numel >= M:
return f"{numel / M:.2f} M"
elif numel >= K:
return f"{numel / K:.2f} K"
else:
return f"{numel}"
def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
texts = [sample["prompt"] + sample["completion"] for sample in batch]
data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length)
data = {k: v.cuda() for k, v in data.items()}
data["labels"] = data["input_ids"].clone()
return data
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
tensor = tensor.data
tensor.div_(dist.get_world_size())
return tensor
def save(
booster: Booster,
model: nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
save_dir: str,
):
save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
def load(
booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str
) -> Tuple[int, int, int]:
booster.load_model(model, os.path.join(load_dir, "model"))
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(os.path.join(load_dir, "running_states.json"))
return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
def _criterion(outputs, inputs):
return outputs.loss
def main():
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, help="pretrained checkpoint path, used with mode==finetune")
parser.add_argument(
"-p",
"--plugin",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"],
default="gemini",
help="Choose which plugin to use",
)
parser.add_argument("-d", "--dataset", type=str, default="yizhongw/self_instruct", help="Data set path")
parser.add_argument("--task_name", type=str, default="super_natural_instructions", help="task to run")
parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs")
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval")
parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory")
parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint")
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping")
parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory")
parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention")
args = parser.parse_args()
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "gemini":
plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip
)
elif args.plugin == "hybrid_parallel":
# modify the param accordingly, default configuration is for llama2-7b
plugin = HybridParallelPlugin(
tp_size=4,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_jit_fused=False,
zero_stage=0,
precision="fp32",
initial_scale=1,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
# ==============================
# Initialize Tensorboard
# ==============================
if print_flag:
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
# ==============================
# Initialize Model, Optimizer and LR Scheduler
# ==============================
config = LlamaConfig.from_pretrained(args.model_path)
# use lazy init when using GeminiPlugin
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, GeminiPlugin)
else nullcontext()
)
with init_ctx:
model = LlamaForCausalLM(config)
# ==============================
# Initialize Tokenizer, Dataset and Dataloader
# ==============================
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
# follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
tokenizer.pad_token = tokenizer.unk_token
dataset = load_dataset(args.dataset, args.task_name)
train_ds = dataset["train"]
dataloader = prepare_dataloader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=partial(tokenize_batch_for_finetune, tokenizer=tokenizer, max_length=args.max_length),
)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if args.flash_attention:
replace_with_flash_attention(model)
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
total_step = args.num_epochs * len(dataloader)
lr_scheduler = CosineAnnealingWarmupLR(
optimizer, total_steps=total_step, warmup_steps=math.ceil(total_step * 0.03), eta_min=0.1 * args.lr
)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler
)
torch.set_default_dtype(torch.float)
booster.load_model(model, args.model_path)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
)
# load checkpoint if specified
start_epoch = 0
start_step = 0
sampler_start_idx = 0
if args.load is not None:
coordinator.print_on_master("Loading checkpoint")
start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}")
num_steps_per_epoch = len(dataloader)
# if resume training, set the sampler start index to the correct value
dataloader.sampler.set_start_index(sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch)
step_nums = num_steps_per_epoch - start_step
dataloader_iter = iter(dataloader)
with tqdm(
range(step_nums),
desc=f"Epoch {epoch}",
disable=not print_flag,
total=num_steps_per_epoch,
initial=start_step,
) as pbar:
for step in pbar:
if use_pipeline:
outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
loss = outputs["loss"]
else:
batch = next(dataloader_iter)
outputs = model(**batch)
loss = outputs[0]
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if not use_pipeline:
all_reduce_mean(loss)
if print_flag:
pbar.set_postfix({"loss": loss.item()})
writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step)
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving checkpoint")
save(
booster,
model,
optimizer,
lr_scheduler,
epoch,
step + 1,
args.batch_size,
coordinator,
args.save_dir,
)
coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}")
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)
start_step = 0
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
main()

View File

@ -1,328 +0,0 @@
import argparse
import os
import resource
from contextlib import nullcontext
from functools import partial
from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from attn import replace_with_flash_attention
from data_utils import load_json, prepare_dataloader, save_json
from datasets import load_dataset
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.llama.tokenization_llama import LlamaTokenizer
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
MODEL_CONFIGS = {
"7b": LlamaConfig(max_position_embeddings=4096),
"13b": LlamaConfig(
hidden_size=5120,
intermediate_size=13824,
num_hidden_layers=40,
num_attention_heads=40,
max_position_embeddings=4096,
),
"70b": LlamaConfig(
hidden_size=8192,
intermediate_size=28672,
num_hidden_layers=80,
num_attention_heads=64,
max_position_embeddings=4096,
num_key_value_heads=8,
),
}
def get_model_numel(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f"{numel / B:.2f} B"
elif numel >= M:
return f"{numel / M:.2f} M"
elif numel >= K:
return f"{numel / K:.2f} K"
else:
return f"{numel}"
def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
texts = [sample["text"] for sample in batch]
data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length)
data = {k: v.cuda() for k, v in data.items()}
data["labels"] = data["input_ids"].clone()
return data
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
tensor = tensor.data
tensor.div_(dist.get_world_size())
return tensor
def save(
booster: Booster,
model: nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
save_dir: str,
):
save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
def load(
booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str
) -> Tuple[int, int, int]:
booster.load_model(model, os.path.join(load_dir, "model"))
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(os.path.join(load_dir, "running_states.json"))
return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
def _criterion(outputs, inputs):
return outputs.loss
def main():
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration")
parser.add_argument(
"-p",
"--plugin",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"],
default="gemini",
help="Choose which plugin to use",
)
parser.add_argument(
"-d", "--dataset", type=str, default="togethercomputer/RedPajama-Data-1T-Sample", help="Data set path"
)
parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs")
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("-s", "--warmup_steps", type=int, default=2000, help="Warmup steps")
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval")
parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory")
parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint")
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping")
parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory")
parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention")
args = parser.parse_args()
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "gemini":
plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip
)
elif args.plugin == "hybrid_parallel":
# modify the param accordingly, default configuration is for llama2-7b
plugin = HybridParallelPlugin(
tp_size=4,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_jit_fused=False,
zero_stage=0,
precision=args.mixed_precision,
initial_scale=1,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
# ==============================
# Initialize Tensorboard
# ==============================
if print_flag:
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
# ==============================
# Initialize Tokenizer, Dataset and Dataloader
# ==============================
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
# follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
tokenizer.pad_token = tokenizer.unk_token
dataset = load_dataset(args.dataset)
train_ds = dataset["train"]
dataloader = prepare_dataloader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=partial(tokenize_batch_for_pretrain, tokenizer=tokenizer, max_length=args.max_length),
)
# ==============================
# Initialize Model, Optimizer and LR Scheduler
# ==============================
config = MODEL_CONFIGS[args.config]
# use lazy init when using GeminiPlugin
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, GeminiPlugin)
else nullcontext()
)
with init_ctx:
model = LlamaForCausalLM(config)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if args.flash_attention:
replace_with_flash_attention(model)
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
lr_scheduler = CosineAnnealingWarmupLR(
optimizer, total_steps=args.num_epochs * len(dataloader), warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr
)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler
)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
)
# load checkpoint if specified
start_epoch = 0
start_step = 0
sampler_start_idx = 0
if args.load is not None:
coordinator.print_on_master("Loading checkpoint")
start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}")
num_steps_per_epoch = len(dataloader)
# if resume training, set the sampler start index to the correct value
dataloader.sampler.set_start_index(sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
with tqdm(
range(start_step, num_steps_per_epoch),
desc=f"Epoch {epoch}",
disable=not print_flag,
total=num_steps_per_epoch,
initial=start_step,
) as pbar:
for step in pbar:
if use_pipeline:
outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
loss = outputs["loss"]
else:
batch = next(dataloader_iter)
outputs = model(**batch)
loss = outputs[0]
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if not use_pipeline:
all_reduce_mean(loss)
if print_flag:
pbar.set_postfix({"loss": loss.item()})
writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step)
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving checkpoint")
save(
booster,
model,
optimizer,
lr_scheduler,
epoch,
step + 1,
args.batch_size,
coordinator,
args.save_dir,
)
coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}")
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)
start_step = 0
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
main()

View File

@ -1,9 +1,8 @@
colossalai>=0.3.2
colossalai>=0.3.6
datasets
numpy
torch>=1.12.0,<=2.0.0
tqdm
transformers
flash-attn>=2.0.0,<=2.0.5
flash-attn>=2.0.0
SentencePiece==0.1.99
tensorboard==2.14.0