[Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878)

* Add finetuning Colossal-Llama-2 example

* Add finetuning Colossal-Llama-2 example 2

* Add finetuning Colossal-Llama-2 example and support NEFTuning

* Add inference example and refine neftune

* Modify readme file

* update the imports

---------

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
pull/5169/head
Yuanchen 2023-12-07 14:02:03 +08:00 committed by GitHub
parent 3dbbf83f1c
commit b397104438
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1036 additions and 19 deletions

View File

@ -11,7 +11,10 @@
- [Performance Evaluation](#performance-evaluation)
- [Examples](#examples)
- [Training Logs](#training-logs)
- [Import from Transformers (Inference)](#import-from-transformers-inference)
- [Inference](#inference)
- [Import from HuggingFace](#import-from-huggingface)
- [Import from Modelscope](#import-from-modelscope)
- [Quick Start](#quick-start)
- [Usage](#usage)
- [Install](#install)
- [0. Pre-requisite](#0-pre-requisite)
@ -21,8 +24,14 @@
- [1. Init Tokenizer Preparation](#1-init-tokenizer-preparation)
- [2. Init Model Preparation](#2-init-model-preparation)
- [3. Data Preparation](#3-data-preparation)
- [3.1 Data for Pretraining](#31-data-for-pretraining)
- [3.2 Data for Supervised Fine-tuning](#32-data-for-supervised-fine-tuning)
- [4. Command Line Arguments for Training](#4-command-line-arguments-for-training)
- [4.1 Arguments for Pretraining](#41-arguments-for-pretraining)
- [4.2 Arguments for Supervised Fine-tuning](#42-arguments-for-supervised-fine-tuning)
- [5. Running Command](#5-running-command)
- [5.1 Command for Pretraining](#51-command-for-pretraining)
- [5.2 Command for Supervised Fine-tuning](#52-command-for-supervised-fine-tuning)
- [Technical Insights](#technical-insights)
- [Data](#data)
- [Tokenizer](#tokenizer)
@ -117,7 +126,8 @@ We also recorded the training logs for the experiment
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/trainingLossByTokens.jpeg?raw=true" width=600/>
</p>
### Import from Transformers (Inference)
### Inference
#### Import from HuggingFace
To load Colossal-LLaMA-2-7B-base model using Transformers, use the following code:
```Python
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -135,6 +145,7 @@ pred = model.generate(**inputs,
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(input):])
```
#### Import from Modelscope
You can also load our model using modelscope, use the following code:
```Python
from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download
@ -153,6 +164,30 @@ print(tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input):])
```
You can download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) or [👾Modelscope](https://modelscope.cn/models/colossalai/Colossal-LLaMA-2-7b-base/summary).
#### Quick Start
You can run [`inference_example.py`](inference_example.py) to quickly start the inference of our base model by loading model weights from HF.
Command to run the script:
```bash
python inference_example.py \
--model_path "<HF_REPO_NAME_OR_LOCAL_PATH_TO_MODEL>" \
--device "cuda:0" \
--max_new_tokens 512 \
--do_sample True \
--temperature 0.3 \
--top_k 50 \
--top_p 0.95 \
--input_txt "YOUR_PROMPT_OR_QUESTION"
```
Here is details about CLI arguments:
* Model path: `--model_path`. HF repo name or local path of the model.
* Device: `--device`. Set the device.
* Max new tokens: `--max_new_tokens`. Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
* Do sample: `--do_sample`. Set whether or not to use sampling.
* Temperature: `--temperature`. Set temperature value.
* Top_k: `--top_k`. Set top_k value for top-k-filtering.
* Top_p: `--top_p`. Set top_p value for generation.
* Input_txt: `--input_txt`. The prompt string input to the model.
## Usage
### Install
@ -218,6 +253,8 @@ Here is details about CLI arguments:
❗️**Important**: Once you initialize the new model checkpoint, copy your new tokenizer files (`special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`) to your new model folder.
#### 3. Data Preparation
##### 3.1 Data for Pretraining
Raw data should be formatted as `jsonl` format. Each data point should have the following fields:
* `source` (str, compulsory): This part is ignored when calculating loss. Default can be empty.
* `target` (str, compulsory): Loss will be calculated.
@ -250,7 +287,31 @@ Here is details about CLI arguments:
* Max length: `max_length`. Max length of spliced samples. Default value is 4096.
* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training.
##### 3.2 Data for Supervised Fine-tuning
We prepare data for supervised fine-tuning in a similar way. The main difference lies in the data format. Each data point should have the following field:
* `messages` (list, compulsory): This part consists of a conversation between a human and assistant. The length of `messages` can vary and only content from `assistant` is used for calculating loss.
Examples:
```JSON
{"messages": [{"from": "human", "content": "What are the three primary colors?"}, {"from": "assistant", "content": "The three primary colors are red, blue, and yellow."}]}
{"messages": [{"from": "human", "content": "解释个人电脑和服务器之间的区别。"}, {"from": "assistant", "content": "个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。"}]}
```
Command to convert jsonl dataset to arrow format is similar to the command in [3.1 Data for Pretraining](#31-data-for-pretraining). In `prepare_sft_dataset.py`, we don't concatenate different data samples.
```
python prepare_sft_dataset.py.py \
--data_input_dirs "<JOSNL_DIR_1>,<JOSNL_DIR_2>,<JOSNL_DIR_3>" \
--tokenizer_dir "<TOKENIZER_DIR>" \
--data_cache_dir "jsonl_to_arrow_cache" \
--data_jsonl_output_dir "spliced_tokenized_output_jsonl" \
--data_arrow_output_dir "spliced_tokenized_output_arrow" \
--max_length 4096 \
--num_spliced_dataset_bins 10
```
#### 4. Command Line Arguments for Training
##### 4.1 Arguments for Pretraining
You can use `colossalai run` to launch multi-nodes training:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
@ -288,7 +349,16 @@ Here is details about CLI arguments:
* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1.
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1.
##### 4.2 Arguments for Supervised Fine-tuning
We add support for gradient accumulation and NEFTuning for supervised fine-tuning and thus there are two more arguments apart from the arguments listed in [4.1 Arguments for Pretraining](#41-arguments-for-pretraining).
Here is details about CLI arguments:
* Accumulation steps: `--accumulation_steps`. The default value is `8`.
* NEFTuning: `--use_neft`. The default value is `False`. It can help improve the performance of chat models.
#### 5. Running Command
##### 5.1 Command for Pretraining
An [example bash](train.example.sh) is also provided for the experiment. Here is the steps to run the experiment:
* Create your own hostfile: `cp hostfile.example hostfile`.
* Create your own bash: `cp train.example.sh train.sh`.
@ -310,6 +380,10 @@ declare -a dataset=(
"<DIR_2>/part-00000"
)
```
##### 5.2 Command for Supervised Fine-tuning
An [example bash](train_sft.example.sh) is provided. The only difference with the command for pretraining is the two arguments (`--accumulation_steps` and `--use_neft`) in the script. You can refer to [4.2 Arguments for Supervised Fine-tuning](#42-arguments-for-supervised-fine-tuning) for more details.
## Technical Insights
In order to enhance LLaMA-2's capabilities for understanding and generating Chinese content, The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team proposes the continuation of pre-training the LLaMA-2 model using both Chinese and English corpora. The overall pipeline can be described as follows:
@ -416,3 +490,11 @@ Applying the above process to perform knowledge transfer in any field allows for
year={2023}
}
```
```bibtex
@article{jain2023neftune,
title={NEFTune: Noisy Embeddings Improve Instruction Finetuning},
author={Jain, Neel and Chiang, Ping-yeh and Wen, Yuxin and Kirchenbauer, John and Chu, Hong-Min and Somepalli, Gowthami and Bartoldson, Brian R and Kailkhura, Bhavya and Schwarzschild, Avi and Saha, Aniruddha and others},
journal={arXiv preprint arXiv:2310.05914},
year={2023}
}
```

View File

@ -0,0 +1,96 @@
# Copyright 2023 lm-sys@FastChat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from enum import Enum, auto
from typing import List
class SeparatorStyle(Enum):
ADD_BOS_EOS_TOKEN = auto()
@dataclasses.dataclass
class Conversation:
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle
seps: List[str]
def clear(self):
self.messages = []
def get_prompt(self, length: int = None):
if length is None:
length = len(self.messages)
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
ret = self.system
for role, message in self.messages[0:length]:
if message:
ret += role + ": " + self.seps[0] + message + self.seps[1]
else:
ret += role + ": " + self.seps[0]
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def save_prompt(self):
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
ret = self.system
for role, message in self.messages:
if message:
ret += role + ": " + self.seps[0] + message + self.seps[1] + "\n"
else:
ret += role + ": " + self.seps[0]
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
seps=self.seps,
)
def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"seps": self.seps,
}
conv = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
messages=[],
offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<s>", "</s>"],
)
default_conversation = conv

View File

@ -4,22 +4,29 @@
Splicing multiple pre-tokenized sequence data points
"""
import bisect
import random
import warnings
from copy import deepcopy
from datasets import dataset_dict
from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from datasets import dataset_dict
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from colossalai.logging import get_dist_logger
from .conversation import Conversation, default_conversation
logger = get_dist_logger()
IGNORE_INDEX = -100
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
def supervised_tokenize(
def supervised_tokenize_pretrain(
data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096
) -> Dict[str, Union[int, str, List[int]]]:
"""
@ -62,6 +69,121 @@ def supervised_tokenize(
)
def supervised_tokenize_sft(
data_point: Dict[str, str],
tokenizer: LlamaTokenizer,
conversation_template: Conversation = default_conversation,
ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original supervised data point as following:
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
"""
assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
"Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
"add <bos> and <eos> manually later"
)
assert (
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
if ignore_index is None:
ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
for mess in messages:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = template.roles[0]
elif from_str.lower() == "assistant":
from_str = template.roles[1]
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
template.append_message(from_str, mess["content"])
if len(template.messages) % 2 != 0:
template.messages = template.messages[0:-1]
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
turns = [i for i in range(1, len(messages) // 2 + 1)]
target_turn_index = bisect.bisect_right(
turns,
max_length - 1,
key=lambda x: len(tokenizer([template.get_prompt(2 * x)], add_special_tokens=False)["input_ids"][0]),
)
# The tokenized length for first turn already exceeds `max_length - 1`.
if target_turn_index - 1 < 0:
return dict(
input_ids=None,
labels=None,
inputs_decode=None,
labels_decode=None,
seq_length=None,
seq_category=None,
)
target_turn = turns[target_turn_index - 1]
prompt = template.get_prompt(2 * target_turn)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
template.messages = template.messages[0 : 2 * target_turn]
starts = []
ends = []
gpt_bos = False if template.messages[0][0] == template.roles[0] else True
gpt_eos = False if template.messages[0][0] == template.roles[0] else True
for i, token_id in enumerate(tokenized):
if token_id == tokenizer.bos_token_id:
if gpt_bos:
starts.append(i)
gpt_bos = not gpt_bos
elif token_id == tokenizer.eos_token_id:
if gpt_eos:
ends.append(i)
gpt_eos = not gpt_eos
if len(starts) != target_turn or len(ends) != target_turn:
logger.info(
"Please check whether the tokenizer add additional `bos_token` and `eos_token`.\n\nOr the original message contains `bos_token` or `eos_token`."
)
return dict(
input_ids=None,
labels=None,
inputs_decode=None,
labels_decode=None,
seq_length=None,
seq_category=None,
)
tokenized = [tokenizer.bos_token_id] + tokenized
labels = [ignore_index] * len(tokenized)
for start, end in zip(starts, ends):
labels[start + 1 : end + 2] = tokenized[start + 1 : end + 2]
labels_decode = deepcopy(labels)
for i, z in enumerate(labels_decode):
if z == ignore_index:
labels_decode[i] = tokenizer.unk_token_id
# `inputs_decode` and `labels_decode` can be used to check whether the tokenization method is true.
return dict(
input_ids=tokenized,
labels=labels,
inputs_decode=tokenizer.decode(tokenized),
labels_decode=tokenizer.decode(labels_decode),
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
class ClosedToConstantLengthSplicedDataset(IterableDataset):
"""
Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
@ -169,12 +291,7 @@ class ClosedToConstantLengthSplicedDataset(IterableDataset):
spliced_labels.extend(seq_labels)
# For residual spliced data point at the end of the data set
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
examples.append(
{
self.input_ids_field: spliced_input_ids,
self.labels_field: spliced_labels
}
)
examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels})
if self.shuffle:
random.shuffle(examples)
for spliced_data_point in examples:

View File

@ -0,0 +1,69 @@
# Copyright 2023 The Hugging Face team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def unwrap(model):
return model.unwrap().module
def neftune_post_forward_hook(module, input, output):
"""
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding
layers. This method is slightly adapted from the original source code that can be found here:
https://github.com/neelsjain/NEFTune Simply add it to your model as follows:
```python
model = ...
model.embed_tokens.neftune_noise_alpha = 0.1
model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
```
Args:
module (`torch.nn.Module`):
The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to
the desired noise alpha value.
input (`torch.Tensor`):
The input tensor to the model.
output (`torch.Tensor`):
The output tensor of the model (i.e. the embeddings).
"""
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
def activate_neftune(model, neftune_noise_alpha=0.1):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
https://arxiv.org/abs/2310.05914
"""
embeddings = unwrap(model).get_input_embeddings()
embeddings.neftune_noise_alpha = neftune_noise_alpha
hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
neftune_hook_handle = hook_handle
return model, neftune_hook_handle
def deactivate_neftune(model, neftune_hook_handle):
"""
Deactivates the neftune method. Make sure to call `_activate_neftune` first.
"""
embeddings = unwrap(model).get_input_embeddings()
neftune_hook_handle.remove()
del embeddings.neftune_noise_alpha

View File

@ -0,0 +1,57 @@
import argparse
import os
import torch
from colossalai.logging import get_dist_logger
from transformers import AutoTokenizer, AutoModelForCausalLM
logger = get_dist_logger()
def load_model(model_path, device="cuda", **kwargs):
logger.info(
"Please check whether the tokenizer and model weights are properly stored in the same folder."
)
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
model.to(device)
try:
tokenizer = AutoTokenizer.from_pretrained(model_path)
except OSError:
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
return model, tokenizer
@torch.inference_mode()
def generate(args):
model, tokenizer = load_model(model_path=args.model_path, device=args.device)
BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
inputs = tokenizer(args.input_txt, return_tensors='pt').to(args.device)
output = model.generate(**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
num_return_sequences=1)
response = tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input_txt):]
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
return response
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.")
parser.add_argument('--model_path', type=str, default="hpcai-tech/Colossal-LLaMA-2-7b-base", help="HF repo name or local path of the model")
parser.add_argument('--device', type=str, default="cuda:0", help="Set the device")
parser.add_argument('--max_new_tokens', type=int, default=512, help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt")
parser.add_argument('--do_sample', type=bool, default=True, help="Set whether or not to use sampling")
parser.add_argument('--temperature', type=float, default=0.3, help="Set temperature value")
parser.add_argument('--top_k', type=int, default=50, help="Set top_k value for top-k-filtering")
parser.add_argument('--top_p', type=int, default=0.95, help="Set top_p value for generation")
parser.add_argument('--input_txt', type=str, default="明月松间照,", help="The prompt input to the model")
args = parser.parse_args()
generate(args)

View File

@ -11,14 +11,14 @@ import os
import time
from multiprocessing import cpu_count
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
ClosedToConstantLengthSplicedDataset,
supervised_tokenize_pretrain,
)
from datasets import dataset_dict, load_dataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from colossalai.logging import get_dist_logger
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
supervised_tokenize,
ClosedToConstantLengthSplicedDataset,
)
logger = get_dist_logger()
@ -104,7 +104,7 @@ def main():
assert isinstance(dataset, dataset_dict.Dataset)
logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
dataset = dataset.map(
function=supervised_tokenize,
function=supervised_tokenize_pretrain,
fn_kwargs={"tokenizer": tokenizer, "max_length": args.max_length},
keep_in_memory=False,
num_proc=min(len(dataset), cpu_count()),
@ -149,5 +149,5 @@ def main():
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -0,0 +1,147 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Prepare sft dataset for fine-tuning
"""
import argparse
import json
import math
import os
from multiprocessing import cpu_count
from colossal_llama2.dataset.conversation import default_conversation
from colossal_llama2.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_input_dirs",
type=str,
required=True,
default=None,
help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.",
)
parser.add_argument(
"--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
)
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
parser.add_argument(
"--data_jsonl_output_dir",
type=str,
default="jsonl_output",
help="Output directory of spliced dataset with jsonl format",
)
parser.add_argument(
"--data_arrow_output_dir",
type=str,
default="arrow_output",
help="Output directory of spliced dataset with arrow format",
)
parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
args = parser.parse_args()
if args.num_spliced_dataset_bins >= 100000:
raise ValueError("Too many spliced divisions, must be smaller than 100000")
assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
assert not os.path.exists(
args.data_jsonl_output_dir
), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
assert not os.path.exists(
args.data_arrow_output_dir
), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
os.makedirs(args.data_jsonl_output_dir)
os.makedirs(args.data_arrow_output_dir)
# Prepare to all input datasets
input_data_paths = []
input_data_dirs = args.data_input_dirs.split(",")
for ds_dir in input_data_dirs:
ds_dir = os.path.abspath(ds_dir)
assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}"
ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")]
ds_paths = [os.path.join(ds_dir, name) for name in ds_files]
input_data_paths.extend(ds_paths)
# Prepare to data splitting.
train_splits = []
split_interval = math.ceil(100 / args.num_spliced_dataset_bins)
for i in range(0, 100, split_interval):
start = i
end = i + split_interval
if end > 100:
end = 100
train_splits.append(f"train[{start}%:{end}%]")
# Prepare to the tokenizer.
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
list_dataset = load_dataset(
path="json",
data_files=input_data_paths,
cache_dir=os.path.join(args.data_cache_dir, "raw"),
keep_in_memory=False,
split=train_splits,
num_proc=cpu_count(),
)
for index, dataset in enumerate(list_dataset):
assert isinstance(dataset, dataset_dict.Dataset)
logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
dataset = dataset.map(
function=supervised_tokenize_sft,
fn_kwargs={
"tokenizer": tokenizer,
"conversation_template": default_conversation,
"max_length": args.max_length,
},
keep_in_memory=False,
num_proc=min(len(dataset), cpu_count()),
)
dataset = dataset.filter(lambda data: data["labels"] is not None)
dataset = dataset.sort(column_names=("seq_category", "seq_length"), reverse=False, keep_in_memory=False)
# We don't concatenate data samples here.
spliced_dataset = dataset
# Save each jsonl spliced dataset.
output_index = "0" * (5 - len(str(index))) + str(index)
output_name = f"part-{output_index}"
output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl")
# st = time.time()
with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer:
spliced_count = 0
for spliced_data_point in spliced_dataset:
if spliced_count % 500 == 0:
logger.info(f"processing {spliced_count} spliced data points for {fp_writer.name}")
spliced_count += 1
fp_writer.write(json.dumps(spliced_data_point, ensure_ascii=False) + "\n")
# Save each arrow spliced dataset
output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name)
logger.info(f"Start to save {output_arrow_path}")
spliced_dataset = load_dataset(
path="json",
data_files=[output_jsonl_path],
cache_dir=os.path.join(args.data_cache_dir, "spliced_and_tokenized"),
keep_in_memory=False,
num_proc=cpu_count(),
split="train",
)
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,46 @@
#!/bin/bash
# NCCL IB environment variables
export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_GID_INDEX=3
export NCCL_IB_TIMEOUT=23
export NCCL_IB_RETRY_CNT=7
export OMP_NUM_THREADS=8
PROJECT_NAME=""
PARENT_SAVE_DIR=""
PARENT_TENSORBOARD_DIR=""
PARENT_CONFIG_FILE=""
PRETRAINED_MODEL_PATH=""
declare -a dataset=(
"PATH TO THE DATASET"
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_sft.py \
--pretrained $PRETRAINED_MODEL_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
--save_interval 400 \
--save_dir $SAVE_DIR \
--tensorboard_dir $TENSORBOARD_DIR \
--config_file $CONFIG_FILE \
--num_epochs 1 \
--accumulation_steps 8 \
--micro_batch_size 8 \
--lr 5e-5 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--weight_decay 0.01 \
--warmup_steps 100 \
--use_grad_checkpoint \
--use_flash_attn \
--use_neft \

View File

@ -0,0 +1,403 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team
"""
import argparse
import json
import os
import resource
from contextlib import nullcontext
import torch
import torch.distributed as dist
from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def get_model_numel(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f"{numel / B:.2f} B"
elif numel >= M:
return f"{numel / M:.2f} M"
elif numel >= K:
return f"{numel / K:.2f} K"
else:
return f"{numel}"
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor
def main() -> None:
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained",
type=str,
default=None,
help="Address of the pre-trained modeling",
)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
help="Choose which plugin to use",
)
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--accumulation_steps", type=int, default=8, help="Number of accumulation steps")
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
parser.add_argument(
"--mixed_precision",
type=str,
default="fp16",
choices=["fp16", "bf16"],
help="Mixed precision",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument(
"--use_grad_checkpoint",
action="store_true",
default=False,
help="Use gradient checkpointing",
)
parser.add_argument(
"--use_flash_attn",
action="store_true",
default=False,
help="Use flash-attention",
)
parser.add_argument(
"--use_neft",
action="store_true",
default=False,
help="Use NEFTune",
)
parser.add_argument(
"--freeze_non_embeds_params",
action="store_true",
default=False,
help="Freeze non embeddings parameters",
)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--zero", type=int, default=1)
args = parser.parse_args()
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
# ==============================
# Initialize Tensorboard
# ==============================
if coordinator.is_master():
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=args.zero,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
coordinator.print_on_master(f"Load dataset: {args.dataset}")
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
dataloader = setup_distributed_dataloader(
dataset=dataset,
batch_size=args.micro_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
)
coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
init_ctx = (
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
)
with init_ctx:
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)
if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
if args.use_flash_attn:
replace_with_flash_attention(model=model)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
optimizer = HybridAdam(
model_params=filter(lambda p: p.requires_grad, model.parameters())
if args.freeze_non_embeds_params
else model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
if args.warmup_steps is None:
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
# Flash attention will be disabled because it does NOT support fp32.
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
if args.load_checkpoint is None:
coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
booster.load_model(model, args.pretrained, strict=False)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
start_epoch = 0
start_step = 0
sampler_start_idx = 0
if args.load_checkpoint is not None:
if "modeling" in args.load_checkpoint:
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
booster.load_model(model, args.load_checkpoint)
else:
coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
start_epoch, start_step, sampler_start_idx = load_checkpoint(
load_dir=args.load_checkpoint,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
coordinator.print_on_master(
f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
)
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
coordinator.print_on_master(
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
# If resume training, set the sampler start index to the correct value
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch)
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
total_loss = torch.tensor(0.0).to(torch.cuda.current_device())
for step, batch in enumerate(dataloader):
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch)
loss = batch_output.loss / args.accumulation_steps
total_loss += loss.item()
booster.backward(loss=loss, optimizer=optimizer)
if (step + 1) % args.accumulation_steps == 0:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
all_reduce_mean(tensor=total_loss)
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
if coordinator.is_master():
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
writer.add_scalar(
tag="Learning Rate",
scalar_value=lr_scheduler.get_last_lr()[0],
global_step=global_step,
)
total_loss.fill_(0.0)
pbar.update()
# Save modeling.
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
step + 1
) == len(dataloader):
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.micro_batch_size,
coordinator=coordinator,
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
# Delete CUDA cache.
# del batch, batch_labels, batch_output, loss
torch.cuda.empty_cache()
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(start_index=0)
start_step = 0
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune.")
deactivate_neftune(model, handle)
# Final save.
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
main()