mirror of https://github.com/hpcaitech/ColossalAI
[exampe] update llama example (#5626)
* [plugin] support dp inside for hybriad parallel * [example] update llama benchmark * [example] update llama benchmark * [example] update llama readme * [example] update llama readmepull/5631/head
parent
862fbaaa62
commit
4de4e31818
|
@ -424,6 +424,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
)
|
||||
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
|
||||
self.dp_size = self.zero_size * self.extra_dp_size
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
|
|
|
@ -34,7 +34,6 @@ from colossalai.zero.low_level import LowLevelZeroOptimizer
|
|||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3
|
||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
||||
|
||||
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
|
||||
|
@ -987,6 +986,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
dp_outside: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert (
|
||||
|
@ -1034,7 +1034,12 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
if dp_outside:
|
||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
else:
|
||||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
self.custom_policy = custom_policy
|
||||
|
@ -1048,7 +1053,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=PP_AXIS,
|
||||
pipeline_axis=self.pp_axis,
|
||||
enable_interleave=pp_style == "interleaved",
|
||||
num_model_chunks=num_model_chunks,
|
||||
)
|
||||
|
@ -1072,13 +1077,13 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
else:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
|
@ -1169,7 +1174,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS])
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
model = HybridParallelModule(
|
||||
|
@ -1317,7 +1322,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
_kwargs = kwargs.copy()
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(
|
||||
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
||||
dataset,
|
||||
num_replicas=self.pg_mesh.size(self.dp_axis),
|
||||
rank=self.pg_mesh.coordinate(self.dp_axis),
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Pretraining LLaMA-1/2: best practices for building LLaMA-1/2-like base models
|
||||
# Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models
|
||||
|
||||
### LLaMA2
|
||||
<p align="center">
|
||||
|
@ -16,38 +16,10 @@
|
|||
- 65-billion-parameter large model pretraining accelerated by 38%
|
||||
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
|
||||
|
||||
## Dataset
|
||||
|
||||
Different from the original LLaMA, we use [RedPajama](https://www.together.xyz/blog/redpajama) dataset, which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. The full dataset is ~5TB unzipped on disk and ~3TB to download compressed.
|
||||
|
||||
A smaller, more consumable random sample can be downloaded through [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T). If you just want to try out the pretraining script, you can use a 1B-token sample subset of RedPajama, which is available at [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample).
|
||||
|
||||
RedPajama-Data-1T consists of seven data slices:
|
||||
|
||||
| | RedPajama | LLaMA |
|
||||
|---------------|--------------|---------------|
|
||||
| CommonCrawl | 878 billion | 852 billion |
|
||||
| C4 | 175 billion | 190 billion |
|
||||
| Github | 59 billion | 100 billion |
|
||||
| Books | 26 billion | 25 billion |
|
||||
| ArXiv | 28 billion | 33 billion |
|
||||
| Wikipedia | 24 billion | 25 billion |
|
||||
| StackExchange | 20 billion | 27 billion |
|
||||
| Total | 1.2 trillion | 1.25 trillion |
|
||||
|
||||
## Training
|
||||
|
||||
We follow the hyperparameter settings from the original LLaMA paper. We use AdamW with $beta1=0.9$ and $beta2=0.95$. We use a cosine learning rate schedule, such that the final learning rate is equal to 10% of the maximal learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. We use 2,000 warmup steps.
|
||||
|
||||
| params | learning rate | batch size |
|
||||
|--------|---------------|------------|
|
||||
| 6.7B | 3.0e-4 | 4M |
|
||||
| 13.0B | 3.0e-4 | 4M |
|
||||
| 32.5B | 1.5e-4 | 4M |
|
||||
| 65.2B | 1.5e-4 | 4M |
|
||||
|
||||
## Usage
|
||||
|
||||
> ⚠ This example only has benchmarking script. For training/finetuning, please refer to the [applications/Colossal-LLaMA](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA).
|
||||
|
||||
### 1. Installation
|
||||
|
||||
Please install the latest ColossalAI from source.
|
||||
|
@ -62,52 +34,6 @@ Then install other dependencies.
|
|||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
|
||||
|
||||
### 2. Download the dataset
|
||||
|
||||
The dataset can be automatically downloaded by using `huggingface/datasets`. You can specify the dataset path by `-d` or `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`.
|
||||
|
||||
### 3. Command line arguments
|
||||
|
||||
Yon can use colossalai run to launch multi-nodes training:
|
||||
```bash
|
||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
||||
pretrain.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
|
||||
Here is a sample hostfile:
|
||||
|
||||
```text
|
||||
hostname1
|
||||
hostname2
|
||||
hostname3
|
||||
hostname4
|
||||
```
|
||||
|
||||
Make sure master node can access all nodes (including itself) by ssh without password.
|
||||
|
||||
Here is details about CLI arguments:
|
||||
|
||||
- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2.
|
||||
- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
|
||||
- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama.
|
||||
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
|
||||
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
|
||||
- Learning rate: `--lr`. The default value is 3e-4.
|
||||
- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
|
||||
- Warmup steps: `-s`, `--warmup_steps`. The default value is 2000.
|
||||
- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
|
||||
- Max length: `-l`, `--max_length`. The default value is 4096.
|
||||
- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
|
||||
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
||||
- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`.
|
||||
- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
|
||||
- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
|
||||
- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
|
||||
- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
|
||||
|
||||
|
||||
### 4. Shell Script Examples
|
||||
|
||||
For your convenience, we provide some shell scripts to run benchmark with various configurations.
|
||||
|
@ -193,40 +119,3 @@ If you run the above command successfully, you will get the following results:
|
|||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
# Fine-tune Llama2
|
||||
|
||||
We also provide a example to fine-tune llama2 in `finetune.py`,
|
||||
|
||||
Make sure master node can access all nodes (including itself) by ssh without password.
|
||||
|
||||
Here is details about CLI arguments:
|
||||
|
||||
- Pretrained checkpoint path: `--model_path`, the path of your model checkpoint, it can be your local directory or a Hugging Face tag.
|
||||
- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
|
||||
- Dataset path: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as `yizhongw/self_instruct`.
|
||||
- task name: `--task_name`, the task to fine-tune, it's also related to the target of loading dataset, The default value is `super_natural_instructions`.
|
||||
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
|
||||
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
|
||||
- Learning rate: `--lr`. The default value is 3e-4.
|
||||
- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
|
||||
- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
|
||||
- Max length: `-l`, `--max_length`. The default value is 4096.
|
||||
- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
|
||||
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
||||
- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`.
|
||||
- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
|
||||
- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
|
||||
- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
|
||||
- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
|
||||
|
||||
|
||||
```shell
|
||||
torchrun --standalone --nproc_per_node 8 finetune.py \
|
||||
--plugin "hybrid_parallel" \
|
||||
--dataset "yizhongw/self_instruct" \
|
||||
--model_path "/path/llama" \
|
||||
--task_name "super_natural_instructions" \
|
||||
--save_dir "/path/output"
|
||||
```
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
../../../applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
|
|
@ -3,14 +3,13 @@ import resource
|
|||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from attn import replace_with_flash_attention
|
||||
from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
from performance_evaluator import PerformanceEvaluator
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
@ -19,6 +18,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
|
|||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
from examples.language.data_utils import RandomDataset
|
||||
from examples.language.model_utils import format_numel_str, get_model_numel
|
||||
from examples.language.performance_evaluator import PerformanceEvaluator
|
||||
|
@ -78,6 +78,7 @@ def main():
|
|||
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch({})
|
||||
|
@ -86,6 +87,19 @@ def main():
|
|||
def empty_init():
|
||||
pass
|
||||
|
||||
# ckpt config for LLaMA3-70B on 64 H100 GPUs
|
||||
ckpt_config = (
|
||||
PipelineGradientCheckpointConfig(
|
||||
num_stages=args.pp,
|
||||
num_model_chunks=1,
|
||||
num_model_layers=80,
|
||||
num_layers_per_stage=[19, 20, 20, 21],
|
||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||
)
|
||||
if args.custom_ckpt
|
||||
else None
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
|
@ -98,6 +112,8 @@ def main():
|
|||
offload_param_frac=args.offload_param_frac,
|
||||
tp_size=args.tp,
|
||||
extra_dp_size=args.extra_dp,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
|
@ -106,26 +122,34 @@ def main():
|
|||
warmup_non_model_data_ratio=args.warmup_ratio,
|
||||
tp_size=args.tp,
|
||||
extra_dp_size=args.extra_dp,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
)
|
||||
elif args.plugin == "fsdp":
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||
param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16,
|
||||
),
|
||||
param_init_fn=empty_init(),
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||
param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16,
|
||||
)
|
||||
)
|
||||
elif args.plugin == "fsdp_cpu":
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||
param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16,
|
||||
),
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
param_init_fn=empty_init(),
|
||||
|
@ -133,7 +157,9 @@ def main():
|
|||
else:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||
param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16,
|
||||
),
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
)
|
||||
|
@ -141,12 +167,13 @@ def main():
|
|||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
pp_style="interleaved",
|
||||
zero_stage=args.zero,
|
||||
num_model_chunks=2,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
dp_outside=False,
|
||||
gradient_checkpoint_config=ckpt_config,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
plugin = HybridParallelPlugin(
|
||||
|
@ -155,6 +182,7 @@ def main():
|
|||
zero_stage=args.zero,
|
||||
cpu_offload=True,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
microbatch_size=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
|
@ -167,9 +195,12 @@ def main():
|
|||
# ==============================
|
||||
# Initialize Dataset and Dataloader
|
||||
# ==============================
|
||||
dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size
|
||||
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
||||
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
if args.config in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||
)
|
||||
|
@ -184,14 +215,17 @@ def main():
|
|||
else nullcontext()
|
||||
)
|
||||
|
||||
init_kwargs = {}
|
||||
if config.model_type == "chatglm":
|
||||
init_kwargs["empty_init"] = False
|
||||
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(config)
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
if args.xformers:
|
||||
replace_with_flash_attention(model)
|
||||
if config.model_type == "chatglm":
|
||||
model.transformer.encoder.gradient_checkpointing = True
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
|
|
@ -1,313 +0,0 @@
|
|||
import argparse
|
||||
import math
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from attn import replace_with_flash_attention
|
||||
from data_utils import load_json, prepare_dataloader, save_json
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> int:
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f"{numel / B:.2f} B"
|
||||
elif numel >= M:
|
||||
return f"{numel / M:.2f} M"
|
||||
elif numel >= K:
|
||||
return f"{numel / K:.2f} K"
|
||||
else:
|
||||
return f"{numel}"
|
||||
|
||||
|
||||
def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
|
||||
texts = [sample["prompt"] + sample["completion"] for sample in batch]
|
||||
data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length)
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
||||
tensor = tensor.data
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
||||
def save(
|
||||
booster: Booster,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
epoch: int,
|
||||
step: int,
|
||||
batch_size: int,
|
||||
coordinator: DistCoordinator,
|
||||
save_dir: str,
|
||||
):
|
||||
save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"sample_start_index": step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
||||
|
||||
def load(
|
||||
booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str
|
||||
) -> Tuple[int, int, int]:
|
||||
booster.load_model(model, os.path.join(load_dir, "model"))
|
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
|
||||
running_states = load_json(os.path.join(load_dir, "running_states.json"))
|
||||
return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
|
||||
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
return outputs.loss
|
||||
|
||||
|
||||
def main():
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, help="pretrained checkpoint path, used with mode==finetune")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--plugin",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"],
|
||||
default="gemini",
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("-d", "--dataset", type=str, default="yizhongw/self_instruct", help="Data set path")
|
||||
parser.add_argument("--task_name", type=str, default="super_natural_instructions", help="task to run")
|
||||
parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
|
||||
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
|
||||
parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval")
|
||||
parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory")
|
||||
parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping")
|
||||
parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory")
|
||||
parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip
|
||||
)
|
||||
elif args.plugin == "hybrid_parallel":
|
||||
# modify the param accordingly, default configuration is for llama2-7b
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=4,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_jit_fused=False,
|
||||
zero_stage=0,
|
||||
precision="fp32",
|
||||
initial_scale=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
|
||||
|
||||
# ==============================
|
||||
# Initialize Tensorboard
|
||||
# ==============================
|
||||
if print_flag:
|
||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
|
||||
# ==============================
|
||||
# Initialize Model, Optimizer and LR Scheduler
|
||||
# ==============================
|
||||
|
||||
config = LlamaConfig.from_pretrained(args.model_path)
|
||||
# use lazy init when using GeminiPlugin
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||
if isinstance(plugin, GeminiPlugin)
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(config)
|
||||
|
||||
# ==============================
|
||||
# Initialize Tokenizer, Dataset and Dataloader
|
||||
# ==============================
|
||||
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
# follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
|
||||
dataset = load_dataset(args.dataset, args.task_name)
|
||||
train_ds = dataset["train"]
|
||||
dataloader = prepare_dataloader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=partial(tokenize_batch_for_finetune, tokenizer=tokenizer, max_length=args.max_length),
|
||||
)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.flash_attention:
|
||||
replace_with_flash_attention(model)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
|
||||
total_step = args.num_epochs * len(dataloader)
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer, total_steps=total_step, warmup_steps=math.ceil(total_step * 0.03), eta_min=0.1 * args.lr
|
||||
)
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
booster.load_model(model, args.model_path)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||
)
|
||||
|
||||
# load checkpoint if specified
|
||||
start_epoch = 0
|
||||
start_step = 0
|
||||
sampler_start_idx = 0
|
||||
if args.load is not None:
|
||||
coordinator.print_on_master("Loading checkpoint")
|
||||
start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
|
||||
coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}")
|
||||
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
|
||||
# if resume training, set the sampler start index to the correct value
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
step_nums = num_steps_per_epoch - start_step
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
with tqdm(
|
||||
range(step_nums),
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not print_flag,
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
|
||||
loss = outputs["loss"]
|
||||
else:
|
||||
batch = next(dataloader_iter)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
booster.backward(loss, optimizer)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if not use_pipeline:
|
||||
all_reduce_mean(loss)
|
||||
if print_flag:
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step)
|
||||
|
||||
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
|
||||
coordinator.print_on_master(f"Saving checkpoint")
|
||||
save(
|
||||
booster,
|
||||
model,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
epoch,
|
||||
step + 1,
|
||||
args.batch_size,
|
||||
coordinator,
|
||||
args.save_dir,
|
||||
)
|
||||
coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}")
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(0)
|
||||
start_step = 0
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,328 +0,0 @@
|
|||
import argparse
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from attn import replace_with_flash_attention
|
||||
from data_utils import load_json, prepare_dataloader, save_json
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"7b": LlamaConfig(max_position_embeddings=4096),
|
||||
"13b": LlamaConfig(
|
||||
hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=40,
|
||||
max_position_embeddings=4096,
|
||||
),
|
||||
"70b": LlamaConfig(
|
||||
hidden_size=8192,
|
||||
intermediate_size=28672,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
max_position_embeddings=4096,
|
||||
num_key_value_heads=8,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> int:
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f"{numel / B:.2f} B"
|
||||
elif numel >= M:
|
||||
return f"{numel / M:.2f} M"
|
||||
elif numel >= K:
|
||||
return f"{numel / K:.2f} K"
|
||||
else:
|
||||
return f"{numel}"
|
||||
|
||||
|
||||
def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
|
||||
texts = [sample["text"] for sample in batch]
|
||||
data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length)
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
||||
tensor = tensor.data
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
||||
def save(
|
||||
booster: Booster,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
epoch: int,
|
||||
step: int,
|
||||
batch_size: int,
|
||||
coordinator: DistCoordinator,
|
||||
save_dir: str,
|
||||
):
|
||||
save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"sample_start_index": step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
||||
|
||||
def load(
|
||||
booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str
|
||||
) -> Tuple[int, int, int]:
|
||||
booster.load_model(model, os.path.join(load_dir, "model"))
|
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
|
||||
running_states = load_json(os.path.join(load_dir, "running_states.json"))
|
||||
return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
|
||||
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
return outputs.loss
|
||||
|
||||
|
||||
def main():
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--plugin",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"],
|
||||
default="gemini",
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d", "--dataset", type=str, default="togethercomputer/RedPajama-Data-1T-Sample", help="Data set path"
|
||||
)
|
||||
parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("-s", "--warmup_steps", type=int, default=2000, help="Warmup steps")
|
||||
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
|
||||
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
|
||||
parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval")
|
||||
parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory")
|
||||
parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping")
|
||||
parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory")
|
||||
parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip
|
||||
)
|
||||
elif args.plugin == "hybrid_parallel":
|
||||
# modify the param accordingly, default configuration is for llama2-7b
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=4,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_jit_fused=False,
|
||||
zero_stage=0,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
|
||||
|
||||
# ==============================
|
||||
# Initialize Tensorboard
|
||||
# ==============================
|
||||
if print_flag:
|
||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
|
||||
# ==============================
|
||||
# Initialize Tokenizer, Dataset and Dataloader
|
||||
# ==============================
|
||||
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
# follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
|
||||
dataset = load_dataset(args.dataset)
|
||||
train_ds = dataset["train"]
|
||||
dataloader = prepare_dataloader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=partial(tokenize_batch_for_pretrain, tokenizer=tokenizer, max_length=args.max_length),
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Initialize Model, Optimizer and LR Scheduler
|
||||
# ==============================
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
# use lazy init when using GeminiPlugin
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||
if isinstance(plugin, GeminiPlugin)
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(config)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.flash_attention:
|
||||
replace_with_flash_attention(model)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer, total_steps=args.num_epochs * len(dataloader), warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr
|
||||
)
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||
)
|
||||
|
||||
# load checkpoint if specified
|
||||
start_epoch = 0
|
||||
start_step = 0
|
||||
sampler_start_idx = 0
|
||||
if args.load is not None:
|
||||
coordinator.print_on_master("Loading checkpoint")
|
||||
start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
|
||||
coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}")
|
||||
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
|
||||
# if resume training, set the sampler start index to the correct value
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
with tqdm(
|
||||
range(start_step, num_steps_per_epoch),
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not print_flag,
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
|
||||
loss = outputs["loss"]
|
||||
else:
|
||||
batch = next(dataloader_iter)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
booster.backward(loss, optimizer)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if not use_pipeline:
|
||||
all_reduce_mean(loss)
|
||||
if print_flag:
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step)
|
||||
|
||||
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
|
||||
coordinator.print_on_master(f"Saving checkpoint")
|
||||
save(
|
||||
booster,
|
||||
model,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
epoch,
|
||||
step + 1,
|
||||
args.batch_size,
|
||||
coordinator,
|
||||
args.save_dir,
|
||||
)
|
||||
coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}")
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(0)
|
||||
start_step = 0
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,9 +1,8 @@
|
|||
colossalai>=0.3.2
|
||||
colossalai>=0.3.6
|
||||
datasets
|
||||
numpy
|
||||
torch>=1.12.0,<=2.0.0
|
||||
tqdm
|
||||
transformers
|
||||
flash-attn>=2.0.0,<=2.0.5
|
||||
flash-attn>=2.0.0
|
||||
SentencePiece==0.1.99
|
||||
tensorboard==2.14.0
|
||||
|
|
Loading…
Reference in New Issue