initial commit: add colossal llama 2 (#4784)

pull/4786/head^2
Tong Li 2023-09-24 23:12:26 +08:00 committed by GitHub
parent 4146f1c0ce
commit 74aa7d964a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 2162 additions and 2 deletions

View File

@ -0,0 +1,377 @@
<div align="center">
<h1>
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/colossalllam2.jpg?raw=true" width=800/>
</h1>
</div>
## Table of Contents
- [News](#news)
- [Colossal-LLaMA-2-7B](#colossal-llama-2-7b)
- [Performance Evaluation](#performance-evaluation)
- [Examples](#examples)
- [Training Logs](#training-logs)
- [Import from Transformers](#import-from-transformers)
- [Usage](#usage)
- [Install](#install)
- [How to run](#how-to-run)
- [Technical Insight](#technical-insights)
- [Data](#data)
- [Tokenizer](#tokenizer)
- [Training Strategy](#training-strategy)
- [Citations](#citations)
## News
* [2023/09] 🔥 TODO We released **Colossal-LLaMA-2-7B-base** based on LLaMA-2. [Download weights](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base).
## Colossal-LLaMA-2-7B
The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team has introduced the open-source model **Colossal-LLaMA-2-7B-base**. This model, a derivation of LLaMA-2, has undergone continual pre-training involving approximately 8.5 billion tokens over a duration of 15 hours with 64 A800 GPUs. At a cost of **less than $1,000**, you can achieve results **similar to those that cost millions of dollars to pretrain from scratch**. It is licensed under the LLaMA-2 license and [Apache 2.0 License](https://github.com/hpcaitech/ColossalAI/blob/main/LICENSE) **without any additional commercial use restrictions**. This solution can also be used to build models of specific domain knowledge or tasks.
Colossal-LLaMA-2-7B-base is designed to accommodate both the Chinese and English languages, featuring an expansive context window spanning 4096 tokens. Remarkably, it has exhibited exceptional performance when benchmarked against models of equivalent scale in standard Chinese and English evaluation metrics, including C-Eval and MMLU, among others.
### Performance Evaluation
We conducted comprehensive evaluation on 4 dataset and compare our Colossal-Llama-2-7b-base model with various models.
* We use 5-shot for MMLU and calculate scores based on the logits of first predicted token.
* We use 5-shot for CMMLU and calculate scores based on the logits of first predicted token.
* We use 5-shot for AGIEval and only calculate scores for 4-choice questions using a combination metric of exact match and the logits of first predicted token. If any of the exact match or logits of first predicted token is correct, the model will get the score.
* We use 0-shot for GAOKAO-Bench and only calculate scores for 4-choice questions based on the logits of first predicted token.
The generation config for all dataset is greedy search.
* We also provided CEval scores from its lastest leaderboard or the official repository of the model.
| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval |
| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: |
| | | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot |
| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 |
| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 |
| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 |
| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 |
| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 |
| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 |
| InternLM-7B | - | - | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 |
| Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 |
| | | | | | | | | |
| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - |
| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - |
| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - |
| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 |
| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - |
| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - |
| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - |
| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - |
| | | | | | | | | |
| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 |
> The score in parentheses corresponds to the scores in the official repository of the model.
>
> We use zero-shot for ChatGLM models.
>
> Qwen-7B is now inaccessible in Hugging Face, we are using the latest version of it before it was made inaccessible. Only for dataset MMLU, the prompt would be "xxx Answer:"(remove the space after ":") and we calculate the logits over " A", " B", " C" and " D" for Qwen-7B. Qwen-7B tends to be much more deterministic than other models. For example, the logits over " A" can be `-inf` and softmax would be exact `0`.
>
> For other models and other dataset, we calculate logits over "A", "B", "C" and "D".
❗️ More details of the evaluation methods and reproduction of the results, please refer to [TODO: ColossalEval]().
### Examples
| Question Type | Question | <center>Colossal-LLaMA-2-7b-base</center> |
| :------: | :-----------------------: | :----------- |
| Knowledge Quiz<br>(Tourism) | 请问武汉的最佳的旅行季节是什么时候? | 3-4月份到武汉赏樱10月份到武汉看黄陂枫树。武汉的最佳的旅行季节是春秋3-4月份到武汉赏樱10月份到武汉看黄陂枫树。武汉简称“汉”别称江城是湖北省省会中部六省唯一的副省级市特大城市中国中部地区的中心城市全国重要的工业基地、科教基地和综合交通枢纽。 |
| Knowledge Quiz<br>(Computer Science) | 在计算机中什么是RAM与ROM有什么区别 | RAM是随机存取内存可以随时读写但是断电之后会丢失数据。存储数据比较快但是需要不断的进行通电供电。ROM是只读内存只能读取数据不能写入数据。断电之后存储的数据也不会丢失但是需要不断进行通电。两者的价格和性能上会有区别而且在使用的时候RAM需要不断的供电而ROM断电数据不丢失。 |
| Knowledge Quiz<br>(Literature) | 古代文学与现代文学有哪些区别和共通之处? | 古代文学与现代文学有三点区别,现代文学更加通俗易懂,具有可读性,而古代文学相对来说更加注重形式,缺乏可读性。现代文学和古代文学,同属于中华文学大花园中的一枝花蕾。但是它们由于历史的原因,有着不同的特点,具体表现在,现代文学有更丰富的描写方法,更强烈的主观色彩和更突出的抒情色彩,而古代文学的主要特征是在创作风格上的多样化。 |
| Emotional Analysis | 请判断这个句子的情绪是否正面:我迫不及待想见到她!!!”。 | 这一句是情绪是正向的。是在期待见面的喜悦情绪。 |
| Translation | 将以下句子翻译成英语。 我喜欢看电影和读书。 | I like watching movies and reading books. |
| Logical Reasoning | All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books? | To find the weight of 2 books, we can use the following steps:<br><br>1. Determine the weight of 1 book.<br>2. Multiply the weight of 1 book by the number of books.<br><br>Step 1: Determine the weight of 1 book.<br>Since 10 books weigh 5kg, we can divide 5kg by 10 to find the weight of 1 book.<br>5kg ÷ 10 = 0.5kg<br><br>Step 2: Multiply the weight of 1 book by the number of books.<br>To find the weight of 2 books, we multiply the weight of 1 book by 2.<br>0.5kg × 2 = 1kg<br><br>So, the weight of 2 books is 1kg. |
| Information Extraction | The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence. | The date mentioned in the sentence is October 20th to October 23rd, 2021. |
| Error Correction | Please correct the syntax error and restructure the sentence: "Went to the store me and my friend." | The given sentence "Went to the store me and my friend." is incorrect due to a syntax error and it lacks a subject. The corrected version with proper grammar and sentence structure would be: "My friend and I went to the store." |
❗️ More examples of question answering, please refer to [Colossal-LLaMA-2-7B-base Examples](docs/example.md).
### Training Logs
We also recorded the training logs for the experiment
<p id="Colossal-LLaMA-2-Multi-stage-training" align="center">
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/trainingLossBySteps.jpeg?raw=true" width=600/>
</p>
<p id="Colossal-LLaMA-2-Multi-stage-training" align="center">
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/trainingLossByTokens.jpeg?raw=true" width=600/>
</p>
### Import from Transformers
To load Colossal-LLaMA-2-7B-base model using Transformers, use the following code:
```Python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("hpcai-tech/Colossal-LLaMA-2-7b-base", device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/Colossal-LLaMA-2-7b-base", trust_remote_code=True)
input = "离离原上草,"
inputs = tokenizer(input, return_tensors='pt')
inputs = inputs.to('cuda:0')
pred = model.generate(**inputs,
max_new_tokens=256,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1)
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(input):])
```
You can also download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base).
## Usage
### Install
#### 0. Pre-requisite
1. This experiment was performed on 8 computing nodes with 64 A800 GPUs in total for LLaMA-2-7B (**about 1000 USD cost**). The nodes are connected with RDMA and GPUs within one node are fully connected with NVLink. The script was tested with CUDA 11.7, CUDA version requires 11.7 or higher. You can also complete it in about 5 days on a 8*A100/A800 server.
2. PyTorch. The PyTorch version should be less than 2.0.0 and greater than 1.12.1.
#### 1. Install required packages
```
cd Colossal-LLaMA-2
pip install -r requirements.txt
```
#### 2. Install `xentropy`, `layer_norm` and `rotary`
```bash
git clone git@github.com:Dao-AILab/flash-attention.git
# At the root folder
cd csrc/xentropy && pip install .
# At the root folder
cd csrc/layer_norm && pip install .
# At the root folder
cd csrc/rotary && pip install .
```
### How to run
#### 1. Init Tokenizer Preparation
Initialize new tokenizer with additional Chinese tokens. Additional Chinese tokens are stored in `jsonl` format as follows:
```json
{"piece": "你好"}
{"piece": "人工智能"}
```
Command to initialize new tokenizer:
```bash
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION='python'
python colossal_llama2/tokenizer/init_tokenizer.py \
--source_tokenizer_dir "<SOURCE_TOKENIZER_DIR>" \
--target_tokenizer_dir "<TARGET_TOKENIZER_DIR>" \
--expand_tokens_file "<NEW_TOKENS_FILE>.jsonl"
```
Here is details about CLI arguments:
* Source tokenizer directory: `--source_tokenizer_dir`. Directory to the source tokenizer. It should at least contain three files: `special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`.
* Target tokenizer directory: `--target_tokenizer_dir`. Directory to the target tokenizer.
* Tokens to be added: `--expand_tokens_file`. Additional tokens to be added to the tokenizer.
#### 2. Init Model Preparation
Initialize the new model checkpoint by calculating the mean values from the original model checkpoint.
Command to initialize new model checkpoint:
```bash
python colossal_llama2/model/init_model.py \
--source_model_and_tokenizer_path "<SOURCE_MODEL_AND_TOKENIZER_DIR>" \
--target_tokenizer_path "<TARGET_TOKENIZER_DIR>" \
--target_model_path "<TARGET_MODEL_DIR>"
```
"<TARGET_MODEL_DIR>" can be the same as "<TARGET_TOKENIZER_DIR>".
Here is details about CLI arguments:
* Source model and tokenizer path: `--source_model_and_tokenizer_path`. Source folder contains both model and tokenizer, for example, LLaMA-2 model in Hugging Face format.
* Target tokenizer path: `--target_tokenizer_path`. Path to the new tokenizer folder generated from previous step.
* Target model path: `--target_model_path`. Path to save the new model in Hugging Face format.
❗️**Important**: Once you initialize the new model checkpoint, copy your new tokenizer files (`special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`) to your new model folder.
#### 3. Data Preparation
Raw data should be formatted as `jsonl` format. Each data point should have the following fields:
* `source` (str, compulsory): This part is ignored when calculating loss. Default can be empty.
* `target` (str, compulsory): Loss will be calculated.
* `category` (str, compulsory): Tags for each data point.
Examples:
```JSON
{"source": "", "target": "Lionel Andrés Messi(Spanish pronunciation: [ljoˈnel anˈdɾes ˈmesi] (i); born 24 June 1987), also known as Leo Messi, is an Argentine professional footballer who plays as a forward for and captains both Major League Soccer club Inter Miami and the Argentina national team.", "category": "sports"}
{"source": "猜谜语:一身卷卷细毛,吃的青青野草,过了数九寒冬,无私献出白毛。(打一动物)", "target": "白羊", "category": "riddle"}
```
You are allowed to customize the category tags or use `unknown` to define the category.
Command to convert jsonl dataset to arrow format:
```
python prepare_pretrain_dataset.py \
--data_input_dirs "<JOSNL_DIR_1>,<JOSNL_DIR_2>,<JOSNL_DIR_3>" \
--tokenizer_dir "<TOKENIZER_DIR>" \
--data_cache_dir "jsonl_to_arrow_cache" \
--data_jsonl_output_dir "spliced_tokenized_output_jsonl" \
--data_arrow_output_dir "spliced_tokenized_output_arrow" \
--max_length 4096 \
--num_spliced_dataset_bins 10
```
Here is details about CLI arguments:
* Source data directory: `data_input_dirs`. Each `<JOSNL_DIR>` can have multiple file in `jsonl` format.
* Tokenzier directory: `tokenizer_dir`. Path to the tokenizer in Hugging Face format.
* Data cache directory: `data_cache_dir`. Directory to store Hugging Face data cache. Default case will create `cache` folder locally.
* Output directory for jsonl format: `data_jsonl_output_dir`. Output directory to store converted dataset in jsonl format.
* Output directory for arrow format: `data_arrow_output_dir`. Output directory to store converted dataset in arrow format, which can be used for training directly.
* Max length: `max_length`. Max length of spliced samples. Default value is 4096.
* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training.
#### 4. Command Line Arguments for Training
You can use `colossalai run` to launch multi-nodes training:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
pretrain.py --OTHER_CONFIGURATIONS
```
Here is a sample hostfile:
```bash
hostname1
hostname2
hostname3
hostname4
```
Make sure master node can access all nodes (including itself) by ssh without password.
Here is details about CLI arguments:
* Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format.
* Dataset path: `--dataset`. Path to the pre-tokenized dataset.
* Booster plugin: `--plugin`. `gemini`, `gemini_auto`, `zero2``zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/).
* Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training.
* Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
* Checkpoint directory: `--save_dir`. The directoty path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`.
* Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs.
* Configuration file: `--config_file`. The path to save the configuration file.
* Number of epochs: `--num_epochs`. Number of training epochs. The default value is 1.
* Micro batch size: `--micro_batch_size`. Batch size per GPU. The default value is 1.
* Learning rate: `--lr`. The default value is 3e-4.
* Max length: `--max_length`. Max context length. The default value is 4096.
* Mixed precision: `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
* Gradient clipping: `--gradient_clipping`. The default value is 1.0.
* Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
* Warmup steps: `-s`, `--warmup_steps`. The default value is calcuated by 0.025 warmup ratio.
* Gradient checkpointing: `--use_grad_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
* Flash attention: `--use_flash_attn`. If you want to use flash attention, you must install `flash-attn` and related packages. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
* Freeze non-embedding parameters: `--freeze_non_embeds_params`. Freeze non-embedding parameters. It can be helpful to align embeddings after extending vocabulary size.
* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1.
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1.
#### 5. Running Command
An [example bash](train.example.sh) is also provided for the experiment. Here is the steps to run the experiment:
* Create your own hostfile: `cp hostfile.example hostfile`.
* Create your own bash: `cp train.example.sh train.sh`.
* Add your real host ip or host name into the `hostfile`.
* Update global variables and parameters in your `train.sh`.
* Run the experiment by `bash train.sh`
Here is the details about global variables for each experiment:
* `PROJECT_NAME`: Project name for each experiment.
* `PARENT_SAVE_DIR`: Parent folder to save model checkpoint.
* `PARENT_TENSORBOARD_DIR`: Parent folder to save tensorboard logs.
* `PARENT_CONFIG_FILE`: Parent folder to save configuration for each experiment.
* `PRETRAINED_MODEL_PATH`: Path to the local pre-trained model checkpoint.
* `dataset`: Paths to all prepared data. Typically, it's a list of subfolders within the output path of prepare data, `--data_arrow_output_dir`, and if there are multiple subfolders, please list them all. e.g.,
```python
declare -a dataset=(
"<DIR_1>/part-00000"
"<DIR_1>/part-00001"
"<DIR_2>/part-00000"
)
```
## Technical Insights
In order to enhance LLaMA-2's capabilities for understanding and generating Chinese content, The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team proposes the continuation of pre-training the LLaMA-2 model using both Chinese and English corpora. The overall pipeline can be described as follows:
<p id="Colossal-LLaMA-2-pipeline" align="center">
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/Colossal-LLaMA-2-pipeline.jpeg?raw=true" width=800/>
</p>
### Data
Large language models such as LLaMA-2 have undergone training using a heterogeneous blend of high-quality datasets, yielding promising outcomes. Enhancing LLaMA-2's performance for the Chinese corpus, while preserving its proficiency in English, critically hinges on two pivotal factors: the composition of the dataset, which encompasses both English and Chinese content, and the quality of each constituent dataset.
The following figure shows the data processing pipeline conducted for Colossal-LLaMA-2.
<p id="Colossal-LLaMA-2-data-processing-pipeline" align="center">
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/data_processing_pipeline.jpeg?raw=true" width=800/>
</p>
❗️**Important**: We will open-source our data-processing toolkit soon, stay tuned!
### Tokenizer
The original LLaMA-2 vacabulary comprises fewer than a thousand Chinese characters, thus proves inadequate for encoding comprehensive Chinese texts effectively. Secondly, the utilization of byte tokens presents a challenge for transformer encoders to capture the semantic nuances of Chinese characters.
To address the above issues, we extend LLaMA-2 vocabulary from 32,000 to 69,104. To adapt the LLaMA-2 model for use with the Colossal-LLaMA-2 tokenizer, we initialize the new word embeddings by calculating the mean values from the original LLaMA-2 embeddings and subsequently append these new rows to the end of the original embedding matrices.
Advantages of extending vocabulary size:
* Improve the compression rate of string sequence encoding.
* Enhance the integrity of information.
* Enable encoded sequences to contain more valuable information, thereby theoretically enhancing the ability for chapter-level encoding.
Advantages of large vocabulary size under low-resource settings:
* The presence of numerous unused tokens can be attributed to the limited training dataset, where an excessive number of tokens might not have been effectively learned.
* Excessive vocabulary expansion leads to an increase in embedding-related parameters, resulting in higher memory usage, which, in turn, affects the efficiency of the training process.
To balance both sides, we finally construct our vocabulary with size 69,104. The following table below presents a comparison of various models at the 7B level.
| Model | Vocabulary Size | Compression Rate | Average Length of Samples (token-level) |
| :-----------: | :---------: | :----: | :----: |
| Colossal-LLaMA-2 | 69104 | 0.659 | 73.682 |
| LLaMA-2-7B | 32000 | 1.205 | 134.689 |
| Atom-7B | 65000 | 0.634 | 70.915 |
| Baichuan-7B | 64000 | 0.678 | 75.857 |
| Baichuan2-7B-base | 125696 | 0.570 | 63.761 |
| Chatglm2-6B | 64789 | 0.645 | 72.178 |
| InternLM-7B | 103168 | 0.566 | 63.349 |
| Qwen-7B | 151643 | 0.578 | 64.703 |
| Tigerbot-7B-base | 60515 | 0.630 | 70.515 |
| Yayi-7B-llama2 | 32005 | 1.214 | 135.689 |
| Chinese-llama-2-7b | 55296 | 0.668 | 74.690 |
| Chinese-Falcon-7B | 90046 | 0.669 | 74.858 |
| LinkSoul-Chinese-Llama-2-7b | 40076 | 0.958 | 107.089 |
| Ziya-LLaMA-13B-v1.1 | 39410 | 0.958 | 107.074 |
### Training Strategy
#### Multi-stage Training
In order to enhance the model's performance and harness the full potential of the original LLaMA-2, we have developed a multi-stage training strategy. This strategy is designed to systematically unlock the model's capabilities over a series of stages.
Therefore, we have divided the training process into three stages:
* Large-scale pre-training stage (Conducted by LLaMA-2): This initial stage is aimed at establishing the model's foundational capabilities from the ground up. It necessitates the use of a substantial dataset comprising no less than 1 trillion tokens.
* Chinese knowledge injection stage: In this stage, we introduce Chinese knowledge into the model. It requires access to a high-quality dataset rich in comprehensive knowledge relevant to the Chinese language.
* Knowledge replay stage: Knowledge is replayed through a question-answering (QA) mechanism, encompassing both the Chinese and English domains.
Following the completion of this multi-stage training process, the model exhibits notable improvements in performance across both English and Chinese benchmarks.
The following figure illustrates the three stages for training Colossal-LLaMA-2.
<p id="Colossal-LLaMA-2-Multi-stage-training" align="center">
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/multi-stage-training.png?raw=true" width=600/>
</p>
#### Bucket-based Training
Our experiments have revealed that the distributions within the training dataset, as well as the arrangement of various topic-related data points, significantly impact the overall performance of the model, particularly in the context of continual pre-training of LLaMA-2.
In an effort to achieve a more balanced distribution and exert control over the dataset's ordering, we have adopted a method where we divide each sub-dataset into discrete bins. These bins are then combined to construct individual data buckets, with one bin contributed by each sub-dataset.
## Citations
```bibtex
@article{bian2021colossal,
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
journal={arXiv preprint arXiv:2110.14883},
year={2021}
}
```
```bibtex
@misc{touvron2023llama,
title={Llama 2: Open Foundation and Fine-Tuned Chat Models},
author={Hugo Touvron and Louis Martin and Kevin Stone and Peter Albert and Amjad Almahairi and Yasmine Babaei and Nikolay Bashlykov and Soumya Batra and Prajjwal Bhargava and Shruti Bhosale and Dan Bikel and Lukas Blecher and Cristian Canton Ferrer and Moya Chen and Guillem Cucurull and David Esiobu and Jude Fernandes and Jeremy Fu and Wenyin Fu and Brian Fuller and Cynthia Gao and Vedanuj Goswami and Naman Goyal and Anthony Hartshorn and Saghar Hosseini and Rui Hou and Hakan Inan and Marcin Kardas and Viktor Kerkez and Madian Khabsa and Isabel Kloumann and Artem Korenev and Punit Singh Koura and Marie-Anne Lachaux and Thibaut Lavril and Jenya Lee and Diana Liskovich and Yinghai Lu and Yuning Mao and Xavier Martinet and Todor Mihaylov and Pushkar Mishra and Igor Molybog and Yixin Nie and Andrew Poulton and Jeremy Reizenstein and Rashi Rungta and Kalyan Saladi and Alan Schelten and Ruan Silva and Eric Michael Smith and Ranjan Subramanian and Xiaoqing Ellen Tan and Binh Tang and Ross Taylor and Adina Williams and Jian Xiang Kuan and Puxin Xu and Zheng Yan and Iliyan Zarov and Yuchen Zhang and Angela Fan and Melanie Kambadur and Sharan Narang and Aurelien Rodriguez and Robert Stojnic and Sergey Edunov and Thomas Scialom},
year={2023},
eprint={2307.09288},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
```bibtex
@article{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
author={Dao, Tri},
year={2023}
}
}
```

View File

@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,219 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import os
import random
from dataclasses import dataclass
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
import torch
from datasets import dataset_dict, load_from_disk
from datasets import Dataset as HFDataset
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer
import torch.nn.functional as F
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]
def load_tokenized_dataset(
dataset_paths: Union[PathType, List[PathType]], mode: str = "train"
) -> Optional[DatasetType]:
"""
Load pre-tokenized dataset.
Each instance of dataset is a dictionary with
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
"""
mode_map = {"train": "train", "dev": "validation", "test": "test"}
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
if isinstance(dataset_paths, (str, os.PathLike)):
dataset_paths = [dataset_paths]
datasets = [] # `List[datasets.dataset_dict.Dataset]`
for ds_path in dataset_paths:
ds_path = os.path.abspath(ds_path)
assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
if isinstance(ds_dict, HFDataset):
datasets.append(ds_dict)
else:
if mode_map[mode] in ds_dict:
datasets.append(ds_dict[mode_map[mode]])
if len(datasets) == 0:
return None
if len(datasets) == 1:
return datasets.pop()
return ConcatDataset(datasets=datasets)
@dataclass
class DataCollatorForSupervisedDataset(object):
"""
Collate instances for supervised dataset.
Each instance is a tokenized dictionary with fields
`input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
"""
tokenizer: PreTrainedTokenizer
max_length: int = 4096
ignore_index: int = -100
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary.
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
"""
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
f"but now `{self.tokenizer.pad_token_id}`"
)
# `List[torch.Tensor]`
batch_input_ids = [
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
for instance in instances
]
batch_labels = [
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
for instance in instances
]
if self.tokenizer.padding_side == "right":
input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=batch_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
labels = torch.nn.utils.rnn.pad_sequence(
sequences=batch_labels,
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
# pad to max
to_pad = self.max_length - input_ids.size(1)
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
elif self.tokenizer.padding_side == "left":
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=reversed_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len)
reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
reversed_labels = torch.nn.utils.rnn.pad_sequence(
sequences=reversed_labels,
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len)
else:
raise RuntimeError(
f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
f"but now `{self.tokenizer.padding_side}`"
)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
class StatefulDistributedSampler(DistributedSampler):
"""
Stateful distributed sampler for multi-stage training.
"""
def __init__(
self,
dataset: DatasetType,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(
dataset=dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
self.start_index = 0
def __iter__(self) -> Iterator:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def setup_distributed_dataloader(
dataset: DatasetType,
batch_size: int = 1,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
process_group: Optional[ProcessGroup] = None,
**kwargs,
) -> DataLoader:
"""
Setup dataloader for distributed training.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset=dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
# Deterministic dataloader
def seed_worker(worker_id: int) -> None:
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
worker_init_fn=seed_worker,
**_kwargs,
)

View File

@ -0,0 +1,183 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Splicing multiple pre-tokenized sequence data points
"""
import random
import warnings
from copy import deepcopy
from datasets import dataset_dict
from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
IGNORE_INDEX = -100
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
def supervised_tokenize(
data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original pretraining data point as following:
{"source": "", "target": "Beijing, the capital of the People's Republic of China, ...", "category": "geography"}
"""
assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
"Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
"add <bos> and <eos> manually later"
)
if ignore_index is None:
ignore_index = IGNORE_INDEX
source_text = data_point["source"] # `str`
target_text = data_point["target"] # `str`
is_null_source = len(source_text) == 0
source_text = tokenizer.bos_token + source_text
target_text += tokenizer.eos_token
sequence_text = source_text + target_text
tokenized = tokenizer([source_text, sequence_text])["input_ids"]
sequence_input_ids = tokenized[1]
sequence_labels = deepcopy(sequence_input_ids)
source_length = len(tokenized[0])
if not is_null_source:
sequence_labels[:source_length] = [ignore_index for _ in range(source_length)]
# sequence truncation.
if len(sequence_input_ids) > max_length:
sequence_input_ids = sequence_input_ids[:max_length]
sequence_labels = sequence_labels[:max_length]
return dict(
input_ids=sequence_input_ids,
labels=sequence_labels,
seq_length=len(sequence_input_ids),
seq_category=data_point["category"],
)
class ClosedToConstantLengthSplicedDataset(IterableDataset):
"""
Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
original independent (pre-tokenized) data points.
"""
def __init__(
self,
dataset: DSType,
tokenizer: PreTrainedTokenizer,
max_length: int = 4096,
num_packed_sequences: int = 8,
fetch_sequence_func: Callable[[Any], Tuple[List[int], List[int]]] = None,
input_ids_field: str = "input_ids",
labels_field: str = "labels",
infinite: bool = False,
shuffle: bool = True,
error_strict: bool = False,
) -> None:
self.tokenizer = tokenizer
self.dataset = dataset
self.max_length = max_length
self.infinite = infinite
self.max_buffer_size = max_length * num_packed_sequences # e.g., 4096 * 16
self.shuffle = shuffle
# Callable[[Dict[str, Any]], Tuple[List[int], List[int]]],
# A function that fetch sequence input_ids and labels from the original data point
if fetch_sequence_func is None:
self.fetch_sequence_func = lambda data_point: (data_point[input_ids_field], data_point[labels_field])
else:
self.fetch_sequence_func = fetch_sequence_func
self.input_ids_field = input_ids_field
self.labels_field = labels_field
self.error_strict = error_strict
self.current_size = 0 # `int`, current packed data size.
def __len__(self) -> int:
return len(self.dataset)
def __iter__(self) -> Iterable[Dict[str, List[int]]]:
iterator = iter(self.dataset)
more_data_points = True
while more_data_points is True:
buffer, buffer_len = [], 0
while True:
# ending condition.
if buffer_len >= self.max_buffer_size:
break
try:
# `Tuple[List[int], List[int]]`
seq_input_ids, seq_labels = self.fetch_sequence_func(next(iterator))
buffer.append({self.input_ids_field: seq_input_ids, self.labels_field: seq_labels})
buffer_len += len(buffer[-1][self.input_ids_field])
except StopIteration:
if self.infinite is True:
iterator = iter(self.dataset)
warnings.warn("The dataset reached end and the iterator is reset to the start.")
else:
more_data_points = False
break
examples = [] # `List[Dict[str, List[int]]]`, save buffered spliced data points.
spliced_input_ids, spliced_labels = [], [] # `List[int]`, `List[int]`
for i, data_point in enumerate(buffer):
# TODO(2023-09-18) check errors for each unspliced tokenized data point
seq_input_ids = data_point[self.input_ids_field]
seq_labels = data_point[self.labels_field]
# Handle special case:
# If the length of an original data point (i.e., input_ids length of a data point before splicing)
# exceeds `max_length`, truncate it.
if len(seq_input_ids) > self.max_length:
truncated_seq_input_ids = seq_input_ids[: self.max_length]
truncated_label_ids = seq_labels[: self.max_length]
if set(truncated_label_ids) == {IGNORE_INDEX}:
if self.error_strict is True:
raise ValueError(
f"Find an out-of-bounds length({len(seq_input_ids)}) data point "
f"with all label values as {IGNORE_INDEX}."
)
else:
warnings.warn(f"Filter an error truncated data point (labels all {IGNORE_INDEX})")
continue # Skip the current error data point.
spliced_data_point = {
self.input_ids_field: truncated_seq_input_ids,
self.labels_field: truncated_label_ids,
}
examples.append(spliced_data_point)
warnings.warn("Find a data point to be truncated.")
continue
# Pre action judgment.
if len(spliced_input_ids) + len(seq_input_ids) > self.max_length:
spliced_data_point = {
self.input_ids_field: spliced_input_ids,
self.labels_field: spliced_labels,
} # `Dict[str, List[int]]`
# Update.
spliced_input_ids, spliced_labels = [], []
spliced_input_ids.extend(seq_input_ids)
spliced_labels.extend(seq_labels)
examples.append(spliced_data_point)
else:
spliced_input_ids.extend(seq_input_ids)
spliced_labels.extend(seq_labels)
# For residual spliced data point at the end of the data set
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
examples.append(
{
self.input_ids_field: spliced_input_ids,
self.labels_field: spliced_labels
}
)
if self.shuffle:
random.shuffle(examples)
for spliced_data_point in examples:
# TODO(2023-09-18): check errors for each spliced tokenized data point.
self.current_size += 1
yield spliced_data_point

View File

@ -0,0 +1,111 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Initialize new model with updated tokenizer by calculating the mean values from original model
"""
import argparse
import numpy as np
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--source_model_and_tokenizer_path",
type=str,
required=True,
default=None,
help="Source path of model & tokenizer",
)
parser.add_argument("--target_tokenizer_path", type=str, required=True, default=None, help="Target tokenizer path")
parser.add_argument("--target_model_path", type=str, required=True, default=None, help="Target model path")
args = parser.parse_args()
source_tokenizer = LlamaTokenizer.from_pretrained(args.source_model_and_tokenizer_path)
source_tokenizer.add_bos_token = False
source_tokenizer.add_eos_token = False
if source_tokenizer.pad_token is None:
source_tokenizer.pad_token = source_tokenizer.unk_token
source_vocab = source_tokenizer.get_vocab()
target_tokenizer = LlamaTokenizer.from_pretrained(args.target_tokenizer_path)
target_tokenizer.add_bos_token = False
target_tokenizer.add_eos_token = False
if target_tokenizer.pad_token is None:
target_tokenizer.pad_token = target_tokenizer.unk_token
target_vocab = target_tokenizer.get_vocab()
target_inverted_vocab = {v: k for k, v in target_vocab.items()}
assert len(target_vocab) > len(
source_vocab
), f"Target vocab size({len(target_vocab)}) must be greater than source vocab size({len(source_vocab)})"
gpu_device = torch.device("cuda:0")
cpu_device = torch.device("cpu")
source_model = LlamaForCausalLM.from_pretrained(args.source_model_and_tokenizer_path)
source_model.eval()
source_model = source_model.to(gpu_device)
source_input_embeddings = source_model.get_input_embeddings()
assert isinstance(source_input_embeddings, torch.nn.Embedding)
assert source_input_embeddings.weight.shape[0] == len(source_vocab)
source_input_embeddings.eval()
source_output_embeddings = source_model.get_output_embeddings()
assert isinstance(source_output_embeddings, torch.nn.Linear)
assert source_output_embeddings.bias is None
assert source_output_embeddings.weight.shape[0] == len(source_vocab)
source_output_embeddings.eval()
input_embeddings = source_input_embeddings.weight.cpu().detach().numpy()
output_embeddings = source_output_embeddings.weight.cpu().detach().numpy()
for i in range(len(source_vocab), len(target_vocab)):
if i % 500 == 0:
logger.info(f"processing {i}/{len(target_vocab)} target tokens")
target_token = target_inverted_vocab[i]
target_to_source_token_ids = torch.LongTensor(source_tokenizer([target_token])["input_ids"][0])
target_to_source_token_ids = target_to_source_token_ids.to(gpu_device)
target_to_source_input_embedding = (
source_input_embeddings.weight[target_to_source_token_ids]
.mean(dim=0)
.unsqueeze(dim=0)
.cpu()
.detach()
.numpy()
)
target_to_source_output_embedding = (
source_output_embeddings.weight[target_to_source_token_ids]
.mean(dim=0)
.unsqueeze(dim=0)
.cpu()
.detach()
.numpy()
)
input_embeddings = np.concatenate((input_embeddings, target_to_source_input_embedding), axis=0)
output_embeddings = np.concatenate((output_embeddings, target_to_source_output_embedding), axis=0)
source_model = source_model.to(cpu_device)
assert isinstance(source_model, LlamaForCausalLM)
# expand
source_model.resize_token_embeddings(new_num_tokens=len(target_vocab))
source_model.model.embed_tokens.weight.data = torch.Tensor(input_embeddings)
source_model.lm_head.weight.data = torch.Tensor(output_embeddings)
source_model = source_model.half()
source_model.save_pretrained(save_directory=args.target_model_path)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,98 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
Initialize new tokenizer for continual pre-training
"""
import argparse
import os
import json
from typing import List, Union
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
from colossalai.logging import get_dist_logger
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
logger = get_dist_logger()
def expand_vocab_tokenizer(
source_tokenizer_dir: Union[str, os.PathLike], target_tokenizer_dir: Union[str, os.PathLike], new_tokens: List[str]
) -> None:
"""Expand tokenizer for continue pre-training."""
if os.path.exists(target_tokenizer_dir):
raise RuntimeError(f"Find existed directory {target_tokenizer_dir}")
source_tokenizer = LlamaTokenizer.from_pretrained(source_tokenizer_dir)
logger.info(source_tokenizer)
source_sp_processor = source_tokenizer.sp_model
source_spm = sp_pb2_model.ModelProto()
source_spm.ParseFromString(source_sp_processor.serialized_model_proto())
logger.info(f"Source tokenizer size: {len(source_sp_processor)}")
# Add new tokens to source tokenizer.
source_spm_tokens = set([p.piece for p in source_spm.pieces])
for piece in new_tokens:
assert isinstance(piece, str), f"Invalid token({piece}) type {type(piece)}"
if piece in source_spm_tokens:
# Skip existed token.
continue
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
source_spm.pieces.append(new_p)
logger.info(f"Expand vocab from {len(source_spm_tokens)} to {len(source_spm.pieces)}")
# Save
os.makedirs(target_tokenizer_dir)
target_tokenizer_model_path = os.path.join(target_tokenizer_dir, "tokenizer.model")
with open(file=target_tokenizer_model_path, mode="wb") as fp:
fp.write(source_spm.SerializeToString())
target_tokenizer = LlamaTokenizer(vocab_file=target_tokenizer_model_path)
target_tokenizer.save_pretrained(save_directory=target_tokenizer_dir)
logger.info(f"Successfully save expand tokenizer to {target_tokenizer_dir}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--source_tokenizer_dir", type=str, required=True, default=None, help="Source tokenizer directory"
)
parser.add_argument(
"--target_tokenizer_dir", type=str, required=True, default=None, help="Target tokenizer directory"
)
parser.add_argument(
"--expand_tokens_file",
type=str,
required=True,
default=None,
help="Path of the file containing tokens to be extended",
)
args = parser.parse_args()
expand_tokens = []
with open(file=args.expand_tokens_file, mode="r", encoding="utf-8") as fp_reader:
for line in fp_reader:
item = json.loads(line)
# e.g., {"piece": "你好"}
token = item["piece"]
if token in expand_tokens:
continue
expand_tokens.append(token)
expand_tokens.sort(key=lambda t: len(t), reverse=False)
expand_vocab_tokenizer(
source_tokenizer_dir=args.source_tokenizer_dir,
target_tokenizer_dir=args.target_tokenizer_dir,
new_tokens=expand_tokens,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,88 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Helper functions for IO
"""
import json
import os
from typing import Any, Dict, Tuple, Union
import torch
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
"""
Load file in JSON format
"""
with open(file=file_path, mode="r", encoding="utf-8") as fp:
return json.load(fp)
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
"""
Save as JSON format
"""
with open(file=file_path, mode="w", encoding="utf-8") as fp:
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
def save_checkpoint(
save_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
def load_checkpoint(
load_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
) -> Tuple[int, int, int]:
"""
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
# Update booster params states.
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
return (
running_states["epoch"],
running_states["step"],
running_states["sample_start_index"],
)

View File

@ -0,0 +1,216 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from types import MethodType
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaAttention,
LlamaModel,
LlamaForCausalLM,
apply_rotary_pos_emb,
repeat_kv,
)
from colossalai.logging import get_dist_logger
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_kvpacked_func,
)
from flash_attn.ops.rms_norm import rms_norm
logger = get_dist_logger()
def _prepare_decoder_attention_mask(
self: LlamaModel,
attention_mask: torch.BoolTensor,
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
past_key_values_length: int,
) -> Optional[torch.Tensor]:
"""
Decoder attetion mask
"""
if past_key_values_length > 0 and attention_mask is not None:
attention_mask = torch.cat(
tensors=(
torch.full(
size=(input_shape[0], past_key_values_length),
fill_value=True,
dtype=attention_mask.dtype,
device=attention_mask.device,
),
attention_mask,
),
dim=-1,
) # (bsz, past_key_values_length + q_len)
if attention_mask is not None and torch.all(attention_mask):
return None # Faster
return attention_mask
def attention_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
"""
if output_attentions:
logger.warning(
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
"return `None` instead."
)
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
q_slicing, kv_slicing = (
dim // self.config.pretraining_tp
for dim in (
self.num_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
)
) # `Tuple[int, int]`
q_slices, k_slices, v_slices = (
proj.weight.split(slicing, dim=0)
for proj, slicing in (
(self.q_proj, q_slicing),
(self.k_proj, kv_slicing),
(self.v_proj, kv_slicing),
)
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
q, k, v = (
torch.cat(
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
dim=-1,
)
for slices in (q_slices, k_slices, v_slices)
)
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
else:
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
q, k, v = (
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
for states, num_heads in (
(q, self.num_heads),
(k, self.num_key_value_heads),
(v, self.num_key_value_heads),
)
)
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
past_kv_len = 0
if past_key_value is not None:
# if `past_key_value` is not None, `kv_len` > `q_len`.
past_kv_len = past_key_value[0].shape[-2]
kv_len += past_kv_len
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
cos, sin = self.rotary_emb(v, seq_len=kv_len)
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
past_key_value = (k, v) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
key_padding_mask = attention_mask
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
if past_kv_len > 0:
q = torch.cat(
tensors=(
torch.full(
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
fill_value=0.0,
dtype=q.dtype,
device=q.device,
),
q,
),
dim=1,
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
if key_padding_mask is None:
# (bsz, past_kv_len + q_len, num_heads, head_dim)
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim)
else:
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
kv, _, cu_kv_lens, max_kv_len = unpad_input(
hidden_states=torch.stack(tensors=(k, v), dim=2),
attention_mask=key_padding_mask,
)
output_unpad = flash_attn_varlen_kvpacked_func(
q=q,
kv=kv,
cu_seqlens_q=cu_q_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_q_len,
max_seqlen_k=max_kv_len,
dropout_p=0.0,
softmax_scale=None,
causal=True,
)
output = pad_input(
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
indices=indices,
batch=bsz,
seqlen=past_kv_len + q_len,
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
if past_kv_len > 0:
# Strip off the zero query outputs.
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
output = self.o_proj(output) # (bsz, q_len, hidden_size)
return output, None, past_key_value
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Formard function for RMS Norm
"""
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
module.forward = MethodType(attention_forward, module)
if isinstance(module, LlamaModel):
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
if isinstance(module, LlamaRMSNorm):
module.forward = MethodType(rms_norm_forward, module)

View File

@ -0,0 +1,18 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from transformers.models.llama import LlamaForCausalLM
def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
"""Freeze all parameters except embeddings."""
for name, params in model.named_parameters():
if "embed_tokens" not in name and "lm_head" not in name:
params.requires_grad = False
else:
params.requires_grad = True
def unfreeze_parameters(model: LlamaForCausalLM) -> None:
for name, params in model.named_parameters():
params.requires_grad = False

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,2 @@
hostname1
hostname2

View File

@ -0,0 +1,153 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Prepare dataset for continual pre-training
"""
import argparse
import json
import math
import os
import time
from multiprocessing import cpu_count
from datasets import dataset_dict, load_dataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from colossalai.logging import get_dist_logger
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
supervised_tokenize,
ClosedToConstantLengthSplicedDataset,
)
logger = get_dist_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_input_dirs",
type=str,
required=True,
default=None,
help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.",
)
parser.add_argument(
"--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
)
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
parser.add_argument(
"--data_jsonl_output_dir",
type=str,
default="jsonl_output",
help="Output directory of spliced dataset with jsonl format",
)
parser.add_argument(
"--data_arrow_output_dir",
type=str,
default="arrow_output",
help="Output directory of spliced dataset with arrow format",
)
parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
args = parser.parse_args()
if args.num_spliced_dataset_bins >= 100000:
raise ValueError("Too many spliced divisions, must be smaller than 100000")
assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
assert not os.path.exists(
args.data_jsonl_output_dir
), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
assert not os.path.exists(
args.data_arrow_output_dir
), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
os.makedirs(args.data_jsonl_output_dir)
os.makedirs(args.data_arrow_output_dir)
# Prepare to all input datasets
input_data_paths = []
input_data_dirs = args.data_input_dirs.split(",")
for ds_dir in input_data_dirs:
ds_dir = os.path.abspath(ds_dir)
assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}"
ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")]
ds_paths = [os.path.join(ds_dir, name) for name in ds_files]
input_data_paths.extend(ds_paths)
# Prepare to data splitting.
train_splits = []
split_interval = math.ceil(100 / args.num_spliced_dataset_bins)
for i in range(0, 100, split_interval):
start = i
end = i + split_interval
if end > 100:
end = 100
train_splits.append(f"train[{start}%:{end}%]")
# Prepare to the tokenizer.
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
list_dataset = load_dataset(
path="json",
data_files=input_data_paths,
cache_dir=os.path.join(args.data_cache_dir, "raw"),
keep_in_memory=False,
split=train_splits,
num_proc=cpu_count(),
)
for index, dataset in enumerate(list_dataset):
assert isinstance(dataset, dataset_dict.Dataset)
logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
dataset = dataset.map(
function=supervised_tokenize,
fn_kwargs={"tokenizer": tokenizer, "max_length": args.max_length},
keep_in_memory=False,
num_proc=min(len(dataset), cpu_count()),
)
dataset = dataset.remove_columns(column_names=["source", "target", "category"])
dataset = dataset.sort(column_names=("seq_category", "seq_length"), reverse=False, keep_in_memory=False)
dataset = dataset.remove_columns(column_names=["seq_category", "seq_length"])
spliced_dataset = ClosedToConstantLengthSplicedDataset(
dataset=dataset, tokenizer=tokenizer, max_length=args.max_length, error_strict=False
)
# Save each jsonl spliced dataset.
output_index = "0" * (5 - len(str(index))) + str(index)
output_name = f"part-{output_index}"
output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl")
st = time.time()
with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer:
spliced_count = 0
for spliced_data_point in spliced_dataset:
if spliced_count % 500 == 0:
logger.info(f"processing {spliced_count} spliced data points for {fp_writer.name}")
spliced_count += 1
fp_writer.write(json.dumps(spliced_data_point, ensure_ascii=False) + "\n")
logger.info(
f"Current file {fp_writer.name}; "
f"Data size: {len(spliced_dataset)}; "
f"Spliced data size: {spliced_dataset.current_size}; "
f"Splicing compression rate: {round(spliced_dataset.current_size / len(spliced_dataset), 6)}; "
f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
)
# Save each arrow spliced dataset
output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name)
logger.info(f"Start to save {output_arrow_path}")
spliced_dataset = load_dataset(
path="json",
data_files=[output_jsonl_path],
cache_dir=os.path.join(args.data_cache_dir, "spliced_and_tokenized"),
keep_in_memory=False,
num_proc=cpu_count(),
split="train",
)
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,15 @@
torch<2.0.0, >=1.12.1
packaging==23.1
colossalai==0.3.2
autoflake==2.2.1
black==23.9.1
transformers
tensorboard==2.14.0
six==1.16.0
datasets
ninja==1.11.1
flash-attn>=2.0.0,<=2.0.5
tqdm
sentencepiece==0.1.99
protobuf<=3.20.0

View File

@ -0,0 +1,44 @@
#!/bin/bash
# NCCL IB environment variables
export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_GID_INDEX=3
export NCCL_IB_TIMEOUT=23
export NCCL_IB_RETRY_CNT=7
export OMP_NUM_THREADS=8
PROJECT_NAME=""
PARENT_SAVE_DIR=""
PARENT_TENSORBOARD_DIR=""
PARENT_CONFIG_FILE=""
PRETRAINED_MODEL_PATH=""
declare -a dataset=(
"PATH TO THE DATASET"
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \
--pretrained $PRETRAINED_MODEL_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
--save_interval 400 \
--save_dir $SAVE_DIR \
--tensorboard_dir $TENSORBOARD_DIR \
--config_file $CONFIG_FILE \
--num_epochs 1 \
--micro_batch_size 8 \
--lr 1e-4 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--weight_decay 0.01 \
--warmup_steps 100 \
--use_grad_checkpoint \
--use_flash_attn \

View File

@ -0,0 +1,383 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
"""
import json
import argparse
import os
import resource
from contextlib import nullcontext
from tqdm import tqdm
import torch
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import (
GeminiPlugin,
LowLevelZeroPlugin,
HybridParallelPlugin,
)
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossal_llama2.dataset.loader import (
load_tokenized_dataset,
setup_distributed_dataloader,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
)
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
def get_model_numel(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f"{numel / B:.2f} B"
elif numel >= M:
return f"{numel / M:.2f} M"
elif numel >= K:
return f"{numel / K:.2f} K"
else:
return f"{numel}"
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor
def main() -> None:
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained",
type=str,
default=None,
help="Address of the pre-trained modeling",
)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
help="Choose which plugin to use",
)
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
parser.add_argument(
"--mixed_precision",
type=str,
default="fp16",
choices=["fp16", "bf16"],
help="Mixed precision",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument(
"--use_grad_checkpoint",
action="store_true",
default=False,
help="Use gradient checkpointing",
)
parser.add_argument(
"--use_flash_attn",
action="store_true",
default=False,
help="Use flash-attention",
)
parser.add_argument(
"--freeze_non_embeds_params",
action="store_true",
default=False,
help="Freeze non embeddings parameters",
)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--zero", type=int, default=1)
args = parser.parse_args()
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
# ==============================
# Initialize Tensorboard
# ==============================
if coordinator.is_master():
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=args.zero,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
coordinator.print_on_master(f"Load dataset: {args.dataset}")
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
dataloader = setup_distributed_dataloader(
dataset=dataset,
batch_size=args.micro_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
)
coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
init_ctx = (
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
)
with init_ctx:
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)
if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
if args.use_flash_attn:
replace_with_flash_attention(model=model)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
optimizer = HybridAdam(
model_params=filter(lambda p: p.requires_grad, model.parameters())
if args.freeze_non_embeds_params
else model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * len(dataloader),
warmup_steps=args.warmup_steps
if args.warmup_steps is not None
else int(args.num_epochs * len(dataloader) * 0.025),
eta_min=0.1 * args.lr,
)
# Flash attention will be disabled because it does NOT support fp32.
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
if args.load_checkpoint is None:
coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
booster.load_model(model, args.pretrained, strict=False)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
start_epoch = 0
start_step = 0
sampler_start_idx = 0
if args.load_checkpoint is not None:
if "modeling" in args.load_checkpoint:
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
booster.load_model(model, args.load_checkpoint)
else:
coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
start_epoch, start_step, sampler_start_idx = load_checkpoint(
load_dir=args.load_checkpoint,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
coordinator.print_on_master(
f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
)
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
coordinator.print_on_master(
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
num_steps_per_epoch = len(dataloader)
# If resume training, set the sampler start index to the correct value
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch)
with tqdm(
iterable=enumerate(dataloader, start=start_step),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
initial=start_step,
) as pbar:
for step, batch in pbar:
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch)
loss = batch_output.loss
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
all_reduce_mean(tensor=loss)
pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
if coordinator.is_master():
global_step = epoch * num_steps_per_epoch + step
writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step)
writer.add_scalar(
tag="Learning Rate",
scalar_value=lr_scheduler.get_last_lr()[0],
global_step=global_step,
)
# Save modeling.
if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader):
coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.micro_batch_size,
coordinator=coordinator,
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)
# Delete CUDA cache.
# del batch, batch_labels, batch_output, loss
torch.cuda.empty_cache()
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(start_index=0)
start_step = 0
# Final save.
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
0.0.1

View File

@ -4,8 +4,9 @@ This directory contains the applications that are powered by Colossal-AI.
The list of applications include:
- [X] [Chatbot](./Chat/README.md)
- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters
- [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2.
- [X] [Chatbot](./Chat/README.md): Replication of ChatGPT with RLHF.
- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters.
> Please note that the `Chatbot` application is migrated from the original `ChatGPT` folder.