mirror of https://github.com/hpcaitech/ColossalAI
[Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878)
* Add finetuning Colossal-Llama-2 example * Add finetuning Colossal-Llama-2 example 2 * Add finetuning Colossal-Llama-2 example and support NEFTuning * Add inference example and refine neftune * Modify readme file * update the imports --------- Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>pull/5169/head
parent
3dbbf83f1c
commit
b397104438
|
@ -11,7 +11,10 @@
|
||||||
- [Performance Evaluation](#performance-evaluation)
|
- [Performance Evaluation](#performance-evaluation)
|
||||||
- [Examples](#examples)
|
- [Examples](#examples)
|
||||||
- [Training Logs](#training-logs)
|
- [Training Logs](#training-logs)
|
||||||
- [Import from Transformers (Inference)](#import-from-transformers-inference)
|
- [Inference](#inference)
|
||||||
|
- [Import from HuggingFace](#import-from-huggingface)
|
||||||
|
- [Import from Modelscope](#import-from-modelscope)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
- [Usage](#usage)
|
- [Usage](#usage)
|
||||||
- [Install](#install)
|
- [Install](#install)
|
||||||
- [0. Pre-requisite](#0-pre-requisite)
|
- [0. Pre-requisite](#0-pre-requisite)
|
||||||
|
@ -21,8 +24,14 @@
|
||||||
- [1. Init Tokenizer Preparation](#1-init-tokenizer-preparation)
|
- [1. Init Tokenizer Preparation](#1-init-tokenizer-preparation)
|
||||||
- [2. Init Model Preparation](#2-init-model-preparation)
|
- [2. Init Model Preparation](#2-init-model-preparation)
|
||||||
- [3. Data Preparation](#3-data-preparation)
|
- [3. Data Preparation](#3-data-preparation)
|
||||||
|
- [3.1 Data for Pretraining](#31-data-for-pretraining)
|
||||||
|
- [3.2 Data for Supervised Fine-tuning](#32-data-for-supervised-fine-tuning)
|
||||||
- [4. Command Line Arguments for Training](#4-command-line-arguments-for-training)
|
- [4. Command Line Arguments for Training](#4-command-line-arguments-for-training)
|
||||||
|
- [4.1 Arguments for Pretraining](#41-arguments-for-pretraining)
|
||||||
|
- [4.2 Arguments for Supervised Fine-tuning](#42-arguments-for-supervised-fine-tuning)
|
||||||
- [5. Running Command](#5-running-command)
|
- [5. Running Command](#5-running-command)
|
||||||
|
- [5.1 Command for Pretraining](#51-command-for-pretraining)
|
||||||
|
- [5.2 Command for Supervised Fine-tuning](#52-command-for-supervised-fine-tuning)
|
||||||
- [Technical Insights](#technical-insights)
|
- [Technical Insights](#technical-insights)
|
||||||
- [Data](#data)
|
- [Data](#data)
|
||||||
- [Tokenizer](#tokenizer)
|
- [Tokenizer](#tokenizer)
|
||||||
|
@ -117,7 +126,8 @@ We also recorded the training logs for the experiment
|
||||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/trainingLossByTokens.jpeg?raw=true" width=600/>
|
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/trainingLossByTokens.jpeg?raw=true" width=600/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
### Import from Transformers (Inference)
|
### Inference
|
||||||
|
#### Import from HuggingFace
|
||||||
To load Colossal-LLaMA-2-7B-base model using Transformers, use the following code:
|
To load Colossal-LLaMA-2-7B-base model using Transformers, use the following code:
|
||||||
```Python
|
```Python
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
@ -135,6 +145,7 @@ pred = model.generate(**inputs,
|
||||||
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(input):])
|
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(input):])
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Import from Modelscope
|
||||||
You can also load our model using modelscope, use the following code:
|
You can also load our model using modelscope, use the following code:
|
||||||
```Python
|
```Python
|
||||||
from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download
|
from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download
|
||||||
|
@ -153,6 +164,30 @@ print(tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input):])
|
||||||
```
|
```
|
||||||
You can download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) or [👾Modelscope](https://modelscope.cn/models/colossalai/Colossal-LLaMA-2-7b-base/summary).
|
You can download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) or [👾Modelscope](https://modelscope.cn/models/colossalai/Colossal-LLaMA-2-7b-base/summary).
|
||||||
|
|
||||||
|
#### Quick Start
|
||||||
|
You can run [`inference_example.py`](inference_example.py) to quickly start the inference of our base model by loading model weights from HF.
|
||||||
|
|
||||||
|
Command to run the script:
|
||||||
|
```bash
|
||||||
|
python inference_example.py \
|
||||||
|
--model_path "<HF_REPO_NAME_OR_LOCAL_PATH_TO_MODEL>" \
|
||||||
|
--device "cuda:0" \
|
||||||
|
--max_new_tokens 512 \
|
||||||
|
--do_sample True \
|
||||||
|
--temperature 0.3 \
|
||||||
|
--top_k 50 \
|
||||||
|
--top_p 0.95 \
|
||||||
|
--input_txt "YOUR_PROMPT_OR_QUESTION"
|
||||||
|
```
|
||||||
|
Here is details about CLI arguments:
|
||||||
|
* Model path: `--model_path`. HF repo name or local path of the model.
|
||||||
|
* Device: `--device`. Set the device.
|
||||||
|
* Max new tokens: `--max_new_tokens`. Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
||||||
|
* Do sample: `--do_sample`. Set whether or not to use sampling.
|
||||||
|
* Temperature: `--temperature`. Set temperature value.
|
||||||
|
* Top_k: `--top_k`. Set top_k value for top-k-filtering.
|
||||||
|
* Top_p: `--top_p`. Set top_p value for generation.
|
||||||
|
* Input_txt: `--input_txt`. The prompt string input to the model.
|
||||||
## Usage
|
## Usage
|
||||||
### Install
|
### Install
|
||||||
|
|
||||||
|
@ -218,6 +253,8 @@ Here is details about CLI arguments:
|
||||||
❗️**Important**: Once you initialize the new model checkpoint, copy your new tokenizer files (`special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`) to your new model folder.
|
❗️**Important**: Once you initialize the new model checkpoint, copy your new tokenizer files (`special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`) to your new model folder.
|
||||||
|
|
||||||
#### 3. Data Preparation
|
#### 3. Data Preparation
|
||||||
|
|
||||||
|
##### 3.1 Data for Pretraining
|
||||||
Raw data should be formatted as `jsonl` format. Each data point should have the following fields:
|
Raw data should be formatted as `jsonl` format. Each data point should have the following fields:
|
||||||
* `source` (str, compulsory): This part is ignored when calculating loss. Default can be empty.
|
* `source` (str, compulsory): This part is ignored when calculating loss. Default can be empty.
|
||||||
* `target` (str, compulsory): Loss will be calculated.
|
* `target` (str, compulsory): Loss will be calculated.
|
||||||
|
@ -250,7 +287,31 @@ Here is details about CLI arguments:
|
||||||
* Max length: `max_length`. Max length of spliced samples. Default value is 4096.
|
* Max length: `max_length`. Max length of spliced samples. Default value is 4096.
|
||||||
* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training.
|
* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training.
|
||||||
|
|
||||||
|
##### 3.2 Data for Supervised Fine-tuning
|
||||||
|
We prepare data for supervised fine-tuning in a similar way. The main difference lies in the data format. Each data point should have the following field:
|
||||||
|
* `messages` (list, compulsory): This part consists of a conversation between a human and assistant. The length of `messages` can vary and only content from `assistant` is used for calculating loss.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```JSON
|
||||||
|
{"messages": [{"from": "human", "content": "What are the three primary colors?"}, {"from": "assistant", "content": "The three primary colors are red, blue, and yellow."}]}
|
||||||
|
{"messages": [{"from": "human", "content": "解释个人电脑和服务器之间的区别。"}, {"from": "assistant", "content": "个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。"}]}
|
||||||
|
```
|
||||||
|
|
||||||
|
Command to convert jsonl dataset to arrow format is similar to the command in [3.1 Data for Pretraining](#31-data-for-pretraining). In `prepare_sft_dataset.py`, we don't concatenate different data samples.
|
||||||
|
```
|
||||||
|
python prepare_sft_dataset.py.py \
|
||||||
|
--data_input_dirs "<JOSNL_DIR_1>,<JOSNL_DIR_2>,<JOSNL_DIR_3>" \
|
||||||
|
--tokenizer_dir "<TOKENIZER_DIR>" \
|
||||||
|
--data_cache_dir "jsonl_to_arrow_cache" \
|
||||||
|
--data_jsonl_output_dir "spliced_tokenized_output_jsonl" \
|
||||||
|
--data_arrow_output_dir "spliced_tokenized_output_arrow" \
|
||||||
|
--max_length 4096 \
|
||||||
|
--num_spliced_dataset_bins 10
|
||||||
|
```
|
||||||
|
|
||||||
#### 4. Command Line Arguments for Training
|
#### 4. Command Line Arguments for Training
|
||||||
|
|
||||||
|
##### 4.1 Arguments for Pretraining
|
||||||
You can use `colossalai run` to launch multi-nodes training:
|
You can use `colossalai run` to launch multi-nodes training:
|
||||||
```bash
|
```bash
|
||||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
||||||
|
@ -288,7 +349,16 @@ Here is details about CLI arguments:
|
||||||
* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1.
|
* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1.
|
||||||
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1.
|
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1.
|
||||||
|
|
||||||
|
##### 4.2 Arguments for Supervised Fine-tuning
|
||||||
|
We add support for gradient accumulation and NEFTuning for supervised fine-tuning and thus there are two more arguments apart from the arguments listed in [4.1 Arguments for Pretraining](#41-arguments-for-pretraining).
|
||||||
|
|
||||||
|
Here is details about CLI arguments:
|
||||||
|
* Accumulation steps: `--accumulation_steps`. The default value is `8`.
|
||||||
|
* NEFTuning: `--use_neft`. The default value is `False`. It can help improve the performance of chat models.
|
||||||
|
|
||||||
#### 5. Running Command
|
#### 5. Running Command
|
||||||
|
|
||||||
|
##### 5.1 Command for Pretraining
|
||||||
An [example bash](train.example.sh) is also provided for the experiment. Here is the steps to run the experiment:
|
An [example bash](train.example.sh) is also provided for the experiment. Here is the steps to run the experiment:
|
||||||
* Create your own hostfile: `cp hostfile.example hostfile`.
|
* Create your own hostfile: `cp hostfile.example hostfile`.
|
||||||
* Create your own bash: `cp train.example.sh train.sh`.
|
* Create your own bash: `cp train.example.sh train.sh`.
|
||||||
|
@ -310,6 +380,10 @@ declare -a dataset=(
|
||||||
"<DIR_2>/part-00000"
|
"<DIR_2>/part-00000"
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
##### 5.2 Command for Supervised Fine-tuning
|
||||||
|
An [example bash](train_sft.example.sh) is provided. The only difference with the command for pretraining is the two arguments (`--accumulation_steps` and `--use_neft`) in the script. You can refer to [4.2 Arguments for Supervised Fine-tuning](#42-arguments-for-supervised-fine-tuning) for more details.
|
||||||
|
|
||||||
## Technical Insights
|
## Technical Insights
|
||||||
In order to enhance LLaMA-2's capabilities for understanding and generating Chinese content, The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team proposes the continuation of pre-training the LLaMA-2 model using both Chinese and English corpora. The overall pipeline can be described as follows:
|
In order to enhance LLaMA-2's capabilities for understanding and generating Chinese content, The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team proposes the continuation of pre-training the LLaMA-2 model using both Chinese and English corpora. The overall pipeline can be described as follows:
|
||||||
|
|
||||||
|
@ -416,3 +490,11 @@ Applying the above process to perform knowledge transfer in any field allows for
|
||||||
year={2023}
|
year={2023}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
```bibtex
|
||||||
|
@article{jain2023neftune,
|
||||||
|
title={NEFTune: Noisy Embeddings Improve Instruction Finetuning},
|
||||||
|
author={Jain, Neel and Chiang, Ping-yeh and Wen, Yuxin and Kirchenbauer, John and Chu, Hong-Min and Somepalli, Gowthami and Bartoldson, Brian R and Kailkhura, Bhavya and Schwarzschild, Avi and Saha, Aniruddha and others},
|
||||||
|
journal={arXiv preprint arXiv:2310.05914},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright 2023 lm-sys@FastChat
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class SeparatorStyle(Enum):
|
||||||
|
ADD_BOS_EOS_TOKEN = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Conversation:
|
||||||
|
system: str
|
||||||
|
roles: List[str]
|
||||||
|
messages: List[List[str]]
|
||||||
|
offset: int
|
||||||
|
sep_style: SeparatorStyle
|
||||||
|
seps: List[str]
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.messages = []
|
||||||
|
|
||||||
|
def get_prompt(self, length: int = None):
|
||||||
|
if length is None:
|
||||||
|
length = len(self.messages)
|
||||||
|
|
||||||
|
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
|
||||||
|
ret = self.system
|
||||||
|
for role, message in self.messages[0:length]:
|
||||||
|
if message:
|
||||||
|
ret += role + ": " + self.seps[0] + message + self.seps[1]
|
||||||
|
else:
|
||||||
|
ret += role + ": " + self.seps[0]
|
||||||
|
return ret
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||||
|
|
||||||
|
def save_prompt(self):
|
||||||
|
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
|
||||||
|
ret = self.system
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
ret += role + ": " + self.seps[0] + message + self.seps[1] + "\n"
|
||||||
|
else:
|
||||||
|
ret += role + ": " + self.seps[0]
|
||||||
|
return ret
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||||
|
|
||||||
|
def append_message(self, role, message):
|
||||||
|
self.messages.append([role, message])
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return Conversation(
|
||||||
|
system=self.system,
|
||||||
|
roles=self.roles,
|
||||||
|
messages=[[x, y] for x, y in self.messages],
|
||||||
|
offset=self.offset,
|
||||||
|
sep_style=self.sep_style,
|
||||||
|
seps=self.seps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def dict(self):
|
||||||
|
return {
|
||||||
|
"system": self.system,
|
||||||
|
"roles": self.roles,
|
||||||
|
"messages": self.messages,
|
||||||
|
"offset": self.offset,
|
||||||
|
"seps": self.seps,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
conv = Conversation(
|
||||||
|
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||||
|
roles=("Human", "Assistant"),
|
||||||
|
messages=[],
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
|
||||||
|
seps=["<s>", "</s>"],
|
||||||
|
)
|
||||||
|
|
||||||
|
default_conversation = conv
|
|
@ -4,22 +4,29 @@
|
||||||
Splicing multiple pre-tokenized sequence data points
|
Splicing multiple pre-tokenized sequence data points
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import bisect
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datasets import dataset_dict
|
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
|
|
||||||
|
|
||||||
|
from datasets import dataset_dict
|
||||||
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
|
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
|
||||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
from .conversation import Conversation, default_conversation
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||||
|
|
||||||
|
|
||||||
def supervised_tokenize(
|
def supervised_tokenize_pretrain(
|
||||||
data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096
|
data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096
|
||||||
) -> Dict[str, Union[int, str, List[int]]]:
|
) -> Dict[str, Union[int, str, List[int]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -62,6 +69,121 @@ def supervised_tokenize(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def supervised_tokenize_sft(
|
||||||
|
data_point: Dict[str, str],
|
||||||
|
tokenizer: LlamaTokenizer,
|
||||||
|
conversation_template: Conversation = default_conversation,
|
||||||
|
ignore_index: int = None,
|
||||||
|
max_length: int = 4096,
|
||||||
|
) -> Dict[str, Union[int, str, List[int]]]:
|
||||||
|
"""
|
||||||
|
A tokenization function to tokenize an original supervised data point as following:
|
||||||
|
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
|
||||||
|
"""
|
||||||
|
assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
|
||||||
|
"Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
|
||||||
|
"add <bos> and <eos> manually later"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
|
||||||
|
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
|
||||||
|
|
||||||
|
if ignore_index is None:
|
||||||
|
ignore_index = IGNORE_INDEX
|
||||||
|
|
||||||
|
messages = data_point["messages"]
|
||||||
|
template = deepcopy(conversation_template)
|
||||||
|
template.messages = []
|
||||||
|
|
||||||
|
for mess in messages:
|
||||||
|
from_str = mess["from"]
|
||||||
|
if from_str.lower() == "human":
|
||||||
|
from_str = template.roles[0]
|
||||||
|
elif from_str.lower() == "assistant":
|
||||||
|
from_str = template.roles[1]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported role {from_str.lower()}")
|
||||||
|
|
||||||
|
template.append_message(from_str, mess["content"])
|
||||||
|
|
||||||
|
if len(template.messages) % 2 != 0:
|
||||||
|
template.messages = template.messages[0:-1]
|
||||||
|
|
||||||
|
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
|
||||||
|
turns = [i for i in range(1, len(messages) // 2 + 1)]
|
||||||
|
target_turn_index = bisect.bisect_right(
|
||||||
|
turns,
|
||||||
|
max_length - 1,
|
||||||
|
key=lambda x: len(tokenizer([template.get_prompt(2 * x)], add_special_tokens=False)["input_ids"][0]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# The tokenized length for first turn already exceeds `max_length - 1`.
|
||||||
|
if target_turn_index - 1 < 0:
|
||||||
|
return dict(
|
||||||
|
input_ids=None,
|
||||||
|
labels=None,
|
||||||
|
inputs_decode=None,
|
||||||
|
labels_decode=None,
|
||||||
|
seq_length=None,
|
||||||
|
seq_category=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
target_turn = turns[target_turn_index - 1]
|
||||||
|
prompt = template.get_prompt(2 * target_turn)
|
||||||
|
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
|
||||||
|
|
||||||
|
template.messages = template.messages[0 : 2 * target_turn]
|
||||||
|
|
||||||
|
starts = []
|
||||||
|
ends = []
|
||||||
|
gpt_bos = False if template.messages[0][0] == template.roles[0] else True
|
||||||
|
gpt_eos = False if template.messages[0][0] == template.roles[0] else True
|
||||||
|
|
||||||
|
for i, token_id in enumerate(tokenized):
|
||||||
|
if token_id == tokenizer.bos_token_id:
|
||||||
|
if gpt_bos:
|
||||||
|
starts.append(i)
|
||||||
|
gpt_bos = not gpt_bos
|
||||||
|
elif token_id == tokenizer.eos_token_id:
|
||||||
|
if gpt_eos:
|
||||||
|
ends.append(i)
|
||||||
|
gpt_eos = not gpt_eos
|
||||||
|
|
||||||
|
if len(starts) != target_turn or len(ends) != target_turn:
|
||||||
|
logger.info(
|
||||||
|
"Please check whether the tokenizer add additional `bos_token` and `eos_token`.\n\nOr the original message contains `bos_token` or `eos_token`."
|
||||||
|
)
|
||||||
|
return dict(
|
||||||
|
input_ids=None,
|
||||||
|
labels=None,
|
||||||
|
inputs_decode=None,
|
||||||
|
labels_decode=None,
|
||||||
|
seq_length=None,
|
||||||
|
seq_category=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenized = [tokenizer.bos_token_id] + tokenized
|
||||||
|
labels = [ignore_index] * len(tokenized)
|
||||||
|
for start, end in zip(starts, ends):
|
||||||
|
labels[start + 1 : end + 2] = tokenized[start + 1 : end + 2]
|
||||||
|
|
||||||
|
labels_decode = deepcopy(labels)
|
||||||
|
for i, z in enumerate(labels_decode):
|
||||||
|
if z == ignore_index:
|
||||||
|
labels_decode[i] = tokenizer.unk_token_id
|
||||||
|
|
||||||
|
# `inputs_decode` and `labels_decode` can be used to check whether the tokenization method is true.
|
||||||
|
return dict(
|
||||||
|
input_ids=tokenized,
|
||||||
|
labels=labels,
|
||||||
|
inputs_decode=tokenizer.decode(tokenized),
|
||||||
|
labels_decode=tokenizer.decode(labels_decode),
|
||||||
|
seq_length=len(tokenized),
|
||||||
|
seq_category=data_point["category"] if "category" in data_point else "None",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClosedToConstantLengthSplicedDataset(IterableDataset):
|
class ClosedToConstantLengthSplicedDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
|
Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
|
||||||
|
@ -169,12 +291,7 @@ class ClosedToConstantLengthSplicedDataset(IterableDataset):
|
||||||
spliced_labels.extend(seq_labels)
|
spliced_labels.extend(seq_labels)
|
||||||
# For residual spliced data point at the end of the data set
|
# For residual spliced data point at the end of the data set
|
||||||
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
|
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
|
||||||
examples.append(
|
examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels})
|
||||||
{
|
|
||||||
self.input_ids_field: spliced_input_ids,
|
|
||||||
self.labels_field: spliced_labels
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if self.shuffle:
|
if self.shuffle:
|
||||||
random.shuffle(examples)
|
random.shuffle(examples)
|
||||||
for spliced_data_point in examples:
|
for spliced_data_point in examples:
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
# Copyright 2023 The Hugging Face team
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap(model):
|
||||||
|
return model.unwrap().module
|
||||||
|
|
||||||
|
|
||||||
|
def neftune_post_forward_hook(module, input, output):
|
||||||
|
"""
|
||||||
|
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding
|
||||||
|
layers. This method is slightly adapted from the original source code that can be found here:
|
||||||
|
https://github.com/neelsjain/NEFTune Simply add it to your model as follows:
|
||||||
|
```python
|
||||||
|
model = ...
|
||||||
|
model.embed_tokens.neftune_noise_alpha = 0.1
|
||||||
|
model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
|
||||||
|
```
|
||||||
|
Args:
|
||||||
|
module (`torch.nn.Module`):
|
||||||
|
The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to
|
||||||
|
the desired noise alpha value.
|
||||||
|
input (`torch.Tensor`):
|
||||||
|
The input tensor to the model.
|
||||||
|
output (`torch.Tensor`):
|
||||||
|
The output tensor of the model (i.e. the embeddings).
|
||||||
|
"""
|
||||||
|
if module.training:
|
||||||
|
dims = torch.tensor(output.size(1) * output.size(2))
|
||||||
|
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
|
||||||
|
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def activate_neftune(model, neftune_noise_alpha=0.1):
|
||||||
|
r"""
|
||||||
|
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
|
||||||
|
https://arxiv.org/abs/2310.05914
|
||||||
|
"""
|
||||||
|
embeddings = unwrap(model).get_input_embeddings()
|
||||||
|
|
||||||
|
embeddings.neftune_noise_alpha = neftune_noise_alpha
|
||||||
|
hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
|
||||||
|
neftune_hook_handle = hook_handle
|
||||||
|
|
||||||
|
return model, neftune_hook_handle
|
||||||
|
|
||||||
|
|
||||||
|
def deactivate_neftune(model, neftune_hook_handle):
|
||||||
|
"""
|
||||||
|
Deactivates the neftune method. Make sure to call `_activate_neftune` first.
|
||||||
|
"""
|
||||||
|
embeddings = unwrap(model).get_input_embeddings()
|
||||||
|
|
||||||
|
neftune_hook_handle.remove()
|
||||||
|
del embeddings.neftune_noise_alpha
|
|
@ -0,0 +1,57 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path, device="cuda", **kwargs):
|
||||||
|
logger.info(
|
||||||
|
"Please check whether the tokenizer and model weights are properly stored in the same folder."
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
except OSError:
|
||||||
|
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
|
||||||
|
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(args):
|
||||||
|
model, tokenizer = load_model(model_path=args.model_path, device=args.device)
|
||||||
|
|
||||||
|
BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
|
||||||
|
input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
|
||||||
|
|
||||||
|
inputs = tokenizer(args.input_txt, return_tensors='pt').to(args.device)
|
||||||
|
output = model.generate(**inputs,
|
||||||
|
max_new_tokens=args.max_new_tokens,
|
||||||
|
do_sample=args.do_sample,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_k=args.top_k,
|
||||||
|
top_p=args.top_p,
|
||||||
|
num_return_sequences=1)
|
||||||
|
response = tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input_txt):]
|
||||||
|
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.")
|
||||||
|
parser.add_argument('--model_path', type=str, default="hpcai-tech/Colossal-LLaMA-2-7b-base", help="HF repo name or local path of the model")
|
||||||
|
parser.add_argument('--device', type=str, default="cuda:0", help="Set the device")
|
||||||
|
parser.add_argument('--max_new_tokens', type=int, default=512, help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt")
|
||||||
|
parser.add_argument('--do_sample', type=bool, default=True, help="Set whether or not to use sampling")
|
||||||
|
parser.add_argument('--temperature', type=float, default=0.3, help="Set temperature value")
|
||||||
|
parser.add_argument('--top_k', type=int, default=50, help="Set top_k value for top-k-filtering")
|
||||||
|
parser.add_argument('--top_p', type=int, default=0.95, help="Set top_p value for generation")
|
||||||
|
parser.add_argument('--input_txt', type=str, default="明月松间照,", help="The prompt input to the model")
|
||||||
|
args = parser.parse_args()
|
||||||
|
generate(args)
|
|
@ -11,14 +11,14 @@ import os
|
||||||
import time
|
import time
|
||||||
from multiprocessing import cpu_count
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
|
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
|
||||||
|
ClosedToConstantLengthSplicedDataset,
|
||||||
|
supervised_tokenize_pretrain,
|
||||||
|
)
|
||||||
from datasets import dataset_dict, load_dataset
|
from datasets import dataset_dict, load_dataset
|
||||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
|
|
||||||
supervised_tokenize,
|
|
||||||
ClosedToConstantLengthSplicedDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ def main():
|
||||||
assert isinstance(dataset, dataset_dict.Dataset)
|
assert isinstance(dataset, dataset_dict.Dataset)
|
||||||
logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
|
logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
function=supervised_tokenize,
|
function=supervised_tokenize_pretrain,
|
||||||
fn_kwargs={"tokenizer": tokenizer, "max_length": args.max_length},
|
fn_kwargs={"tokenizer": tokenizer, "max_length": args.max_length},
|
||||||
keep_in_memory=False,
|
keep_in_memory=False,
|
||||||
num_proc=min(len(dataset), cpu_count()),
|
num_proc=min(len(dataset), cpu_count()),
|
||||||
|
@ -149,5 +149,5 @@ def main():
|
||||||
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
|
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -0,0 +1,147 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Prepare sft dataset for fine-tuning
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
|
from colossal_llama2.dataset.conversation import default_conversation
|
||||||
|
from colossal_llama2.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
|
||||||
|
from datasets import dataset_dict, load_dataset
|
||||||
|
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||||
|
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_input_dirs",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
default=None,
|
||||||
|
help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
|
||||||
|
)
|
||||||
|
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_jsonl_output_dir",
|
||||||
|
type=str,
|
||||||
|
default="jsonl_output",
|
||||||
|
help="Output directory of spliced dataset with jsonl format",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_arrow_output_dir",
|
||||||
|
type=str,
|
||||||
|
default="arrow_output",
|
||||||
|
help="Output directory of spliced dataset with arrow format",
|
||||||
|
)
|
||||||
|
parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
|
||||||
|
parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.num_spliced_dataset_bins >= 100000:
|
||||||
|
raise ValueError("Too many spliced divisions, must be smaller than 100000")
|
||||||
|
|
||||||
|
assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
|
||||||
|
assert not os.path.exists(
|
||||||
|
args.data_jsonl_output_dir
|
||||||
|
), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
|
||||||
|
assert not os.path.exists(
|
||||||
|
args.data_arrow_output_dir
|
||||||
|
), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
|
||||||
|
os.makedirs(args.data_jsonl_output_dir)
|
||||||
|
os.makedirs(args.data_arrow_output_dir)
|
||||||
|
|
||||||
|
# Prepare to all input datasets
|
||||||
|
input_data_paths = []
|
||||||
|
input_data_dirs = args.data_input_dirs.split(",")
|
||||||
|
for ds_dir in input_data_dirs:
|
||||||
|
ds_dir = os.path.abspath(ds_dir)
|
||||||
|
assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}"
|
||||||
|
ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")]
|
||||||
|
ds_paths = [os.path.join(ds_dir, name) for name in ds_files]
|
||||||
|
input_data_paths.extend(ds_paths)
|
||||||
|
|
||||||
|
# Prepare to data splitting.
|
||||||
|
train_splits = []
|
||||||
|
split_interval = math.ceil(100 / args.num_spliced_dataset_bins)
|
||||||
|
for i in range(0, 100, split_interval):
|
||||||
|
start = i
|
||||||
|
end = i + split_interval
|
||||||
|
if end > 100:
|
||||||
|
end = 100
|
||||||
|
train_splits.append(f"train[{start}%:{end}%]")
|
||||||
|
|
||||||
|
# Prepare to the tokenizer.
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
|
||||||
|
tokenizer.add_bos_token = False
|
||||||
|
tokenizer.add_eos_token = False
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
|
|
||||||
|
list_dataset = load_dataset(
|
||||||
|
path="json",
|
||||||
|
data_files=input_data_paths,
|
||||||
|
cache_dir=os.path.join(args.data_cache_dir, "raw"),
|
||||||
|
keep_in_memory=False,
|
||||||
|
split=train_splits,
|
||||||
|
num_proc=cpu_count(),
|
||||||
|
)
|
||||||
|
for index, dataset in enumerate(list_dataset):
|
||||||
|
assert isinstance(dataset, dataset_dict.Dataset)
|
||||||
|
logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
|
||||||
|
dataset = dataset.map(
|
||||||
|
function=supervised_tokenize_sft,
|
||||||
|
fn_kwargs={
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
"conversation_template": default_conversation,
|
||||||
|
"max_length": args.max_length,
|
||||||
|
},
|
||||||
|
keep_in_memory=False,
|
||||||
|
num_proc=min(len(dataset), cpu_count()),
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = dataset.filter(lambda data: data["labels"] is not None)
|
||||||
|
dataset = dataset.sort(column_names=("seq_category", "seq_length"), reverse=False, keep_in_memory=False)
|
||||||
|
|
||||||
|
# We don't concatenate data samples here.
|
||||||
|
spliced_dataset = dataset
|
||||||
|
# Save each jsonl spliced dataset.
|
||||||
|
output_index = "0" * (5 - len(str(index))) + str(index)
|
||||||
|
output_name = f"part-{output_index}"
|
||||||
|
output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl")
|
||||||
|
# st = time.time()
|
||||||
|
with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer:
|
||||||
|
spliced_count = 0
|
||||||
|
for spliced_data_point in spliced_dataset:
|
||||||
|
if spliced_count % 500 == 0:
|
||||||
|
logger.info(f"processing {spliced_count} spliced data points for {fp_writer.name}")
|
||||||
|
spliced_count += 1
|
||||||
|
fp_writer.write(json.dumps(spliced_data_point, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
# Save each arrow spliced dataset
|
||||||
|
output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name)
|
||||||
|
logger.info(f"Start to save {output_arrow_path}")
|
||||||
|
spliced_dataset = load_dataset(
|
||||||
|
path="json",
|
||||||
|
data_files=[output_jsonl_path],
|
||||||
|
cache_dir=os.path.join(args.data_cache_dir, "spliced_and_tokenized"),
|
||||||
|
keep_in_memory=False,
|
||||||
|
num_proc=cpu_count(),
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,46 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# NCCL IB environment variables
|
||||||
|
export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
|
||||||
|
export NCCL_IB_DISABLE=0
|
||||||
|
export NCCL_SOCKET_IFNAME=eth0
|
||||||
|
export NCCL_IB_GID_INDEX=3
|
||||||
|
export NCCL_IB_TIMEOUT=23
|
||||||
|
export NCCL_IB_RETRY_CNT=7
|
||||||
|
export OMP_NUM_THREADS=8
|
||||||
|
|
||||||
|
PROJECT_NAME=""
|
||||||
|
PARENT_SAVE_DIR=""
|
||||||
|
PARENT_TENSORBOARD_DIR=""
|
||||||
|
PARENT_CONFIG_FILE=""
|
||||||
|
PRETRAINED_MODEL_PATH=""
|
||||||
|
|
||||||
|
declare -a dataset=(
|
||||||
|
"PATH TO THE DATASET"
|
||||||
|
)
|
||||||
|
|
||||||
|
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||||
|
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||||
|
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||||
|
TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
|
||||||
|
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
|
||||||
|
|
||||||
|
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_sft.py \
|
||||||
|
--pretrained $PRETRAINED_MODEL_PATH \
|
||||||
|
--dataset ${dataset[@]} \
|
||||||
|
--plugin "zero2" \
|
||||||
|
--save_interval 400 \
|
||||||
|
--save_dir $SAVE_DIR \
|
||||||
|
--tensorboard_dir $TENSORBOARD_DIR \
|
||||||
|
--config_file $CONFIG_FILE \
|
||||||
|
--num_epochs 1 \
|
||||||
|
--accumulation_steps 8 \
|
||||||
|
--micro_batch_size 8 \
|
||||||
|
--lr 5e-5 \
|
||||||
|
--mixed_precision "bf16" \
|
||||||
|
--grad_clip 1.0 \
|
||||||
|
--weight_decay 0.01 \
|
||||||
|
--warmup_steps 100 \
|
||||||
|
--use_grad_checkpoint \
|
||||||
|
--use_flash_attn \
|
||||||
|
--use_neft \
|
|
@ -0,0 +1,403 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import resource
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossal_llama2.dataset.loader import (
|
||||||
|
DataCollatorForSupervisedDataset,
|
||||||
|
StatefulDistributedSampler,
|
||||||
|
load_tokenized_dataset,
|
||||||
|
setup_distributed_dataloader,
|
||||||
|
)
|
||||||
|
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||||
|
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||||
|
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||||
|
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_numel(model: torch.nn.Module) -> int:
|
||||||
|
return sum(p.numel() for p in model.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
def format_numel_str(numel: int) -> str:
|
||||||
|
B = 1024**3
|
||||||
|
M = 1024**2
|
||||||
|
K = 1024
|
||||||
|
if numel >= B:
|
||||||
|
return f"{numel / B:.2f} B"
|
||||||
|
elif numel >= M:
|
||||||
|
return f"{numel / M:.2f} M"
|
||||||
|
elif numel >= K:
|
||||||
|
return f"{numel / K:.2f} K"
|
||||||
|
else:
|
||||||
|
return f"{numel}"
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||||
|
tensor.div_(dist.get_world_size())
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
# ==============================
|
||||||
|
# Parse Arguments
|
||||||
|
# ==============================
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Address of the pre-trained modeling",
|
||||||
|
)
|
||||||
|
parser.add_argument("--dataset", nargs="+", default=[])
|
||||||
|
parser.add_argument(
|
||||||
|
"--plugin",
|
||||||
|
type=str,
|
||||||
|
default="gemini",
|
||||||
|
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||||
|
help="Choose which plugin to use",
|
||||||
|
)
|
||||||
|
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
|
||||||
|
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
|
||||||
|
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
|
||||||
|
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
||||||
|
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||||
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=8, help="Number of accumulation steps")
|
||||||
|
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
||||||
|
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||||
|
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mixed_precision",
|
||||||
|
type=str,
|
||||||
|
default="fp16",
|
||||||
|
choices=["fp16", "bf16"],
|
||||||
|
help="Mixed precision",
|
||||||
|
)
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||||
|
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_grad_checkpoint",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use gradient checkpointing",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_flash_attn",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use flash-attention",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_neft",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use NEFTune",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--freeze_non_embeds_params",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Freeze non embeddings parameters",
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp", type=int, default=1)
|
||||||
|
parser.add_argument("--zero", type=int, default=1)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with open(args.config_file, "w") as f:
|
||||||
|
json.dump(args.__dict__, f, indent=4)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Distributed Training
|
||||||
|
# ==============================
|
||||||
|
colossalai.launch_from_torch({})
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Tensorboard
|
||||||
|
# ==============================
|
||||||
|
if coordinator.is_master():
|
||||||
|
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||||
|
writer = SummaryWriter(args.tensorboard_dir)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Booster
|
||||||
|
# ==============================
|
||||||
|
if args.plugin == "gemini":
|
||||||
|
plugin = GeminiPlugin(
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
initial_scale=2**16,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
)
|
||||||
|
elif args.plugin == "gemini_auto":
|
||||||
|
plugin = GeminiPlugin(
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
placement_policy="auto",
|
||||||
|
initial_scale=2**16,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
)
|
||||||
|
elif args.plugin == "zero2":
|
||||||
|
plugin = LowLevelZeroPlugin(
|
||||||
|
stage=2,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
initial_scale=2**16,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
)
|
||||||
|
elif args.plugin == "zero2_cpu":
|
||||||
|
plugin = LowLevelZeroPlugin(
|
||||||
|
stage=2,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
initial_scale=2**16,
|
||||||
|
cpu_offload=True,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
)
|
||||||
|
elif args.plugin == "3d":
|
||||||
|
plugin = HybridParallelPlugin(
|
||||||
|
tp_size=args.tp,
|
||||||
|
pp_size=1,
|
||||||
|
zero_stage=args.zero,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
# ======================================================
|
||||||
|
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
||||||
|
# ======================================================
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.add_bos_token = False
|
||||||
|
tokenizer.add_eos_token = False
|
||||||
|
|
||||||
|
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||||
|
coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
|
||||||
|
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
|
||||||
|
|
||||||
|
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||||
|
|
||||||
|
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||||
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||||
|
dataloader = setup_distributed_dataloader(
|
||||||
|
dataset=dataset,
|
||||||
|
batch_size=args.micro_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=data_collator,
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ======================================================
|
||||||
|
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||||
|
# ======================================================
|
||||||
|
init_ctx = (
|
||||||
|
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||||
|
)
|
||||||
|
with init_ctx:
|
||||||
|
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
||||||
|
# Freeze part of parameters.
|
||||||
|
if args.freeze_non_embeds_params:
|
||||||
|
freeze_non_embeds_parameters(model=model)
|
||||||
|
|
||||||
|
if args.use_grad_checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
|
if args.use_flash_attn:
|
||||||
|
replace_with_flash_attention(model=model)
|
||||||
|
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||||
|
|
||||||
|
model_numel = get_model_numel(model)
|
||||||
|
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||||
|
|
||||||
|
optimizer = HybridAdam(
|
||||||
|
model_params=filter(lambda p: p.requires_grad, model.parameters())
|
||||||
|
if args.freeze_non_embeds_params
|
||||||
|
else model.parameters(),
|
||||||
|
lr=args.lr,
|
||||||
|
betas=(0.9, 0.95),
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
adamw_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.warmup_steps is None:
|
||||||
|
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
|
||||||
|
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||||
|
|
||||||
|
lr_scheduler = CosineAnnealingWarmupLR(
|
||||||
|
optimizer=optimizer,
|
||||||
|
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
|
||||||
|
warmup_steps=args.warmup_steps,
|
||||||
|
eta_min=0.1 * args.lr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flash attention will be disabled because it does NOT support fp32.
|
||||||
|
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||||
|
torch.set_default_dtype(default_dtype)
|
||||||
|
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
dataloader=dataloader,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.set_default_dtype(torch.float)
|
||||||
|
|
||||||
|
if args.load_checkpoint is None:
|
||||||
|
coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
|
||||||
|
booster.load_model(model, args.pretrained, strict=False)
|
||||||
|
|
||||||
|
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
start_epoch = 0
|
||||||
|
start_step = 0
|
||||||
|
sampler_start_idx = 0
|
||||||
|
if args.load_checkpoint is not None:
|
||||||
|
if "modeling" in args.load_checkpoint:
|
||||||
|
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
|
||||||
|
booster.load_model(model, args.load_checkpoint)
|
||||||
|
else:
|
||||||
|
coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
|
||||||
|
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||||
|
load_dir=args.load_checkpoint,
|
||||||
|
booster=booster,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||||
|
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.use_neft:
|
||||||
|
coordinator.print_on_master("Activate NEFTune.")
|
||||||
|
model, handle = activate_neftune(model)
|
||||||
|
|
||||||
|
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
|
||||||
|
# If resume training, set the sampler start index to the correct value
|
||||||
|
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
|
||||||
|
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, args.num_epochs):
|
||||||
|
dataloader.sampler.set_epoch(epoch=epoch)
|
||||||
|
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
|
||||||
|
total_loss = torch.tensor(0.0).to(torch.cuda.current_device())
|
||||||
|
for step, batch in enumerate(dataloader):
|
||||||
|
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||||
|
|
||||||
|
batch_output = model(**batch)
|
||||||
|
|
||||||
|
loss = batch_output.loss / args.accumulation_steps
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
booster.backward(loss=loss, optimizer=optimizer)
|
||||||
|
|
||||||
|
if (step + 1) % args.accumulation_steps == 0:
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
all_reduce_mean(tensor=total_loss)
|
||||||
|
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
|
||||||
|
if coordinator.is_master():
|
||||||
|
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||||
|
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
tag="Learning Rate",
|
||||||
|
scalar_value=lr_scheduler.get_last_lr()[0],
|
||||||
|
global_step=global_step,
|
||||||
|
)
|
||||||
|
total_loss.fill_(0.0)
|
||||||
|
pbar.update()
|
||||||
|
# Save modeling.
|
||||||
|
|
||||||
|
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
|
||||||
|
step + 1
|
||||||
|
) == len(dataloader):
|
||||||
|
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||||
|
|
||||||
|
if args.use_neft:
|
||||||
|
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
||||||
|
deactivate_neftune(model, handle)
|
||||||
|
|
||||||
|
save_checkpoint(
|
||||||
|
save_dir=args.save_dir,
|
||||||
|
booster=booster,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
epoch=epoch,
|
||||||
|
step=step + 1,
|
||||||
|
batch_size=args.micro_batch_size,
|
||||||
|
coordinator=coordinator,
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.use_neft:
|
||||||
|
coordinator.print_on_master("Activate NEFTune.")
|
||||||
|
model, handle = activate_neftune(model)
|
||||||
|
|
||||||
|
# Delete CUDA cache.
|
||||||
|
# del batch, batch_labels, batch_output, loss
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||||
|
dataloader.sampler.set_start_index(start_index=0)
|
||||||
|
start_step = 0
|
||||||
|
|
||||||
|
if args.use_neft:
|
||||||
|
coordinator.print_on_master("Deactivate NEFTune.")
|
||||||
|
deactivate_neftune(model, handle)
|
||||||
|
|
||||||
|
# Final save.
|
||||||
|
coordinator.print_on_master("Start saving final model checkpoint")
|
||||||
|
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||||
|
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||||
|
|
||||||
|
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue