mirror of https://github.com/hpcaitech/ColossalAI
initial commit: add colossal llama 2 (#4784)
parent
4146f1c0ce
commit
74aa7d964a
|
@ -0,0 +1,377 @@
|
|||
<div align="center">
|
||||
<h1>
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/colossalllam2.jpg?raw=true" width=800/>
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
## Table of Contents
|
||||
- [News](#news)
|
||||
- [Colossal-LLaMA-2-7B](#colossal-llama-2-7b)
|
||||
- [Performance Evaluation](#performance-evaluation)
|
||||
- [Examples](#examples)
|
||||
- [Training Logs](#training-logs)
|
||||
- [Import from Transformers](#import-from-transformers)
|
||||
- [Usage](#usage)
|
||||
- [Install](#install)
|
||||
- [How to run](#how-to-run)
|
||||
- [Technical Insight](#technical-insights)
|
||||
- [Data](#data)
|
||||
- [Tokenizer](#tokenizer)
|
||||
- [Training Strategy](#training-strategy)
|
||||
- [Citations](#citations)
|
||||
|
||||
## News
|
||||
* [2023/09] 🔥 TODO We released **Colossal-LLaMA-2-7B-base** based on LLaMA-2. [Download weights](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base).
|
||||
|
||||
## Colossal-LLaMA-2-7B
|
||||
The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team has introduced the open-source model **Colossal-LLaMA-2-7B-base**. This model, a derivation of LLaMA-2, has undergone continual pre-training involving approximately 8.5 billion tokens over a duration of 15 hours with 64 A800 GPUs. At a cost of **less than $1,000**, you can achieve results **similar to those that cost millions of dollars to pretrain from scratch**. It is licensed under the LLaMA-2 license and [Apache 2.0 License](https://github.com/hpcaitech/ColossalAI/blob/main/LICENSE) **without any additional commercial use restrictions**. This solution can also be used to build models of specific domain knowledge or tasks.
|
||||
|
||||
Colossal-LLaMA-2-7B-base is designed to accommodate both the Chinese and English languages, featuring an expansive context window spanning 4096 tokens. Remarkably, it has exhibited exceptional performance when benchmarked against models of equivalent scale in standard Chinese and English evaluation metrics, including C-Eval and MMLU, among others.
|
||||
|
||||
### Performance Evaluation
|
||||
We conducted comprehensive evaluation on 4 dataset and compare our Colossal-Llama-2-7b-base model with various models.
|
||||
|
||||
* We use 5-shot for MMLU and calculate scores based on the logits of first predicted token.
|
||||
* We use 5-shot for CMMLU and calculate scores based on the logits of first predicted token.
|
||||
* We use 5-shot for AGIEval and only calculate scores for 4-choice questions using a combination metric of exact match and the logits of first predicted token. If any of the exact match or logits of first predicted token is correct, the model will get the score.
|
||||
* We use 0-shot for GAOKAO-Bench and only calculate scores for 4-choice questions based on the logits of first predicted token.
|
||||
The generation config for all dataset is greedy search.
|
||||
* We also provided CEval scores from its lastest leaderboard or the official repository of the model.
|
||||
|
||||
| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval |
|
||||
| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: |
|
||||
| | | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot |
|
||||
| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 |
|
||||
| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 |
|
||||
| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 |
|
||||
| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 |
|
||||
| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 |
|
||||
| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 |
|
||||
| InternLM-7B | - | - | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 |
|
||||
| Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 |
|
||||
| | | | | | | | | |
|
||||
| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - |
|
||||
| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - |
|
||||
| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - |
|
||||
| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 |
|
||||
| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - |
|
||||
| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - |
|
||||
| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - |
|
||||
| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - |
|
||||
| | | | | | | | | |
|
||||
| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 |
|
||||
|
||||
> The score in parentheses corresponds to the scores in the official repository of the model.
|
||||
>
|
||||
> We use zero-shot for ChatGLM models.
|
||||
>
|
||||
> Qwen-7B is now inaccessible in Hugging Face, we are using the latest version of it before it was made inaccessible. Only for dataset MMLU, the prompt would be "xxx Answer:"(remove the space after ":") and we calculate the logits over " A", " B", " C" and " D" for Qwen-7B. Qwen-7B tends to be much more deterministic than other models. For example, the logits over " A" can be `-inf` and softmax would be exact `0`.
|
||||
>
|
||||
> For other models and other dataset, we calculate logits over "A", "B", "C" and "D".
|
||||
|
||||
❗️ More details of the evaluation methods and reproduction of the results, please refer to [TODO: ColossalEval]().
|
||||
|
||||
### Examples
|
||||
| Question Type | Question | <center>Colossal-LLaMA-2-7b-base</center> |
|
||||
| :------: | :-----------------------: | :----------- |
|
||||
| Knowledge Quiz<br>(Tourism) | 请问武汉的最佳的旅行季节是什么时候? | 3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉的最佳的旅行季节是春秋,3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉,简称“汉”别称江城,是湖北省省会,中部六省唯一的副省级市,特大城市,中国中部地区的中心城市,全国重要的工业基地、科教基地和综合交通枢纽。 |
|
||||
| Knowledge Quiz<br>(Computer Science) | 在计算机中,什么是RAM?与ROM有什么区别? | RAM是随机存取内存,可以随时读写,但是断电之后会丢失数据。存储数据比较快,但是需要不断的进行通电供电。ROM是只读内存,只能读取数据,不能写入数据。断电之后,存储的数据也不会丢失,但是需要不断进行通电。两者的价格和性能上会有区别,而且在使用的时候,RAM需要不断的供电,而ROM断电数据不丢失。 |
|
||||
| Knowledge Quiz<br>(Literature) | 古代文学与现代文学有哪些区别和共通之处? | 古代文学与现代文学有三点区别,现代文学更加通俗易懂,具有可读性,而古代文学相对来说更加注重形式,缺乏可读性。现代文学和古代文学,同属于中华文学大花园中的一枝花蕾。但是它们由于历史的原因,有着不同的特点,具体表现在,现代文学有更丰富的描写方法,更强烈的主观色彩和更突出的抒情色彩,而古代文学的主要特征是在创作风格上的多样化。 |
|
||||
| Emotional Analysis | 请判断这个句子的情绪是否正面:我迫不及待想见到她!!!”。 | 这一句是情绪是正向的。是在期待见面的喜悦情绪。 |
|
||||
| Translation | 将以下句子翻译成英语。 我喜欢看电影和读书。 | I like watching movies and reading books. |
|
||||
| Logical Reasoning | All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books? | To find the weight of 2 books, we can use the following steps:<br><br>1. Determine the weight of 1 book.<br>2. Multiply the weight of 1 book by the number of books.<br><br>Step 1: Determine the weight of 1 book.<br>Since 10 books weigh 5kg, we can divide 5kg by 10 to find the weight of 1 book.<br>5kg ÷ 10 = 0.5kg<br><br>Step 2: Multiply the weight of 1 book by the number of books.<br>To find the weight of 2 books, we multiply the weight of 1 book by 2.<br>0.5kg × 2 = 1kg<br><br>So, the weight of 2 books is 1kg. |
|
||||
| Information Extraction | The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence. | The date mentioned in the sentence is October 20th to October 23rd, 2021. |
|
||||
| Error Correction | Please correct the syntax error and restructure the sentence: "Went to the store me and my friend." | The given sentence "Went to the store me and my friend." is incorrect due to a syntax error and it lacks a subject. The corrected version with proper grammar and sentence structure would be: "My friend and I went to the store." |
|
||||
|
||||
❗️ More examples of question answering, please refer to [Colossal-LLaMA-2-7B-base Examples](docs/example.md).
|
||||
|
||||
### Training Logs
|
||||
We also recorded the training logs for the experiment
|
||||
|
||||
<p id="Colossal-LLaMA-2-Multi-stage-training" align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/trainingLossBySteps.jpeg?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
<p id="Colossal-LLaMA-2-Multi-stage-training" align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/trainingLossByTokens.jpeg?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
### Import from Transformers
|
||||
To load Colossal-LLaMA-2-7B-base model using Transformers, use the following code:
|
||||
```Python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained("hpcai-tech/Colossal-LLaMA-2-7b-base", device_map="auto", trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/Colossal-LLaMA-2-7b-base", trust_remote_code=True)
|
||||
input = "离离原上草,"
|
||||
inputs = tokenizer(input, return_tensors='pt')
|
||||
inputs = inputs.to('cuda:0')
|
||||
pred = model.generate(**inputs,
|
||||
max_new_tokens=256,
|
||||
do_sample=True,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
num_return_sequences=1)
|
||||
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(input):])
|
||||
```
|
||||
|
||||
You can also download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base).
|
||||
|
||||
## Usage
|
||||
### Install
|
||||
|
||||
#### 0. Pre-requisite
|
||||
1. This experiment was performed on 8 computing nodes with 64 A800 GPUs in total for LLaMA-2-7B (**about 1000 USD cost**). The nodes are connected with RDMA and GPUs within one node are fully connected with NVLink. The script was tested with CUDA 11.7, CUDA version requires 11.7 or higher. You can also complete it in about 5 days on a 8*A100/A800 server.
|
||||
|
||||
2. PyTorch. The PyTorch version should be less than 2.0.0 and greater than 1.12.1.
|
||||
|
||||
|
||||
#### 1. Install required packages
|
||||
```
|
||||
cd Colossal-LLaMA-2
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
#### 2. Install `xentropy`, `layer_norm` and `rotary`
|
||||
```bash
|
||||
git clone git@github.com:Dao-AILab/flash-attention.git
|
||||
# At the root folder
|
||||
cd csrc/xentropy && pip install .
|
||||
# At the root folder
|
||||
cd csrc/layer_norm && pip install .
|
||||
# At the root folder
|
||||
cd csrc/rotary && pip install .
|
||||
```
|
||||
|
||||
### How to run
|
||||
|
||||
#### 1. Init Tokenizer Preparation
|
||||
Initialize new tokenizer with additional Chinese tokens. Additional Chinese tokens are stored in `jsonl` format as follows:
|
||||
```json
|
||||
{"piece": "你好"}
|
||||
{"piece": "人工智能"}
|
||||
```
|
||||
Command to initialize new tokenizer:
|
||||
```bash
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION='python'
|
||||
python colossal_llama2/tokenizer/init_tokenizer.py \
|
||||
--source_tokenizer_dir "<SOURCE_TOKENIZER_DIR>" \
|
||||
--target_tokenizer_dir "<TARGET_TOKENIZER_DIR>" \
|
||||
--expand_tokens_file "<NEW_TOKENS_FILE>.jsonl"
|
||||
```
|
||||
Here is details about CLI arguments:
|
||||
* Source tokenizer directory: `--source_tokenizer_dir`. Directory to the source tokenizer. It should at least contain three files: `special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`.
|
||||
* Target tokenizer directory: `--target_tokenizer_dir`. Directory to the target tokenizer.
|
||||
* Tokens to be added: `--expand_tokens_file`. Additional tokens to be added to the tokenizer.
|
||||
|
||||
#### 2. Init Model Preparation
|
||||
Initialize the new model checkpoint by calculating the mean values from the original model checkpoint.
|
||||
Command to initialize new model checkpoint:
|
||||
```bash
|
||||
python colossal_llama2/model/init_model.py \
|
||||
--source_model_and_tokenizer_path "<SOURCE_MODEL_AND_TOKENIZER_DIR>" \
|
||||
--target_tokenizer_path "<TARGET_TOKENIZER_DIR>" \
|
||||
--target_model_path "<TARGET_MODEL_DIR>"
|
||||
```
|
||||
"<TARGET_MODEL_DIR>" can be the same as "<TARGET_TOKENIZER_DIR>".
|
||||
|
||||
Here is details about CLI arguments:
|
||||
* Source model and tokenizer path: `--source_model_and_tokenizer_path`. Source folder contains both model and tokenizer, for example, LLaMA-2 model in Hugging Face format.
|
||||
* Target tokenizer path: `--target_tokenizer_path`. Path to the new tokenizer folder generated from previous step.
|
||||
* Target model path: `--target_model_path`. Path to save the new model in Hugging Face format.
|
||||
|
||||
❗️**Important**: Once you initialize the new model checkpoint, copy your new tokenizer files (`special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`) to your new model folder.
|
||||
|
||||
#### 3. Data Preparation
|
||||
Raw data should be formatted as `jsonl` format. Each data point should have the following fields:
|
||||
* `source` (str, compulsory): This part is ignored when calculating loss. Default can be empty.
|
||||
* `target` (str, compulsory): Loss will be calculated.
|
||||
* `category` (str, compulsory): Tags for each data point.
|
||||
|
||||
Examples:
|
||||
```JSON
|
||||
{"source": "", "target": "Lionel Andrés Messi(Spanish pronunciation: [ljoˈnel anˈdɾes ˈmesi] (i); born 24 June 1987), also known as Leo Messi, is an Argentine professional footballer who plays as a forward for and captains both Major League Soccer club Inter Miami and the Argentina national team.", "category": "sports"}
|
||||
{"source": "猜谜语:一身卷卷细毛,吃的青青野草,过了数九寒冬,无私献出白毛。(打一动物)", "target": "白羊", "category": "riddle"}
|
||||
```
|
||||
You are allowed to customize the category tags or use `unknown` to define the category.
|
||||
|
||||
Command to convert jsonl dataset to arrow format:
|
||||
```
|
||||
python prepare_pretrain_dataset.py \
|
||||
--data_input_dirs "<JOSNL_DIR_1>,<JOSNL_DIR_2>,<JOSNL_DIR_3>" \
|
||||
--tokenizer_dir "<TOKENIZER_DIR>" \
|
||||
--data_cache_dir "jsonl_to_arrow_cache" \
|
||||
--data_jsonl_output_dir "spliced_tokenized_output_jsonl" \
|
||||
--data_arrow_output_dir "spliced_tokenized_output_arrow" \
|
||||
--max_length 4096 \
|
||||
--num_spliced_dataset_bins 10
|
||||
```
|
||||
Here is details about CLI arguments:
|
||||
* Source data directory: `data_input_dirs`. Each `<JOSNL_DIR>` can have multiple file in `jsonl` format.
|
||||
* Tokenzier directory: `tokenizer_dir`. Path to the tokenizer in Hugging Face format.
|
||||
* Data cache directory: `data_cache_dir`. Directory to store Hugging Face data cache. Default case will create `cache` folder locally.
|
||||
* Output directory for jsonl format: `data_jsonl_output_dir`. Output directory to store converted dataset in jsonl format.
|
||||
* Output directory for arrow format: `data_arrow_output_dir`. Output directory to store converted dataset in arrow format, which can be used for training directly.
|
||||
* Max length: `max_length`. Max length of spliced samples. Default value is 4096.
|
||||
* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training.
|
||||
|
||||
#### 4. Command Line Arguments for Training
|
||||
You can use `colossalai run` to launch multi-nodes training:
|
||||
```bash
|
||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
||||
pretrain.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
Here is a sample hostfile:
|
||||
```bash
|
||||
hostname1
|
||||
hostname2
|
||||
hostname3
|
||||
hostname4
|
||||
```
|
||||
Make sure master node can access all nodes (including itself) by ssh without password.
|
||||
|
||||
Here is details about CLI arguments:
|
||||
* Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format.
|
||||
* Dataset path: `--dataset`. Path to the pre-tokenized dataset.
|
||||
* Booster plugin: `--plugin`. `gemini`, `gemini_auto`, `zero2`,`zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/).
|
||||
* Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training.
|
||||
* Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
||||
* Checkpoint directory: `--save_dir`. The directoty path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`.
|
||||
* Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs.
|
||||
* Configuration file: `--config_file`. The path to save the configuration file.
|
||||
* Number of epochs: `--num_epochs`. Number of training epochs. The default value is 1.
|
||||
* Micro batch size: `--micro_batch_size`. Batch size per GPU. The default value is 1.
|
||||
* Learning rate: `--lr`. The default value is 3e-4.
|
||||
* Max length: `--max_length`. Max context length. The default value is 4096.
|
||||
* Mixed precision: `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
|
||||
* Gradient clipping: `--gradient_clipping`. The default value is 1.0.
|
||||
* Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
|
||||
* Warmup steps: `-s`, `--warmup_steps`. The default value is calcuated by 0.025 warmup ratio.
|
||||
* Gradient checkpointing: `--use_grad_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
|
||||
* Flash attention: `--use_flash_attn`. If you want to use flash attention, you must install `flash-attn` and related packages. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
|
||||
* Freeze non-embedding parameters: `--freeze_non_embeds_params`. Freeze non-embedding parameters. It can be helpful to align embeddings after extending vocabulary size.
|
||||
* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1.
|
||||
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1.
|
||||
|
||||
#### 5. Running Command
|
||||
An [example bash](train.example.sh) is also provided for the experiment. Here is the steps to run the experiment:
|
||||
* Create your own hostfile: `cp hostfile.example hostfile`.
|
||||
* Create your own bash: `cp train.example.sh train.sh`.
|
||||
* Add your real host ip or host name into the `hostfile`.
|
||||
* Update global variables and parameters in your `train.sh`.
|
||||
* Run the experiment by `bash train.sh`
|
||||
|
||||
Here is the details about global variables for each experiment:
|
||||
* `PROJECT_NAME`: Project name for each experiment.
|
||||
* `PARENT_SAVE_DIR`: Parent folder to save model checkpoint.
|
||||
* `PARENT_TENSORBOARD_DIR`: Parent folder to save tensorboard logs.
|
||||
* `PARENT_CONFIG_FILE`: Parent folder to save configuration for each experiment.
|
||||
* `PRETRAINED_MODEL_PATH`: Path to the local pre-trained model checkpoint.
|
||||
* `dataset`: Paths to all prepared data. Typically, it's a list of subfolders within the output path of prepare data, `--data_arrow_output_dir`, and if there are multiple subfolders, please list them all. e.g.,
|
||||
```python
|
||||
declare -a dataset=(
|
||||
"<DIR_1>/part-00000"
|
||||
"<DIR_1>/part-00001"
|
||||
"<DIR_2>/part-00000"
|
||||
)
|
||||
```
|
||||
## Technical Insights
|
||||
In order to enhance LLaMA-2's capabilities for understanding and generating Chinese content, The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team proposes the continuation of pre-training the LLaMA-2 model using both Chinese and English corpora. The overall pipeline can be described as follows:
|
||||
|
||||
<p id="Colossal-LLaMA-2-pipeline" align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/Colossal-LLaMA-2-pipeline.jpeg?raw=true" width=800/>
|
||||
</p>
|
||||
|
||||
### Data
|
||||
Large language models such as LLaMA-2 have undergone training using a heterogeneous blend of high-quality datasets, yielding promising outcomes. Enhancing LLaMA-2's performance for the Chinese corpus, while preserving its proficiency in English, critically hinges on two pivotal factors: the composition of the dataset, which encompasses both English and Chinese content, and the quality of each constituent dataset.
|
||||
|
||||
The following figure shows the data processing pipeline conducted for Colossal-LLaMA-2.
|
||||
<p id="Colossal-LLaMA-2-data-processing-pipeline" align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/data_processing_pipeline.jpeg?raw=true" width=800/>
|
||||
</p>
|
||||
|
||||
❗️**Important**: We will open-source our data-processing toolkit soon, stay tuned!
|
||||
|
||||
### Tokenizer
|
||||
The original LLaMA-2 vacabulary comprises fewer than a thousand Chinese characters, thus proves inadequate for encoding comprehensive Chinese texts effectively. Secondly, the utilization of byte tokens presents a challenge for transformer encoders to capture the semantic nuances of Chinese characters.
|
||||
|
||||
To address the above issues, we extend LLaMA-2 vocabulary from 32,000 to 69,104. To adapt the LLaMA-2 model for use with the Colossal-LLaMA-2 tokenizer, we initialize the new word embeddings by calculating the mean values from the original LLaMA-2 embeddings and subsequently append these new rows to the end of the original embedding matrices.
|
||||
|
||||
Advantages of extending vocabulary size:
|
||||
* Improve the compression rate of string sequence encoding.
|
||||
* Enhance the integrity of information.
|
||||
* Enable encoded sequences to contain more valuable information, thereby theoretically enhancing the ability for chapter-level encoding.
|
||||
|
||||
Advantages of large vocabulary size under low-resource settings:
|
||||
* The presence of numerous unused tokens can be attributed to the limited training dataset, where an excessive number of tokens might not have been effectively learned.
|
||||
* Excessive vocabulary expansion leads to an increase in embedding-related parameters, resulting in higher memory usage, which, in turn, affects the efficiency of the training process.
|
||||
|
||||
To balance both sides, we finally construct our vocabulary with size 69,104. The following table below presents a comparison of various models at the 7B level.
|
||||
|
||||
| Model | Vocabulary Size | Compression Rate | Average Length of Samples (token-level) |
|
||||
| :-----------: | :---------: | :----: | :----: |
|
||||
| Colossal-LLaMA-2 | 69104 | 0.659 | 73.682 |
|
||||
| LLaMA-2-7B | 32000 | 1.205 | 134.689 |
|
||||
| Atom-7B | 65000 | 0.634 | 70.915 |
|
||||
| Baichuan-7B | 64000 | 0.678 | 75.857 |
|
||||
| Baichuan2-7B-base | 125696 | 0.570 | 63.761 |
|
||||
| Chatglm2-6B | 64789 | 0.645 | 72.178 |
|
||||
| InternLM-7B | 103168 | 0.566 | 63.349 |
|
||||
| Qwen-7B | 151643 | 0.578 | 64.703 |
|
||||
| Tigerbot-7B-base | 60515 | 0.630 | 70.515 |
|
||||
| Yayi-7B-llama2 | 32005 | 1.214 | 135.689 |
|
||||
| Chinese-llama-2-7b | 55296 | 0.668 | 74.690 |
|
||||
| Chinese-Falcon-7B | 90046 | 0.669 | 74.858 |
|
||||
| LinkSoul-Chinese-Llama-2-7b | 40076 | 0.958 | 107.089 |
|
||||
| Ziya-LLaMA-13B-v1.1 | 39410 | 0.958 | 107.074 |
|
||||
|
||||
|
||||
### Training Strategy
|
||||
#### Multi-stage Training
|
||||
In order to enhance the model's performance and harness the full potential of the original LLaMA-2, we have developed a multi-stage training strategy. This strategy is designed to systematically unlock the model's capabilities over a series of stages.
|
||||
|
||||
Therefore, we have divided the training process into three stages:
|
||||
* Large-scale pre-training stage (Conducted by LLaMA-2): This initial stage is aimed at establishing the model's foundational capabilities from the ground up. It necessitates the use of a substantial dataset comprising no less than 1 trillion tokens.
|
||||
* Chinese knowledge injection stage: In this stage, we introduce Chinese knowledge into the model. It requires access to a high-quality dataset rich in comprehensive knowledge relevant to the Chinese language.
|
||||
* Knowledge replay stage: Knowledge is replayed through a question-answering (QA) mechanism, encompassing both the Chinese and English domains.
|
||||
|
||||
Following the completion of this multi-stage training process, the model exhibits notable improvements in performance across both English and Chinese benchmarks.
|
||||
|
||||
The following figure illustrates the three stages for training Colossal-LLaMA-2.
|
||||
|
||||
<p id="Colossal-LLaMA-2-Multi-stage-training" align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/multi-stage-training.png?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
#### Bucket-based Training
|
||||
Our experiments have revealed that the distributions within the training dataset, as well as the arrangement of various topic-related data points, significantly impact the overall performance of the model, particularly in the context of continual pre-training of LLaMA-2.
|
||||
|
||||
In an effort to achieve a more balanced distribution and exert control over the dataset's ordering, we have adopted a method where we divide each sub-dataset into discrete bins. These bins are then combined to construct individual data buckets, with one bin contributed by each sub-dataset.
|
||||
|
||||
## Citations
|
||||
```bibtex
|
||||
@article{bian2021colossal,
|
||||
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
|
||||
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
|
||||
journal={arXiv preprint arXiv:2110.14883},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
```bibtex
|
||||
@misc{touvron2023llama,
|
||||
title={Llama 2: Open Foundation and Fine-Tuned Chat Models},
|
||||
author={Hugo Touvron and Louis Martin and Kevin Stone and Peter Albert and Amjad Almahairi and Yasmine Babaei and Nikolay Bashlykov and Soumya Batra and Prajjwal Bhargava and Shruti Bhosale and Dan Bikel and Lukas Blecher and Cristian Canton Ferrer and Moya Chen and Guillem Cucurull and David Esiobu and Jude Fernandes and Jeremy Fu and Wenyin Fu and Brian Fuller and Cynthia Gao and Vedanuj Goswami and Naman Goyal and Anthony Hartshorn and Saghar Hosseini and Rui Hou and Hakan Inan and Marcin Kardas and Viktor Kerkez and Madian Khabsa and Isabel Kloumann and Artem Korenev and Punit Singh Koura and Marie-Anne Lachaux and Thibaut Lavril and Jenya Lee and Diana Liskovich and Yinghai Lu and Yuning Mao and Xavier Martinet and Todor Mihaylov and Pushkar Mishra and Igor Molybog and Yixin Nie and Andrew Poulton and Jeremy Reizenstein and Rashi Rungta and Kalyan Saladi and Alan Schelten and Ruan Silva and Eric Michael Smith and Ranjan Subramanian and Xiaoqing Ellen Tan and Binh Tang and Ross Taylor and Adina Williams and Jian Xiang Kuan and Puxin Xu and Zheng Yan and Iliyan Zarov and Yuchen Zhang and Angela Fan and Melanie Kambadur and Sharan Narang and Aurelien Rodriguez and Robert Stojnic and Sergey Edunov and Thomas Scialom},
|
||||
year={2023},
|
||||
eprint={2307.09288},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
||||
```bibtex
|
||||
@article{dao2023flashattention2,
|
||||
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
|
||||
author={Dao, Tri},
|
||||
year={2023}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
|
@ -0,0 +1,2 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
|
@ -0,0 +1,219 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
|
||||
|
||||
import torch
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from datasets import Dataset as HFDataset
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
import torch.nn.functional as F
|
||||
|
||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
PathType = Union[str, os.PathLike]
|
||||
|
||||
|
||||
def load_tokenized_dataset(
|
||||
dataset_paths: Union[PathType, List[PathType]], mode: str = "train"
|
||||
) -> Optional[DatasetType]:
|
||||
"""
|
||||
Load pre-tokenized dataset.
|
||||
Each instance of dataset is a dictionary with
|
||||
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
|
||||
"""
|
||||
mode_map = {"train": "train", "dev": "validation", "test": "test"}
|
||||
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
|
||||
|
||||
if isinstance(dataset_paths, (str, os.PathLike)):
|
||||
dataset_paths = [dataset_paths]
|
||||
|
||||
datasets = [] # `List[datasets.dataset_dict.Dataset]`
|
||||
for ds_path in dataset_paths:
|
||||
ds_path = os.path.abspath(ds_path)
|
||||
assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
|
||||
ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
|
||||
if isinstance(ds_dict, HFDataset):
|
||||
datasets.append(ds_dict)
|
||||
else:
|
||||
if mode_map[mode] in ds_dict:
|
||||
datasets.append(ds_dict[mode_map[mode]])
|
||||
if len(datasets) == 0:
|
||||
return None
|
||||
if len(datasets) == 1:
|
||||
return datasets.pop()
|
||||
return ConcatDataset(datasets=datasets)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSupervisedDataset(object):
|
||||
"""
|
||||
Collate instances for supervised dataset.
|
||||
Each instance is a tokenized dictionary with fields
|
||||
`input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
max_length: int = 4096
|
||||
ignore_index: int = -100
|
||||
|
||||
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
instances (`Sequence[Dict[str, List[int]]]`):
|
||||
Mini-batch samples, each sample is stored in an individual dictionary.
|
||||
|
||||
Returns:
|
||||
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
|
||||
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
|
||||
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
|
||||
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
|
||||
"""
|
||||
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
|
||||
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
|
||||
f"but now `{self.tokenizer.pad_token_id}`"
|
||||
)
|
||||
|
||||
# `List[torch.Tensor]`
|
||||
batch_input_ids = [
|
||||
torch.LongTensor(instance["input_ids"][: self.max_length])
|
||||
if len(instance["input_ids"]) > self.max_length
|
||||
else torch.LongTensor(instance["input_ids"])
|
||||
for instance in instances
|
||||
]
|
||||
batch_labels = [
|
||||
torch.LongTensor(instance["labels"][: self.max_length])
|
||||
if len(instance["labels"]) > self.max_length
|
||||
else torch.LongTensor(instance["labels"])
|
||||
for instance in instances
|
||||
]
|
||||
|
||||
if self.tokenizer.padding_side == "right":
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=batch_input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
) # (bsz, max_len)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=batch_labels,
|
||||
batch_first=True,
|
||||
padding_value=self.ignore_index,
|
||||
) # (bsz, max_len)
|
||||
# pad to max
|
||||
to_pad = self.max_length - input_ids.size(1)
|
||||
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
|
||||
elif self.tokenizer.padding_side == "left":
|
||||
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
|
||||
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=reversed_input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
) # (bsz, max_len)
|
||||
input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len)
|
||||
reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
|
||||
reversed_labels = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=reversed_labels,
|
||||
batch_first=True,
|
||||
padding_value=self.ignore_index,
|
||||
) # (bsz, max_len)
|
||||
labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
|
||||
f"but now `{self.tokenizer.padding_side}`"
|
||||
)
|
||||
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
|
||||
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
"""
|
||||
Stateful distributed sampler for multi-stage training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: DatasetType,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dataset=dataset,
|
||||
num_replicas=num_replicas,
|
||||
rank=rank,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
self.start_index = 0
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index :]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples - self.start_index
|
||||
|
||||
def set_start_index(self, start_index: int) -> None:
|
||||
self.start_index = start_index
|
||||
|
||||
|
||||
def setup_distributed_dataloader(
|
||||
dataset: DatasetType,
|
||||
batch_size: int = 1,
|
||||
shuffle: bool = False,
|
||||
seed: int = 1024,
|
||||
drop_last: bool = False,
|
||||
pin_memory: bool = False,
|
||||
num_workers: int = 0,
|
||||
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**kwargs,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Setup dataloader for distributed training.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset=dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id: int) -> None:
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=drop_last,
|
||||
worker_init_fn=seed_worker,
|
||||
**_kwargs,
|
||||
)
|
|
@ -0,0 +1,183 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Splicing multiple pre-tokenized sequence data points
|
||||
"""
|
||||
|
||||
import random
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from datasets import dataset_dict
|
||||
from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
|
||||
|
||||
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
|
||||
|
||||
def supervised_tokenize(
|
||||
data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
A tokenization function to tokenize an original pretraining data point as following:
|
||||
{"source": "", "target": "Beijing, the capital of the People's Republic of China, ...", "category": "geography"}
|
||||
"""
|
||||
assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
|
||||
"Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
|
||||
"add <bos> and <eos> manually later"
|
||||
)
|
||||
if ignore_index is None:
|
||||
ignore_index = IGNORE_INDEX
|
||||
|
||||
source_text = data_point["source"] # `str`
|
||||
target_text = data_point["target"] # `str`
|
||||
is_null_source = len(source_text) == 0
|
||||
|
||||
source_text = tokenizer.bos_token + source_text
|
||||
target_text += tokenizer.eos_token
|
||||
sequence_text = source_text + target_text
|
||||
|
||||
tokenized = tokenizer([source_text, sequence_text])["input_ids"]
|
||||
sequence_input_ids = tokenized[1]
|
||||
sequence_labels = deepcopy(sequence_input_ids)
|
||||
|
||||
source_length = len(tokenized[0])
|
||||
if not is_null_source:
|
||||
sequence_labels[:source_length] = [ignore_index for _ in range(source_length)]
|
||||
|
||||
# sequence truncation.
|
||||
if len(sequence_input_ids) > max_length:
|
||||
sequence_input_ids = sequence_input_ids[:max_length]
|
||||
sequence_labels = sequence_labels[:max_length]
|
||||
|
||||
return dict(
|
||||
input_ids=sequence_input_ids,
|
||||
labels=sequence_labels,
|
||||
seq_length=len(sequence_input_ids),
|
||||
seq_category=data_point["category"],
|
||||
)
|
||||
|
||||
|
||||
class ClosedToConstantLengthSplicedDataset(IterableDataset):
|
||||
"""
|
||||
Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
|
||||
original independent (pre-tokenized) data points.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: DSType,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int = 4096,
|
||||
num_packed_sequences: int = 8,
|
||||
fetch_sequence_func: Callable[[Any], Tuple[List[int], List[int]]] = None,
|
||||
input_ids_field: str = "input_ids",
|
||||
labels_field: str = "labels",
|
||||
infinite: bool = False,
|
||||
shuffle: bool = True,
|
||||
error_strict: bool = False,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.dataset = dataset
|
||||
self.max_length = max_length
|
||||
self.infinite = infinite
|
||||
self.max_buffer_size = max_length * num_packed_sequences # e.g., 4096 * 16
|
||||
self.shuffle = shuffle
|
||||
|
||||
# Callable[[Dict[str, Any]], Tuple[List[int], List[int]]],
|
||||
# A function that fetch sequence input_ids and labels from the original data point
|
||||
if fetch_sequence_func is None:
|
||||
self.fetch_sequence_func = lambda data_point: (data_point[input_ids_field], data_point[labels_field])
|
||||
else:
|
||||
self.fetch_sequence_func = fetch_sequence_func
|
||||
self.input_ids_field = input_ids_field
|
||||
self.labels_field = labels_field
|
||||
|
||||
self.error_strict = error_strict
|
||||
self.current_size = 0 # `int`, current packed data size.
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def __iter__(self) -> Iterable[Dict[str, List[int]]]:
|
||||
iterator = iter(self.dataset)
|
||||
more_data_points = True
|
||||
while more_data_points is True:
|
||||
buffer, buffer_len = [], 0
|
||||
while True:
|
||||
# ending condition.
|
||||
if buffer_len >= self.max_buffer_size:
|
||||
break
|
||||
try:
|
||||
# `Tuple[List[int], List[int]]`
|
||||
seq_input_ids, seq_labels = self.fetch_sequence_func(next(iterator))
|
||||
buffer.append({self.input_ids_field: seq_input_ids, self.labels_field: seq_labels})
|
||||
buffer_len += len(buffer[-1][self.input_ids_field])
|
||||
except StopIteration:
|
||||
if self.infinite is True:
|
||||
iterator = iter(self.dataset)
|
||||
warnings.warn("The dataset reached end and the iterator is reset to the start.")
|
||||
else:
|
||||
more_data_points = False
|
||||
break
|
||||
examples = [] # `List[Dict[str, List[int]]]`, save buffered spliced data points.
|
||||
spliced_input_ids, spliced_labels = [], [] # `List[int]`, `List[int]`
|
||||
for i, data_point in enumerate(buffer):
|
||||
# TODO(2023-09-18) check errors for each unspliced tokenized data point
|
||||
seq_input_ids = data_point[self.input_ids_field]
|
||||
seq_labels = data_point[self.labels_field]
|
||||
# Handle special case:
|
||||
# If the length of an original data point (i.e., input_ids length of a data point before splicing)
|
||||
# exceeds `max_length`, truncate it.
|
||||
if len(seq_input_ids) > self.max_length:
|
||||
truncated_seq_input_ids = seq_input_ids[: self.max_length]
|
||||
truncated_label_ids = seq_labels[: self.max_length]
|
||||
if set(truncated_label_ids) == {IGNORE_INDEX}:
|
||||
if self.error_strict is True:
|
||||
raise ValueError(
|
||||
f"Find an out-of-bounds length({len(seq_input_ids)}) data point "
|
||||
f"with all label values as {IGNORE_INDEX}."
|
||||
)
|
||||
else:
|
||||
warnings.warn(f"Filter an error truncated data point (labels all {IGNORE_INDEX})")
|
||||
continue # Skip the current error data point.
|
||||
spliced_data_point = {
|
||||
self.input_ids_field: truncated_seq_input_ids,
|
||||
self.labels_field: truncated_label_ids,
|
||||
}
|
||||
examples.append(spliced_data_point)
|
||||
warnings.warn("Find a data point to be truncated.")
|
||||
continue
|
||||
|
||||
# Pre action judgment.
|
||||
if len(spliced_input_ids) + len(seq_input_ids) > self.max_length:
|
||||
spliced_data_point = {
|
||||
self.input_ids_field: spliced_input_ids,
|
||||
self.labels_field: spliced_labels,
|
||||
} # `Dict[str, List[int]]`
|
||||
# Update.
|
||||
spliced_input_ids, spliced_labels = [], []
|
||||
spliced_input_ids.extend(seq_input_ids)
|
||||
spliced_labels.extend(seq_labels)
|
||||
examples.append(spliced_data_point)
|
||||
else:
|
||||
spliced_input_ids.extend(seq_input_ids)
|
||||
spliced_labels.extend(seq_labels)
|
||||
# For residual spliced data point at the end of the data set
|
||||
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
|
||||
examples.append(
|
||||
{
|
||||
self.input_ids_field: spliced_input_ids,
|
||||
self.labels_field: spliced_labels
|
||||
}
|
||||
)
|
||||
if self.shuffle:
|
||||
random.shuffle(examples)
|
||||
for spliced_data_point in examples:
|
||||
# TODO(2023-09-18): check errors for each spliced tokenized data point.
|
||||
self.current_size += 1
|
||||
yield spliced_data_point
|
|
@ -0,0 +1,111 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Initialize new model with updated tokenizer by calculating the mean values from original model
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--source_model_and_tokenizer_path",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="Source path of model & tokenizer",
|
||||
)
|
||||
parser.add_argument("--target_tokenizer_path", type=str, required=True, default=None, help="Target tokenizer path")
|
||||
parser.add_argument("--target_model_path", type=str, required=True, default=None, help="Target model path")
|
||||
args = parser.parse_args()
|
||||
|
||||
source_tokenizer = LlamaTokenizer.from_pretrained(args.source_model_and_tokenizer_path)
|
||||
source_tokenizer.add_bos_token = False
|
||||
source_tokenizer.add_eos_token = False
|
||||
if source_tokenizer.pad_token is None:
|
||||
source_tokenizer.pad_token = source_tokenizer.unk_token
|
||||
source_vocab = source_tokenizer.get_vocab()
|
||||
|
||||
target_tokenizer = LlamaTokenizer.from_pretrained(args.target_tokenizer_path)
|
||||
target_tokenizer.add_bos_token = False
|
||||
target_tokenizer.add_eos_token = False
|
||||
if target_tokenizer.pad_token is None:
|
||||
target_tokenizer.pad_token = target_tokenizer.unk_token
|
||||
target_vocab = target_tokenizer.get_vocab()
|
||||
target_inverted_vocab = {v: k for k, v in target_vocab.items()}
|
||||
|
||||
assert len(target_vocab) > len(
|
||||
source_vocab
|
||||
), f"Target vocab size({len(target_vocab)}) must be greater than source vocab size({len(source_vocab)})"
|
||||
|
||||
gpu_device = torch.device("cuda:0")
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
source_model = LlamaForCausalLM.from_pretrained(args.source_model_and_tokenizer_path)
|
||||
source_model.eval()
|
||||
source_model = source_model.to(gpu_device)
|
||||
|
||||
source_input_embeddings = source_model.get_input_embeddings()
|
||||
assert isinstance(source_input_embeddings, torch.nn.Embedding)
|
||||
assert source_input_embeddings.weight.shape[0] == len(source_vocab)
|
||||
source_input_embeddings.eval()
|
||||
|
||||
source_output_embeddings = source_model.get_output_embeddings()
|
||||
assert isinstance(source_output_embeddings, torch.nn.Linear)
|
||||
assert source_output_embeddings.bias is None
|
||||
assert source_output_embeddings.weight.shape[0] == len(source_vocab)
|
||||
source_output_embeddings.eval()
|
||||
|
||||
input_embeddings = source_input_embeddings.weight.cpu().detach().numpy()
|
||||
output_embeddings = source_output_embeddings.weight.cpu().detach().numpy()
|
||||
for i in range(len(source_vocab), len(target_vocab)):
|
||||
if i % 500 == 0:
|
||||
logger.info(f"processing {i}/{len(target_vocab)} target tokens")
|
||||
target_token = target_inverted_vocab[i]
|
||||
target_to_source_token_ids = torch.LongTensor(source_tokenizer([target_token])["input_ids"][0])
|
||||
target_to_source_token_ids = target_to_source_token_ids.to(gpu_device)
|
||||
|
||||
target_to_source_input_embedding = (
|
||||
source_input_embeddings.weight[target_to_source_token_ids]
|
||||
.mean(dim=0)
|
||||
.unsqueeze(dim=0)
|
||||
.cpu()
|
||||
.detach()
|
||||
.numpy()
|
||||
)
|
||||
target_to_source_output_embedding = (
|
||||
source_output_embeddings.weight[target_to_source_token_ids]
|
||||
.mean(dim=0)
|
||||
.unsqueeze(dim=0)
|
||||
.cpu()
|
||||
.detach()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
input_embeddings = np.concatenate((input_embeddings, target_to_source_input_embedding), axis=0)
|
||||
output_embeddings = np.concatenate((output_embeddings, target_to_source_output_embedding), axis=0)
|
||||
|
||||
source_model = source_model.to(cpu_device)
|
||||
assert isinstance(source_model, LlamaForCausalLM)
|
||||
|
||||
# expand
|
||||
source_model.resize_token_embeddings(new_num_tokens=len(target_vocab))
|
||||
source_model.model.embed_tokens.weight.data = torch.Tensor(input_embeddings)
|
||||
source_model.lm_head.weight.data = torch.Tensor(output_embeddings)
|
||||
|
||||
source_model = source_model.half()
|
||||
source_model.save_pretrained(save_directory=args.target_model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,98 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Initialize new tokenizer for continual pre-training
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
from typing import List, Union
|
||||
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def expand_vocab_tokenizer(
|
||||
source_tokenizer_dir: Union[str, os.PathLike], target_tokenizer_dir: Union[str, os.PathLike], new_tokens: List[str]
|
||||
) -> None:
|
||||
"""Expand tokenizer for continue pre-training."""
|
||||
if os.path.exists(target_tokenizer_dir):
|
||||
raise RuntimeError(f"Find existed directory {target_tokenizer_dir}")
|
||||
|
||||
source_tokenizer = LlamaTokenizer.from_pretrained(source_tokenizer_dir)
|
||||
logger.info(source_tokenizer)
|
||||
source_sp_processor = source_tokenizer.sp_model
|
||||
source_spm = sp_pb2_model.ModelProto()
|
||||
source_spm.ParseFromString(source_sp_processor.serialized_model_proto())
|
||||
|
||||
logger.info(f"Source tokenizer size: {len(source_sp_processor)}")
|
||||
|
||||
# Add new tokens to source tokenizer.
|
||||
source_spm_tokens = set([p.piece for p in source_spm.pieces])
|
||||
for piece in new_tokens:
|
||||
assert isinstance(piece, str), f"Invalid token({piece}) type {type(piece)}"
|
||||
if piece in source_spm_tokens:
|
||||
# Skip existed token.
|
||||
continue
|
||||
new_p = sp_pb2_model.ModelProto().SentencePiece()
|
||||
new_p.piece = piece
|
||||
new_p.score = 0
|
||||
source_spm.pieces.append(new_p)
|
||||
logger.info(f"Expand vocab from {len(source_spm_tokens)} to {len(source_spm.pieces)}")
|
||||
|
||||
# Save
|
||||
os.makedirs(target_tokenizer_dir)
|
||||
target_tokenizer_model_path = os.path.join(target_tokenizer_dir, "tokenizer.model")
|
||||
with open(file=target_tokenizer_model_path, mode="wb") as fp:
|
||||
fp.write(source_spm.SerializeToString())
|
||||
|
||||
target_tokenizer = LlamaTokenizer(vocab_file=target_tokenizer_model_path)
|
||||
target_tokenizer.save_pretrained(save_directory=target_tokenizer_dir)
|
||||
logger.info(f"Successfully save expand tokenizer to {target_tokenizer_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--source_tokenizer_dir", type=str, required=True, default=None, help="Source tokenizer directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_tokenizer_dir", type=str, required=True, default=None, help="Target tokenizer directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expand_tokens_file",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="Path of the file containing tokens to be extended",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
expand_tokens = []
|
||||
with open(file=args.expand_tokens_file, mode="r", encoding="utf-8") as fp_reader:
|
||||
for line in fp_reader:
|
||||
item = json.loads(line)
|
||||
# e.g., {"piece": "你好"}
|
||||
token = item["piece"]
|
||||
if token in expand_tokens:
|
||||
continue
|
||||
expand_tokens.append(token)
|
||||
expand_tokens.sort(key=lambda t: len(t), reverse=False)
|
||||
|
||||
expand_vocab_tokenizer(
|
||||
source_tokenizer_dir=args.source_tokenizer_dir,
|
||||
target_tokenizer_dir=args.target_tokenizer_dir,
|
||||
new_tokens=expand_tokens,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,2 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
|
@ -0,0 +1,88 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Helper functions for IO
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||
"""
|
||||
Load file in JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
|
||||
"""
|
||||
Save as JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="w", encoding="utf-8") as fp:
|
||||
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
save_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
epoch: int,
|
||||
step: int,
|
||||
batch_size: int,
|
||||
coordinator: DistCoordinator,
|
||||
) -> None:
|
||||
"""
|
||||
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
|
||||
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
|
||||
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"sample_start_index": step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
load_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
# Update booster params states.
|
||||
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
|
||||
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||
|
||||
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
|
||||
return (
|
||||
running_states["epoch"],
|
||||
running_states["step"],
|
||||
running_states["sample_start_index"],
|
||||
)
|
|
@ -0,0 +1,216 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaRMSNorm,
|
||||
LlamaAttention,
|
||||
LlamaModel,
|
||||
LlamaForCausalLM,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
)
|
||||
from flash_attn.ops.rms_norm import rms_norm
|
||||
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def _prepare_decoder_attention_mask(
|
||||
self: LlamaModel,
|
||||
attention_mask: torch.BoolTensor,
|
||||
input_shape: torch.Size,
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Decoder attetion mask
|
||||
"""
|
||||
if past_key_values_length > 0 and attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(input_shape[0], past_key_values_length),
|
||||
fill_value=True,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
),
|
||||
attention_mask,
|
||||
),
|
||||
dim=-1,
|
||||
) # (bsz, past_key_values_length + q_len)
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # Faster
|
||||
return attention_mask
|
||||
|
||||
|
||||
def attention_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
||||
"""
|
||||
if output_attentions:
|
||||
logger.warning(
|
||||
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
|
||||
"return `None` instead."
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
q_slicing, kv_slicing = (
|
||||
dim // self.config.pretraining_tp
|
||||
for dim in (
|
||||
self.num_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
)
|
||||
) # `Tuple[int, int]`
|
||||
q_slices, k_slices, v_slices = (
|
||||
proj.weight.split(slicing, dim=0)
|
||||
for proj, slicing in (
|
||||
(self.q_proj, q_slicing),
|
||||
(self.k_proj, kv_slicing),
|
||||
(self.v_proj, kv_slicing),
|
||||
)
|
||||
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
|
||||
q, k, v = (
|
||||
torch.cat(
|
||||
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
|
||||
dim=-1,
|
||||
)
|
||||
for slices in (q_slices, k_slices, v_slices)
|
||||
)
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
else:
|
||||
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
|
||||
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k, v = (
|
||||
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
|
||||
for states, num_heads in (
|
||||
(q, self.num_heads),
|
||||
(k, self.num_key_value_heads),
|
||||
(v, self.num_key_value_heads),
|
||||
)
|
||||
)
|
||||
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
|
||||
past_kv_len = 0
|
||||
if past_key_value is not None:
|
||||
# if `past_key_value` is not None, `kv_len` > `q_len`.
|
||||
past_kv_len = past_key_value[0].shape[-2]
|
||||
kv_len += past_kv_len
|
||||
|
||||
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
|
||||
cos, sin = self.rotary_emb(v, seq_len=kv_len)
|
||||
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
k = torch.cat([past_key_value[0], k], dim=2)
|
||||
v = torch.cat([past_key_value[1], v], dim=2)
|
||||
|
||||
past_key_value = (k, v) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
|
||||
key_padding_mask = attention_mask
|
||||
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
|
||||
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
|
||||
|
||||
if past_kv_len > 0:
|
||||
q = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
|
||||
fill_value=0.0,
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
),
|
||||
q,
|
||||
),
|
||||
dim=1,
|
||||
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
|
||||
if key_padding_mask is None:
|
||||
# (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
|
||||
output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
else:
|
||||
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
|
||||
kv, _, cu_kv_lens, max_kv_len = unpad_input(
|
||||
hidden_states=torch.stack(tensors=(k, v), dim=2),
|
||||
attention_mask=key_padding_mask,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q=q,
|
||||
kv=kv,
|
||||
cu_seqlens_q=cu_q_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_q_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = pad_input(
|
||||
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
|
||||
indices=indices,
|
||||
batch=bsz,
|
||||
seqlen=past_kv_len + q_len,
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
|
||||
if past_kv_len > 0:
|
||||
# Strip off the zero query outputs.
|
||||
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
|
||||
output = self.o_proj(output) # (bsz, q_len, hidden_size)
|
||||
return output, None, past_key_value
|
||||
|
||||
|
||||
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Formard function for RMS Norm
|
||||
"""
|
||||
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
|
||||
|
||||
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.forward = MethodType(attention_forward, module)
|
||||
if isinstance(module, LlamaModel):
|
||||
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.forward = MethodType(rms_norm_forward, module)
|
|
@ -0,0 +1,18 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from transformers.models.llama import LlamaForCausalLM
|
||||
|
||||
|
||||
def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
|
||||
"""Freeze all parameters except embeddings."""
|
||||
for name, params in model.named_parameters():
|
||||
if "embed_tokens" not in name and "lm_head" not in name:
|
||||
params.requires_grad = False
|
||||
else:
|
||||
params.requires_grad = True
|
||||
|
||||
|
||||
def unfreeze_parameters(model: LlamaForCausalLM) -> None:
|
||||
for name, params in model.named_parameters():
|
||||
params.requires_grad = False
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,2 @@
|
|||
hostname1
|
||||
hostname2
|
|
@ -0,0 +1,153 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Prepare dataset for continual pre-training
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from datasets import dataset_dict, load_dataset
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
|
||||
supervised_tokenize,
|
||||
ClosedToConstantLengthSplicedDataset,
|
||||
)
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_input_dirs",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
|
||||
)
|
||||
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
|
||||
parser.add_argument(
|
||||
"--data_jsonl_output_dir",
|
||||
type=str,
|
||||
default="jsonl_output",
|
||||
help="Output directory of spliced dataset with jsonl format",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_arrow_output_dir",
|
||||
type=str,
|
||||
default="arrow_output",
|
||||
help="Output directory of spliced dataset with arrow format",
|
||||
)
|
||||
parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
|
||||
parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.num_spliced_dataset_bins >= 100000:
|
||||
raise ValueError("Too many spliced divisions, must be smaller than 100000")
|
||||
|
||||
assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
|
||||
assert not os.path.exists(
|
||||
args.data_jsonl_output_dir
|
||||
), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
|
||||
assert not os.path.exists(
|
||||
args.data_arrow_output_dir
|
||||
), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
|
||||
os.makedirs(args.data_jsonl_output_dir)
|
||||
os.makedirs(args.data_arrow_output_dir)
|
||||
|
||||
# Prepare to all input datasets
|
||||
input_data_paths = []
|
||||
input_data_dirs = args.data_input_dirs.split(",")
|
||||
for ds_dir in input_data_dirs:
|
||||
ds_dir = os.path.abspath(ds_dir)
|
||||
assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}"
|
||||
ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")]
|
||||
ds_paths = [os.path.join(ds_dir, name) for name in ds_files]
|
||||
input_data_paths.extend(ds_paths)
|
||||
|
||||
# Prepare to data splitting.
|
||||
train_splits = []
|
||||
split_interval = math.ceil(100 / args.num_spliced_dataset_bins)
|
||||
for i in range(0, 100, split_interval):
|
||||
start = i
|
||||
end = i + split_interval
|
||||
if end > 100:
|
||||
end = 100
|
||||
train_splits.append(f"train[{start}%:{end}%]")
|
||||
|
||||
# Prepare to the tokenizer.
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
|
||||
list_dataset = load_dataset(
|
||||
path="json",
|
||||
data_files=input_data_paths,
|
||||
cache_dir=os.path.join(args.data_cache_dir, "raw"),
|
||||
keep_in_memory=False,
|
||||
split=train_splits,
|
||||
num_proc=cpu_count(),
|
||||
)
|
||||
for index, dataset in enumerate(list_dataset):
|
||||
assert isinstance(dataset, dataset_dict.Dataset)
|
||||
logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
|
||||
dataset = dataset.map(
|
||||
function=supervised_tokenize,
|
||||
fn_kwargs={"tokenizer": tokenizer, "max_length": args.max_length},
|
||||
keep_in_memory=False,
|
||||
num_proc=min(len(dataset), cpu_count()),
|
||||
)
|
||||
dataset = dataset.remove_columns(column_names=["source", "target", "category"])
|
||||
dataset = dataset.sort(column_names=("seq_category", "seq_length"), reverse=False, keep_in_memory=False)
|
||||
dataset = dataset.remove_columns(column_names=["seq_category", "seq_length"])
|
||||
spliced_dataset = ClosedToConstantLengthSplicedDataset(
|
||||
dataset=dataset, tokenizer=tokenizer, max_length=args.max_length, error_strict=False
|
||||
)
|
||||
# Save each jsonl spliced dataset.
|
||||
output_index = "0" * (5 - len(str(index))) + str(index)
|
||||
output_name = f"part-{output_index}"
|
||||
output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl")
|
||||
st = time.time()
|
||||
with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer:
|
||||
spliced_count = 0
|
||||
for spliced_data_point in spliced_dataset:
|
||||
if spliced_count % 500 == 0:
|
||||
logger.info(f"processing {spliced_count} spliced data points for {fp_writer.name}")
|
||||
spliced_count += 1
|
||||
fp_writer.write(json.dumps(spliced_data_point, ensure_ascii=False) + "\n")
|
||||
logger.info(
|
||||
f"Current file {fp_writer.name}; "
|
||||
f"Data size: {len(spliced_dataset)}; "
|
||||
f"Spliced data size: {spliced_dataset.current_size}; "
|
||||
f"Splicing compression rate: {round(spliced_dataset.current_size / len(spliced_dataset), 6)}; "
|
||||
f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
|
||||
)
|
||||
|
||||
# Save each arrow spliced dataset
|
||||
output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name)
|
||||
logger.info(f"Start to save {output_arrow_path}")
|
||||
spliced_dataset = load_dataset(
|
||||
path="json",
|
||||
data_files=[output_jsonl_path],
|
||||
cache_dir=os.path.join(args.data_cache_dir, "spliced_and_tokenized"),
|
||||
keep_in_memory=False,
|
||||
num_proc=cpu_count(),
|
||||
split="train",
|
||||
)
|
||||
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,15 @@
|
|||
torch<2.0.0, >=1.12.1
|
||||
packaging==23.1
|
||||
colossalai==0.3.2
|
||||
autoflake==2.2.1
|
||||
black==23.9.1
|
||||
transformers
|
||||
tensorboard==2.14.0
|
||||
six==1.16.0
|
||||
datasets
|
||||
ninja==1.11.1
|
||||
flash-attn>=2.0.0,<=2.0.5
|
||||
tqdm
|
||||
sentencepiece==0.1.99
|
||||
protobuf<=3.20.0
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
|
||||
# NCCL IB environment variables
|
||||
export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
|
||||
export NCCL_IB_DISABLE=0
|
||||
export NCCL_SOCKET_IFNAME=eth0
|
||||
export NCCL_IB_GID_INDEX=3
|
||||
export NCCL_IB_TIMEOUT=23
|
||||
export NCCL_IB_RETRY_CNT=7
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
PROJECT_NAME=""
|
||||
PARENT_SAVE_DIR=""
|
||||
PARENT_TENSORBOARD_DIR=""
|
||||
PARENT_CONFIG_FILE=""
|
||||
PRETRAINED_MODEL_PATH=""
|
||||
|
||||
declare -a dataset=(
|
||||
"PATH TO THE DATASET"
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \
|
||||
--pretrained $PRETRAINED_MODEL_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
--save_interval 400 \
|
||||
--save_dir $SAVE_DIR \
|
||||
--tensorboard_dir $TENSORBOARD_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--num_epochs 1 \
|
||||
--micro_batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 100 \
|
||||
--use_grad_checkpoint \
|
||||
--use_flash_attn \
|
|
@ -0,0 +1,383 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import (
|
||||
GeminiPlugin,
|
||||
LowLevelZeroPlugin,
|
||||
HybridParallelPlugin,
|
||||
)
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from colossal_llama2.dataset.loader import (
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> int:
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f"{numel / B:.2f} B"
|
||||
elif numel >= M:
|
||||
return f"{numel / M:.2f} M"
|
||||
elif numel >= K:
|
||||
return f"{numel / K:.2f} K"
|
||||
else:
|
||||
return f"{numel}"
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pretrained",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Address of the pre-trained modeling",
|
||||
)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
|
||||
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
|
||||
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16"],
|
||||
help="Mixed precision",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument(
|
||||
"--use_grad_checkpoint",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use gradient checkpointing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use flash-attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_non_embeds_params",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Freeze non embeddings parameters",
|
||||
)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--zero", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Tensorboard
|
||||
# ==============================
|
||||
if coordinator.is_master():
|
||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=args.zero,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
||||
# ======================================================
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||
coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
|
||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
|
||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
dataloader = setup_distributed_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=args.micro_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
)
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
||||
# Freeze part of parameters.
|
||||
if args.freeze_non_embeds_params:
|
||||
freeze_non_embeds_parameters(model=model)
|
||||
|
||||
if args.use_grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
if args.use_flash_attn:
|
||||
replace_with_flash_attention(model=model)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
||||
optimizer = HybridAdam(
|
||||
model_params=filter(lambda p: p.requires_grad, model.parameters())
|
||||
if args.freeze_non_embeds_params
|
||||
else model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer,
|
||||
total_steps=args.num_epochs * len(dataloader),
|
||||
warmup_steps=args.warmup_steps
|
||||
if args.warmup_steps is not None
|
||||
else int(args.num_epochs * len(dataloader) * 0.025),
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
# Flash attention will be disabled because it does NOT support fp32.
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=dataloader,
|
||||
)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
if args.load_checkpoint is None:
|
||||
coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
|
||||
booster.load_model(model, args.pretrained, strict=False)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
start_step = 0
|
||||
sampler_start_idx = 0
|
||||
if args.load_checkpoint is not None:
|
||||
if "modeling" in args.load_checkpoint:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
|
||||
booster.load_model(model, args.load_checkpoint)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.load_checkpoint,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
# If resume training, set the sampler start index to the correct value
|
||||
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
|
||||
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch=epoch)
|
||||
with tqdm(
|
||||
iterable=enumerate(dataloader, start=start_step),
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not coordinator.is_master(),
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step,
|
||||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
batch_output = model(**batch)
|
||||
|
||||
loss = batch_output.loss
|
||||
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
all_reduce_mean(tensor=loss)
|
||||
pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
|
||||
if coordinator.is_master():
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
tag="Learning Rate",
|
||||
scalar_value=lr_scheduler.get_last_lr()[0],
|
||||
global_step=global_step,
|
||||
)
|
||||
# Save modeling.
|
||||
|
||||
if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader):
|
||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
save_dir=args.save_dir,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
epoch=epoch,
|
||||
step=step + 1,
|
||||
batch_size=args.micro_batch_size,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||
)
|
||||
|
||||
# Delete CUDA cache.
|
||||
# del batch, batch_labels, batch_output, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(start_index=0)
|
||||
start_step = 0
|
||||
|
||||
# Final save.
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
|
||||
)
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1 @@
|
|||
0.0.1
|
|
@ -4,8 +4,9 @@ This directory contains the applications that are powered by Colossal-AI.
|
|||
|
||||
The list of applications include:
|
||||
|
||||
- [X] [Chatbot](./Chat/README.md)
|
||||
- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters
|
||||
- [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2.
|
||||
- [X] [Chatbot](./Chat/README.md): Replication of ChatGPT with RLHF.
|
||||
- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters.
|
||||
|
||||
> Please note that the `Chatbot` application is migrated from the original `ChatGPT` folder.
|
||||
|
||||
|
|
Loading…
Reference in New Issue