diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index 5e74c2c4f..d2d02811a 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -265,6 +265,10 @@ def launch_multi_processes(args: Config) -> None: # establish remote connection runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env) + # overwrite master addr when num_nodes > 1 and not specified + if len(active_device_pool) > 1 and args.master_addr == "127.0.0.1": + args.master_addr = active_device_pool.hostinfo_list[0].hostname + # execute distributed launching command for node_id, hostinfo in enumerate(active_device_pool): cmd = get_launch_command(master_addr=args.master_addr, diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py index e83beb8b2..8a8980808 100644 --- a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py +++ b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py @@ -2,7 +2,13 @@ import warnings HAS_MEM_EFF_ATTN = False try: - from xformers.ops.fmha import memory_efficient_attention + from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention + from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + ) HAS_MEM_EFF_ATTN = True except ImportError: warnings.warn('please install xformers from https://github.com/facebookresearch/xformers') @@ -16,13 +22,6 @@ if HAS_MEM_EFF_ATTN: from typing import Optional import torch - from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp - from xformers.ops.fmha.attn_bias import ( - BlockDiagonalCausalMask, - BlockDiagonalMask, - LowerTriangularMask, - LowerTriangularMaskWithTensorBias, - ) from .utils import SeqLenInfo diff --git a/examples/language/llama/README.md b/examples/language/llama/README.md deleted file mode 100644 index 871804f2c..000000000 --- a/examples/language/llama/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Pretraining LLaMA: best practices for building LLaMA-like base models - -

- -

- -- 65-billion-parameter large model pretraining accelerated by 38% -[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama) -[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining) - -> Since the main branch is being updated, in order to maintain the stability of the code, this example is temporarily kept as an [independent branch](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama). diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md new file mode 100644 index 000000000..b64b5d29e --- /dev/null +++ b/examples/language/llama2/README.md @@ -0,0 +1,176 @@ +# Pretraining LLaMA-2: best practices for building LLaMA-2-like base models + +## Dataset + +Different from the original LLaMA, we use [RedPajama](https://www.together.xyz/blog/redpajama) dataset, which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. The full dataset is ~5TB unzipped on disk and ~3TB to download compressed. + +A smaller, more consumable random sample can be downloaded through [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T). If you just want to try out the pretraining script, you can use a 1B-token sample subset of RedPajama, which is available at [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample). + +RedPajama-Data-1T consists of seven data slices: + +| | RedPajama | LLaMA | +|---------------|--------------|---------------| +| CommonCrawl | 878 billion | 852 billion | +| C4 | 175 billion | 190 billion | +| Github | 59 billion | 100 billion | +| Books | 26 billion | 25 billion | +| ArXiv | 28 billion | 33 billion | +| Wikipedia | 24 billion | 25 billion | +| StackExchange | 20 billion | 27 billion | +| Total | 1.2 trillion | 1.25 trillion | + +## Training + +We follow the hyperparameter settings from the original LLaMA paper. We use AdamW with $beta1=0.9$ and $beta2=0.95$. We use a cosine learning rate schedule, such that the final learning rate is equal to 10% of the maximal learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. We use 2,000 warmup steps. + +| params | learning rate | batch size | +|--------|---------------|------------| +| 6.7B | 3.0e-4 | 4M | +| 13.0B | 3.0e-4 | 4M | +| 32.5B | 1.5e-4 | 4M | +| 65.2B | 1.5e-4 | 4M | + +## Usage + +### 1. Installation + +Please install the latest ColossalAI from source. + +```bash +CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI +``` + +Then install other dependencies. + +```bash +pip install -r requirements.txt +``` + +Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention. + +### 2. Download the dataset + +The dataset can be automatically downloaded by using `huggingface/datasets`. You can specify the dataset path by `-d` or `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. + +### 3. Command line arguments + +Yon can use colossalai run to launch multi-nodes training: +```bash +colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ +pretrain.py --OTHER_CONFIGURATIONS +``` + +Here is a sample hostfile: + +```text +hostname1 +hostname2 +hostname3 +hostname4 +``` + +Make sure master node can access all nodes (including itself) by ssh without password. + +Here is details about CLI arguments: + +- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported. +- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). +- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama. +- Number of epochs: `-e`, `--num_epochs`. The default value is 1. +- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2. +- Learning rate: `--lr`. The default value is 3e-4. +- Weight decay: `-w`, `--weight_decay`. The default value is 0.1. +- Warmup steps: `-s`, `--warmup_steps`. The default value is 2000. +- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. +- Max length: `-l`, `--max_length`. The default value is 4096. +- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. +- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. +- Checkpoint directory: `-o`, `--save_dir`. The directoty path to save checkpoints. The default value is `checkpoint`. +- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`. +- Gradient clipping: `--gradient_clipping`. The default value is 1.0. +- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`. +- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. + + +### 4. Shell Script Examples + +For your convenience, we provide some shell scripts to run benchmark with various configurations. + +You can find them in `scripts/benchmark_7B` and `scripts/benchmark_70B` directory. The main command should be in the format of: +```bash +colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ +benchmark.py --OTHER_CONFIGURATIONS +``` +Here we will show an example of how to run training +llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`. + +#### a. Running environment +This experiment was performed on 4 computing nodes with 32 A800 GPUs in total. The nodes are +connected with RDMA and GPUs within one node are fully connected with NVLink. + +#### b. Running command + +```bash +cd scripts/benchmark_7B +``` + +First, put your host file (`hosts.txt`) in this directory with your real host ip or host name. + +Here is a sample `hosts.txt`: +```text +hostname1 +hostname2 +hostname3 +hostname4 +``` + +Then add environment variables to script if needed. + +Finally, run the following command to start training: + +```bash +bash gemini.sh +``` +#### c. Results +If you run the above command successfully, you will get the following results: +`max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`. + + +## Reference +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +``` + +```bibtex +@software{openlm2023openllama, + author = {Geng, Xinyang and Liu, Hao}, + title = {OpenLLaMA: An Open Reproduction of LLaMA}, + month = May, + year = 2023, + url = {https://github.com/openlm-research/open_llama} +} +``` + +```bibtex +@software{together2023redpajama, + author = {Together Computer}, + title = {RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset}, + month = April, + year = 2023, + url = {https://github.com/togethercomputer/RedPajama-Data} +} +``` + +```bibtex +@article{touvron2023llama, + title={Llama: Open and efficient foundation language models}, + author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others}, + journal={arXiv preprint arXiv:2302.13971}, + year={2023} +} +``` diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py new file mode 100644 index 000000000..15f76647c --- /dev/null +++ b/examples/language/llama2/attn.py @@ -0,0 +1,83 @@ +from types import MethodType +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv + +SUPPORT_XFORMERS = False +SUPPORT_FLASH2 = False +try: + import xformers.ops as xops + SUPPORT_XFORMERS = True +except ImportError: + pass + +try: + from flash_attn import flash_attn_func + SUPPORT_FLASH2 = True +except ImportError: + pass + +SUPPORT_FLASH = SUPPORT_XFORMERS or SUPPORT_FLASH2 + + +def llama_flash_attention( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # q, k, v is [B, H, S, K] and xformers need [B, S, H, K]. returns [B, S, H, K] + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + if SUPPORT_FLASH2: + attn_output = flash_attn_func(query_states, key_states, value_states, causal=True) + else: + attn_output = xops.memory_efficient_attention(query_states, + key_states, + value_states, + attn_bias=xops.LowerTriangularMask()) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def replace_xformers(model: nn.Module): + for module in model.modules(): + if isinstance(module, LlamaAttention): + module.forward = MethodType(llama_flash_attention, module) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py new file mode 100644 index 000000000..1b947cef9 --- /dev/null +++ b/examples/language/llama2/benchmark.py @@ -0,0 +1,211 @@ +import argparse +import resource +from contextlib import nullcontext + +import torch +from attn import SUPPORT_FLASH, replace_xformers +from data_utils import RandomDataset +from model_utils import format_numel_str, get_model_numel +from performance_evaluator import PerformanceEvaluator +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision +from tqdm import tqdm +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Constants +# ============================== + +MODEL_CONFIGS = { + '7b': + LlamaConfig(max_position_embeddings=4096), + '13b': + LlamaConfig(hidden_size=5120, + intermediate_size=13824, + num_hidden_layers=40, + num_attention_heads=40, + max_position_embeddings=4096), + '70b': + LlamaConfig(hidden_size=8192, + intermediate_size=28672, + num_hidden_layers=80, + num_attention_heads=64, + max_position_embeddings=4096, + num_key_value_heads=8), +} + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration') + parser.add_argument('-p', + '--plugin', + choices=['gemini', 'gemini_auto', 'fsdp', 'fsdp_cpu', '3d', '3d_cpu'], + default='gemini', + help='Choose which plugin to use') + parser.add_argument('-b', '--batch_size', type=int, default=2, help='Batch size') + parser.add_argument('-s', '--num_steps', type=int, default=5, help='Number of steps to run') + parser.add_argument('-i', '--ignore_steps', type=int, default=2, help='Number of steps to ignore') + parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing') + parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length') + parser.add_argument('-w', + '--warmup_ratio', + type=float, + default=0.8, + help='warm up ratio of non-model data. Only for gemini-auto') + parser.add_argument('-m', '--memory_limit', type=int, help='Gemini memory limit in mb') + parser.add_argument('-x', '--xformers', action='store_true', help='Use xformers') + parser.add_argument('--shard_param_frac', type=float, default=1.0, help='Shard param fraction. Only for gemini') + parser.add_argument('--offload_optim_frac', type=float, default=0.0, help='Offload optim fraction. Only for gemini') + parser.add_argument('--offload_param_frac', type=float, default=0.0, help='Offload param fraction. Only for gemini') + parser.add_argument('--tp', type=int, default=1, help='Tensor parallel size') + parser.add_argument('--pp', type=int, default=1, help='Pipeline parallel size') + parser.add_argument('--mbs', type=int, default=1) + parser.add_argument('--zero', type=int, default=0) + args = parser.parse_args() + + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + def empty_init(): + pass + + # ============================== + # Initialize Booster + # ============================== + use_empty_init = True + if args.plugin == 'gemini': + plugin = GeminiPlugin(precision='bf16', + shard_param_frac=args.shard_param_frac, + offload_optim_frac=args.offload_optim_frac, + offload_param_frac=args.offload_param_frac) + elif args.plugin == 'gemini_auto': + plugin = GeminiPlugin(placement_policy='auto', precision='bf16', warmup_non_model_data_ratio=args.warmup_ratio) + elif args.plugin == 'fsdp': + if use_empty_init: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision(param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16), + param_init_fn=empty_init(), + ) + else: + plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16)) + elif args.plugin == 'fsdp_cpu': + if use_empty_init: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision(param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16), + cpu_offload=CPUOffload(offload_params=True), + param_init_fn=empty_init(), + ) + else: + plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision(param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16), + cpu_offload=CPUOffload(offload_params=True)) + elif args.plugin == '3d': + plugin = HybridParallelPlugin(tp_size=args.tp, + pp_size=args.pp, + zero_stage=args.zero, + enable_fused_normalization=True, + num_microbatches=args.mbs, + precision='bf16') + elif args.plugin == '3d_cpu': + plugin = HybridParallelPlugin(tp_size=args.tp, + pp_size=args.pp, + zero_stage=args.zero, + cpu_offload=True, + enable_fused_normalization=True, + num_microbatches=args.mbs, + initial_scale=2**8, + precision='bf16') + else: + raise ValueError(f'Unknown plugin {args.plugin}') + + booster = Booster(plugin=plugin) + + # ============================== + # Initialize Dataset and Dataloader + # ============================== + dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size + + config = MODEL_CONFIGS[args.config] + dataset = RandomDataset(num_samples=args.batch_size * args.num_steps * dp_size, + max_length=args.max_length, + vocab_size=config.vocab_size) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) + + # ============================== + # Initialize Model and Optimizer + # ============================== + init_ctx = LazyInitContext( + default_device=get_current_device()) if isinstance(plugin, + (GeminiPlugin, HybridParallelPlugin)) else nullcontext() + + with init_ctx: + model = LlamaForCausalLM(config) + + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + + if args.xformers: + assert SUPPORT_FLASH, 'Use flash attention while xfomers is not installed' + replace_xformers(model) + + model_numel = get_model_numel(model) + coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}') + performance_evaluator = PerformanceEvaluator(model_numel, + args.grad_checkpoint, + args.ignore_steps, + dp_world_size=dp_size) + + optimizer = HybridAdam(model.parameters()) + torch.set_default_dtype(torch.bfloat16) + model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + torch.set_default_dtype(torch.float) + coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master( + f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB') + + if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: + data_iter = iter(dataloader) + for step in tqdm(range(len(dataloader)), desc='Step', disable=not coordinator.is_master()): + performance_evaluator.on_step_start(step) + booster.execute_pipeline(data_iter, + model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=optimizer, + return_loss=False) + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) + else: + for step, batch in enumerate(tqdm(dataloader, desc='Step', disable=not coordinator.is_master())): + performance_evaluator.on_step_start(step) + outputs = model(**batch) + loss = outputs[0] + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(**batch) + + performance_evaluator.on_fit_end() + coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + + +if __name__ == '__main__': + main() diff --git a/examples/language/llama2/data_utils.py b/examples/language/llama2/data_utils.py new file mode 100644 index 000000000..25d0e1bd9 --- /dev/null +++ b/examples/language/llama2/data_utils.py @@ -0,0 +1,119 @@ +import json +import random +from typing import Iterator, Optional + +import numpy as np +import torch +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group +from torch.utils.data import DataLoader, Dataset, DistributedSampler + +from colossalai.utils import get_current_device + + +class StatefulDistributedSampler(DistributedSampler): + + def __init__(self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False) -> None: + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + self.start_index: int = 0 + + def __iter__(self) -> Iterator: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index:] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def set_start_index(self, start_index: int) -> None: + self.start_index = start_index + + +def prepare_dataloader(dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + process_group: Optional[ProcessGroup] = None, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `StatefulDistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler(dataset, + num_replicas=process_group.size(), + rank=process_group.rank(), + shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + + +def load_json(file_path: str): + with open(file_path, 'r') as f: + return json.load(f) + + +def save_json(data, file_path: str): + with open(file_path, 'w') as f: + json.dump(data, f, indent=4) + + +class RandomDataset(Dataset): + + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + 'input_ids': self.input_ids[idx], + 'attention_mask': self.attention_mask[idx], + 'labels': self.input_ids[idx] + } diff --git a/examples/language/llama2/model_utils.py b/examples/language/llama2/model_utils.py new file mode 100644 index 000000000..431ff5cfb --- /dev/null +++ b/examples/language/llama2/model_utils.py @@ -0,0 +1,32 @@ +from contextlib import contextmanager + +import torch +import torch.nn as nn + + +@contextmanager +def low_precision_init(target_dtype: torch.dtype = torch.float16): + dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(target_dtype) + yield + finally: + torch.set_default_dtype(dtype) + + +def get_model_numel(model: nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f'{numel / B:.2f} B' + elif numel >= M: + return f'{numel / M:.2f} M' + elif numel >= K: + return f'{numel / K:.2f} K' + else: + return f'{numel}' diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py new file mode 100644 index 000000000..711b99c54 --- /dev/null +++ b/examples/language/llama2/performance_evaluator.py @@ -0,0 +1,102 @@ +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +from torch import Tensor + +from colossalai.cluster import DistCoordinator + + +def divide(x: float, y: float) -> float: + if y == 0: + return float('inf') + elif y == float('inf'): + return float('nan') + return x / y + + +@torch.no_grad() +def all_reduce_mean(x: float, world_size: int) -> float: + if world_size == 1: + return x + tensor = torch.tensor([x], device=torch.cuda.current_device()) + dist.all_reduce(tensor) + tensor = tensor / world_size + return tensor.item() + + +class Timer: + + def __init__(self) -> None: + self.start_time: Optional[float] = None + self.duration: float = 0. + + def start(self) -> None: + self.start_time = time() + + def end(self) -> None: + assert self.start_time is not None + self.duration += time() - self.start_time + self.start_time = None + + def reset(self) -> None: + self.duration = 0. + + +class PerformanceEvaluator: + """ + Callback for valuate the performance of the model. + Args: + actor_num_params: The number of parameters of the actor model. + critic_num_params: The number of parameters of the critic model. + initial_model_num_params: The number of parameters of the initial model. + reward_model_num_params: The number of parameters of the reward model. + enable_grad_checkpoint: Whether to enable gradient checkpointing. + ignore_episodes: The number of episodes to ignore when calculating the performance. + """ + + def __init__(self, + model_numel: int, + enable_grad_checkpoint: bool = False, + ignore_steps: int = 0, + dp_world_size: Optional[int] = None) -> None: + self.model_numel = model_numel + self.enable_grad_checkpoint = enable_grad_checkpoint + self.ignore_steps = ignore_steps + + self.coordinator = DistCoordinator() + self.dp_world_size = dp_world_size or self.coordinator.world_size + self.disable: bool = False + self.timer = Timer() + self.num_samples: int = 0 + self.flop: int = 0 + + def on_step_start(self, step: int) -> None: + self.disable = self.ignore_steps > 0 and step < self.ignore_steps + if self.disable: + return + torch.cuda.synchronize() + self.timer.start() + + def on_step_end(self, input_ids: Tensor, **kwargs) -> None: + if self.disable: + return + torch.cuda.synchronize() + self.timer.end() + + batch_size, seq_len = input_ids.shape + + self.num_samples += batch_size + self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)) + + def on_fit_end(self) -> None: + avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size) + avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12) + mp_world_size = self.coordinator.world_size // self.dp_world_size + avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size + self.coordinator.print_on_master( + f'num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, ' + f'avg_throughput: {avg_throughput}') + self.coordinator.print_on_master( + f'Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}') diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py new file mode 100644 index 000000000..b72a30196 --- /dev/null +++ b/examples/language/llama2/pretrain.py @@ -0,0 +1,275 @@ +import argparse +import os +import resource +from contextlib import nullcontext +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from attn import SUPPORT_XFORMERS, replace_xformers +from data_utils import load_json, prepare_dataloader, save_json +from datasets import load_dataset +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers.models.llama.tokenization_llama import LlamaTokenizer + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +MODEL_CONFIGS = { + '7b': + LlamaConfig(max_position_embeddings=4096), + '13b': + LlamaConfig(hidden_size=5120, + intermediate_size=13824, + num_hidden_layers=40, + num_attention_heads=40, + max_position_embeddings=4096), + '70b': + LlamaConfig(hidden_size=8192, + intermediate_size=28672, + num_hidden_layers=80, + num_attention_heads=64, + max_position_embeddings=4096, + num_key_value_heads=8), +} + + +def get_model_numel(model: nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f'{numel / B:.2f} B' + elif numel >= M: + return f'{numel / M:.2f} M' + elif numel >= K: + return f'{numel / K:.2f} K' + else: + return f'{numel}' + + +def tokenize_batch(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): + texts = [sample['text'] for sample in batch] + data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length) + data['labels'] = data['input_ids'].clone() + return data + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) + return tensor + + +def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int, + batch_size: int, coordinator: DistCoordinator, save_dir: str): + save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}') + os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, 'model'), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler')) + running_states = { + 'epoch': epoch, + 'step': step, + 'sample_start_index': step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, 'running_states.json')) + + +def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, + load_dir: str) -> Tuple[int, int, int]: + booster.load_model(model, os.path.join(load_dir, 'model')) + booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer')) + booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler')) + running_states = load_json(os.path.join(load_dir, 'running_states.json')) + return running_states['epoch'], running_states['step'], running_states['sample_start_index'] + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration') + parser.add_argument('-p', + '--plugin', + choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu'], + default='gemini', + help='Choose which plugin to use') + parser.add_argument('-d', + '--dataset', + type=str, + default='togethercomputer/RedPajama-Data-1T-Sample', + help='Data set path') + parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs') + parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size') + parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate') + parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay') + parser.add_argument('-s', '--warmup_steps', type=int, default=2000, help='Warmup steps') + parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing') + parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length') + parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision') + parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval') + parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory') + parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint') + parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping') + parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory') + parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention') + args = parser.parse_args() + + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + # ============================== + # Initialize Tensorboard + # ============================== + if coordinator.is_master(): + os.makedirs(args.tensorboard_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == 'gemini': + plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) + elif args.plugin == 'gemini_auto': + plugin = GeminiPlugin(precision=args.mixed_precision, + placement_policy='auto', + initial_scale=2**16, + max_norm=args.grad_clip) + elif args.plugin == 'zero2': + plugin = LowLevelZeroPlugin(stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip) + elif args.plugin == 'zero2_cpu': + plugin = LowLevelZeroPlugin(stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip) + else: + raise ValueError(f'Unknown plugin {args.plugin}') + + booster = Booster(plugin=plugin) + + # ============================== + # Initialize Tokenizer, Dataset and Dataloader + # ============================== + tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer') + # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 + tokenizer.pad_token = tokenizer.unk_token + + dataset = load_dataset(args.dataset) + train_ds = dataset['train'] + dataloader = prepare_dataloader(train_ds, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=args.max_length)) + + # ============================== + # Initialize Model, Optimizer and LR Scheduler + # ============================== + config = MODEL_CONFIGS[args.config] + init_ctx = LazyInitContext( + default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + + with init_ctx: + model = LlamaForCausalLM(config) + + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + if args.flash_attention: + assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed' + replace_xformers(model) + + model_numel = get_model_numel(model) + coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}') + + optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) + lr_scheduler = CosineAnnealingWarmupLR(optimizer, + total_steps=args.num_epochs * len(dataloader), + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr) + default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optimizer, _, dataloader, lr_scheduler = booster.boost(model, + optimizer, + dataloader=dataloader, + lr_scheduler=lr_scheduler) + torch.set_default_dtype(torch.float) + + coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master( + f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB') + + # load checkpoint if specified + start_epoch = 0 + start_step = 0 + sampler_start_idx = 0 + if args.load is not None: + coordinator.print_on_master('Loading checkpoint') + start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) + coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}') + + num_steps_per_epoch = len(dataloader) + # if resume training, set the sampler start index to the correct value + dataloader.sampler.set_start_index(sampler_start_idx) + for epoch in range(start_epoch, args.num_epochs): + dataloader.sampler.set_epoch(epoch) + with tqdm(enumerate(dataloader), + desc=f'Epoch {epoch}', + disable=not coordinator.is_master(), + total=num_steps_per_epoch, + initial=start_step) as pbar: + for step, batch in pbar: + batch = {k: v.cuda() for k, v in batch.items()} + outputs = model(**batch) + loss = outputs[0] + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + all_reduce_mean(loss) + pbar.set_postfix({'loss': loss.item()}) + if coordinator.is_master(): + writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step) + + if args.save_interval > 0 and (step + 1) % args.save_interval == 0: + coordinator.print_on_master(f'Saving checkpoint') + save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator, + args.save_dir) + coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}') + # the continue epochs are not resumed, so we need to reset the sampler start index and start step + dataloader.sampler.set_start_index(0) + start_step = 0 + + coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + + +if __name__ == '__main__': + main() diff --git a/examples/language/llama2/requirements.txt b/examples/language/llama2/requirements.txt new file mode 100644 index 000000000..3ddf21ffe --- /dev/null +++ b/examples/language/llama2/requirements.txt @@ -0,0 +1,9 @@ +colossalai>=0.3.0 +datasets +numpy +torch>=1.12.0,<=2.0.0 +tqdm +transformers +flash-attn>=2.0.0,<=2.0.5 +SentencePiece==0.1.99 +tensorboard==2.14.0 diff --git a/examples/language/llama2/scripts/benchmark_70B/3d.sh b/examples/language/llama2/scripts/benchmark_70B/3d.sh new file mode 100644 index 000000000..d50c57042 --- /dev/null +++ b/examples/language/llama2/scripts/benchmark_70B/3d.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# TODO: fix this +echo "3D parallel for LLaMA-2 is not ready yet" +exit 1 + +################ +#Load your environments and modules here +################ + +HOSTFILE=$(realpath hosts.txt) + +cd ../.. + +export OMP_NUM_THREADS=8 + +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 4 diff --git a/examples/language/llama2/scripts/benchmark_70B/gemini.sh b/examples/language/llama2/scripts/benchmark_70B/gemini.sh new file mode 100644 index 000000000..c80d4d9f2 --- /dev/null +++ b/examples/language/llama2/scripts/benchmark_70B/gemini.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +################ +#Load your environments and modules here +################ + +HOSTFILE=$(realpath hosts.txt) + +cd ../.. + +export OMP_NUM_THREADS=8 + +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -g -x -b 2 diff --git a/examples/language/llama2/scripts/benchmark_70B/gemini_auto.sh b/examples/language/llama2/scripts/benchmark_70B/gemini_auto.sh new file mode 100644 index 000000000..ce3b2f217 --- /dev/null +++ b/examples/language/llama2/scripts/benchmark_70B/gemini_auto.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +################ +#Load your environments and modules here +################ + +HOSTFILE=$(realpath hosts.txt) + +cd ../.. + +export OMP_NUM_THREADS=8 + +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p gemini_auto -g -x -b 2 diff --git a/examples/language/llama2/scripts/benchmark_7B/gemini.sh b/examples/language/llama2/scripts/benchmark_7B/gemini.sh new file mode 100644 index 000000000..db4968a8d --- /dev/null +++ b/examples/language/llama2/scripts/benchmark_7B/gemini.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +################ +#Load your environments and modules here +################ + +HOSTFILE=$(realpath hosts.txt) + +cd ../.. + +export OMP_NUM_THREADS=8 + +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -g -x -b 16 diff --git a/examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh b/examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh new file mode 100644 index 000000000..59ec1c1a7 --- /dev/null +++ b/examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +################ +#Load your environments and modules here +################ + +HOSTFILE=$(realpath hosts.txt) + +cd ../.. + +export OMP_NUM_THREADS=8 + +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -p gemini_auto -g -x -b 16 diff --git a/examples/language/llama/test_ci.sh b/examples/language/llama2/test_ci.sh similarity index 100% rename from examples/language/llama/test_ci.sh rename to examples/language/llama2/test_ci.sh