mirror of https://github.com/hpcaitech/ColossalAI
support session-based training (#4313)
Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>pull/2519/head
parent
ef4b99ebcd
commit
5187c96b7c
|
@ -0,0 +1,87 @@
|
||||||
|
# Copyright 2023 lm-sys@FastChat
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class SeparatorStyle(Enum):
|
||||||
|
ADD_EOS_TOKEN = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Conversation:
|
||||||
|
system: str
|
||||||
|
roles: List[str]
|
||||||
|
messages: List[List[str]]
|
||||||
|
offset: int
|
||||||
|
sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN
|
||||||
|
sep: str = "</s>"
|
||||||
|
|
||||||
|
skip_next: bool = False
|
||||||
|
|
||||||
|
def get_prompt(self):
|
||||||
|
if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN:
|
||||||
|
ret = self.system
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
ret += role + ": " + message + self.sep
|
||||||
|
else:
|
||||||
|
ret += role + ": "
|
||||||
|
return ret
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||||
|
|
||||||
|
def append_message(self, role, message):
|
||||||
|
self.messages.append([role, message])
|
||||||
|
|
||||||
|
def to_gradio_chatbot(self):
|
||||||
|
ret = []
|
||||||
|
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||||
|
if i % 2 == 0:
|
||||||
|
ret.append([msg, None])
|
||||||
|
else:
|
||||||
|
ret[-1][-1] = msg
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return Conversation(system=self.system,
|
||||||
|
roles=self.roles,
|
||||||
|
messages=[[x, y] for x, y in self.messages],
|
||||||
|
offset=self.offset,
|
||||||
|
sep_style=self.sep_style,
|
||||||
|
sep=self.sep)
|
||||||
|
|
||||||
|
def dict(self):
|
||||||
|
return {
|
||||||
|
"system": self.system,
|
||||||
|
"roles": self.roles,
|
||||||
|
"messages": self.messages,
|
||||||
|
"offset": self.offset,
|
||||||
|
"sep": self.sep
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
conv = Conversation(
|
||||||
|
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||||
|
roles=("Human", "Assistant"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.ADD_EOS_TOKEN,
|
||||||
|
sep="</s>",
|
||||||
|
)
|
||||||
|
|
||||||
|
default_conversation = conv
|
|
@ -15,7 +15,7 @@
|
||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Dict, Sequence
|
from typing import Callable, Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -25,11 +25,21 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
from .conversation import default_conversation
|
||||||
from .utils import is_rank_0, jload
|
from .utils import is_rank_0, jload
|
||||||
|
|
||||||
|
# The following is a template prompt for a 4-round conversation.
|
||||||
|
"""
|
||||||
|
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||||
|
|
||||||
|
Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>
|
||||||
|
"""
|
||||||
|
# Please note that we only calculate loss on assistant's answer tokens.
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
DEFAULT_EOS_TOKEN = "</s>"
|
||||||
PROMPT_DICT = {
|
PROMPT_DICT = {
|
||||||
"prompt_input":
|
"prompt_input":
|
||||||
("Below is an instruction that describes a task, paired with an input that provides further context. "
|
("Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||||
|
@ -107,6 +117,61 @@ def preprocess(
|
||||||
return dict(input_ids=input_ids, labels=labels)
|
return dict(input_ids=input_ids, labels=labels)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
max_length: int) -> Dict:
|
||||||
|
"""Preprocess the conversation data by tokenizing."""
|
||||||
|
conversations = []
|
||||||
|
intermediates = []
|
||||||
|
for source in sources:
|
||||||
|
header = f"{default_conversation.system}"
|
||||||
|
conversation, intermediate = _add_speaker_and_signal(header, source)
|
||||||
|
conversations.append(conversation)
|
||||||
|
intermediates.append(intermediate)
|
||||||
|
|
||||||
|
conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
|
||||||
|
input_ids = conversations_tokenized["input_ids"]
|
||||||
|
targets = copy.deepcopy(input_ids)
|
||||||
|
|
||||||
|
assert len(targets) == len(intermediates)
|
||||||
|
for target, inters in zip(targets, intermediates):
|
||||||
|
mask = torch.zeros_like(target, dtype=torch.bool)
|
||||||
|
for inter in inters:
|
||||||
|
tokenized = _tokenize_fn(inter, tokenizer, max_length)
|
||||||
|
|
||||||
|
start_idx = tokenized["input_ids"][0].size(0) - 1
|
||||||
|
end_idx = tokenized["input_ids"][1].size(0)
|
||||||
|
|
||||||
|
mask[start_idx:end_idx] = True
|
||||||
|
target[~mask] = IGNORE_INDEX
|
||||||
|
|
||||||
|
return dict(input_ids=input_ids, labels=targets)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_speaker_and_signal(header: str,
|
||||||
|
source: List[Dict],
|
||||||
|
get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
|
||||||
|
END_SIGNAL = DEFAULT_EOS_TOKEN
|
||||||
|
conversation = header
|
||||||
|
intermediate = []
|
||||||
|
for sentence in source:
|
||||||
|
from_str = sentence["from"]
|
||||||
|
if from_str.lower() == "human":
|
||||||
|
from_str = default_conversation.roles[0]
|
||||||
|
elif from_str.lower() == "gpt":
|
||||||
|
from_str = default_conversation.roles[1]
|
||||||
|
else:
|
||||||
|
from_str = 'unknown'
|
||||||
|
|
||||||
|
value = from_str + ": " + sentence["value"] + END_SIGNAL
|
||||||
|
if sentence["from"].lower() == "gpt":
|
||||||
|
start = conversation + from_str + ": "
|
||||||
|
end = conversation + value
|
||||||
|
intermediate.append([start, end])
|
||||||
|
if get_conversation:
|
||||||
|
conversation += value
|
||||||
|
return conversation, intermediate
|
||||||
|
|
||||||
|
|
||||||
class SupervisedDataset(Dataset):
|
class SupervisedDataset(Dataset):
|
||||||
"""Dataset for supervised fine-tuning."""
|
"""Dataset for supervised fine-tuning."""
|
||||||
|
|
||||||
|
@ -125,15 +190,27 @@ class SupervisedDataset(Dataset):
|
||||||
list_data_dict = list_data_dict[:max_datasets_size]
|
list_data_dict = list_data_dict[:max_datasets_size]
|
||||||
|
|
||||||
logger.info("Formatting inputs...")
|
logger.info("Formatting inputs...")
|
||||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
if "conversations" not in list_data_dict[0]:
|
||||||
sources = [
|
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||||
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
sources = [
|
||||||
for example in list_data_dict
|
prompt_input.format_map(example)
|
||||||
]
|
if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
|
||||||
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
]
|
||||||
|
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
||||||
|
|
||||||
logger.info("Tokenizing inputs... This may take some time...")
|
if is_rank_0():
|
||||||
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
logger.info("Tokenizing inputs... This may take some time...")
|
||||||
|
|
||||||
|
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
||||||
|
else:
|
||||||
|
if is_rank_0():
|
||||||
|
logger.info("Tokenizing inputs... This may take some time...")
|
||||||
|
|
||||||
|
sources = [conv["conversations"] for conv in list_data_dict]
|
||||||
|
data_dict = preprocess_conversation(sources, tokenizer, max_length)
|
||||||
|
|
||||||
|
if is_rank_0():
|
||||||
|
logger.info("Tokenizing finish.")
|
||||||
|
|
||||||
self.input_ids = data_dict["input_ids"]
|
self.input_ids = data_dict["input_ids"]
|
||||||
self.labels = data_dict["labels"]
|
self.labels = data_dict["labels"]
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
- [Table of Contents](#table-of-contents)
|
- [Table of Contents](#table-of-contents)
|
||||||
- [Install requirements](#install-requirements)
|
- [Install requirements](#install-requirements)
|
||||||
- [Supervised datasets collection](#supervised-datasets-collection)
|
- [Supervised datasets collection](#supervised-datasets-collection)
|
||||||
|
- [Conversation dataset generation](#conversation-dataset-generation)
|
||||||
- [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning)
|
- [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning)
|
||||||
- [Arg List](#arg-list)
|
- [Arg List](#arg-list)
|
||||||
- [Stage2 - Training reward model](#stage2---training-reward-model)
|
- [Stage2 - Training reward model](#stage2---training-reward-model)
|
||||||
|
@ -45,6 +46,49 @@ The following pic shows how we collected the data.
|
||||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/data-collect.png" width=500/>
|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/data-collect.png" width=500/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
### Conversation dataset generation
|
||||||
|
|
||||||
|
In order to further improve the model's ability to handle multi-turn conversations, we need to include samples with multi-turn conversations in the dataset. However, the samples in InstructWild and Alpaca datasets currently consist of only single-turn conversations, and their dataset organization is not suitable for storing multi-turn conversations. Additionally, after converting the aforementioned datasets, we also need to include multi-turn conversation datasets like ShareGPT, and we should transform them into the training format supported by ColossalChat.
|
||||||
|
|
||||||
|
A sample of conversation dataset should have the following fields:
|
||||||
|
|
||||||
|
* `type` (str, optional): The type of the data sample.
|
||||||
|
* `language` (str, optional): The language of the data sample.
|
||||||
|
* `dataset` (str, optional): The dataset the data sample originates from.
|
||||||
|
* `conversations` (str, compulsory): Conversation content of the data sample.
|
||||||
|
* `id` (int, optional): The ID of the data sample.
|
||||||
|
|
||||||
|
A simple example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "instruction",
|
||||||
|
"language": "English",
|
||||||
|
"dataset": "Alpaca",
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "Give three tips for staying healthy."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"id": 1
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> **NOTE:** Only key `conversations` is compulsary for training and other keys serve as metadata. The length of `conversations` varies.
|
||||||
|
|
||||||
|
You can run the `examples/generate_conversation_dataset.py` to generate a conversation dataset supported by ColossalChat.
|
||||||
|
|
||||||
|
You can use the following cmd to generate conversation dataset.
|
||||||
|
```
|
||||||
|
python generate_conversation_dataset.py \
|
||||||
|
--dataset "All"
|
||||||
|
--save_path "/path/to/dataset"
|
||||||
|
```
|
||||||
|
|
||||||
## Stage1 - Supervised instructs tuning
|
## Stage1 - Supervised instructs tuning
|
||||||
|
|
||||||
Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.
|
Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def generate_alpaca():
|
||||||
|
# We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
|
||||||
|
conversation_dataset = []
|
||||||
|
dataset = load_dataset("tatsu-lab/alpaca", split="train")
|
||||||
|
|
||||||
|
instructions = dataset["instruction"]
|
||||||
|
inputs = dataset["input"]
|
||||||
|
outputs = dataset["output"]
|
||||||
|
|
||||||
|
assert len(instructions) == len(inputs) == len(outputs)
|
||||||
|
|
||||||
|
for idx in range(len(instructions)):
|
||||||
|
human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
|
||||||
|
human = {"from": "human", "value": human_utterance}
|
||||||
|
|
||||||
|
gpt_utterance = outputs[idx]
|
||||||
|
gpt = {"from": "gpt", "value": gpt_utterance}
|
||||||
|
|
||||||
|
conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
|
||||||
|
conversation_dataset.append(conversation)
|
||||||
|
|
||||||
|
return conversation_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sharegpt():
|
||||||
|
# ShareGPT data requires less processing.
|
||||||
|
conversation_dataset = []
|
||||||
|
dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered",
|
||||||
|
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
|
||||||
|
split="train")
|
||||||
|
|
||||||
|
conversations = dataset["conversations"]
|
||||||
|
|
||||||
|
for idx in range(len(conversations)):
|
||||||
|
for conv in conversations[idx]:
|
||||||
|
# We don't need markdown and text value.
|
||||||
|
del conv["markdown"]
|
||||||
|
del conv["text"]
|
||||||
|
|
||||||
|
conversation = dict(type="conversation",
|
||||||
|
language="Multilingual",
|
||||||
|
dataset="ShareGPT",
|
||||||
|
conversations=conversations[idx])
|
||||||
|
conversation_dataset.append(conversation)
|
||||||
|
|
||||||
|
return conversation_dataset
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--dataset',
|
||||||
|
type=str,
|
||||||
|
default="All",
|
||||||
|
choices=["Alpaca", "ShareGPT", "All"],
|
||||||
|
help="which dataset to convert, All will combine Alpaca and ShareGPT")
|
||||||
|
parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
conversation_dataset = []
|
||||||
|
|
||||||
|
if args.dataset == "Alpaca":
|
||||||
|
conversation_dataset.extend(generate_alpaca())
|
||||||
|
elif args.dataset == "ShareGPT":
|
||||||
|
conversation_dataset.extend(generate_sharegpt())
|
||||||
|
else:
|
||||||
|
conversation_dataset.extend(generate_alpaca())
|
||||||
|
conversation_dataset.extend(generate_sharegpt())
|
||||||
|
|
||||||
|
for idx, sample in enumerate(conversation_dataset):
|
||||||
|
sample["id"] = idx + 1
|
||||||
|
|
||||||
|
with open(args.save_path, mode='w') as f:
|
||||||
|
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
|
|
@ -74,8 +74,8 @@ def train(args):
|
||||||
padding_side="right",
|
padding_side="right",
|
||||||
use_fast=False,
|
use_fast=False,
|
||||||
)
|
)
|
||||||
tokenizer.eos_token = '<\s>'
|
tokenizer.eos_token = '</s>'
|
||||||
tokenizer.pad_token = tokenizer.unk_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
|
@ -153,9 +153,7 @@ def train(args):
|
||||||
optim,
|
optim,
|
||||||
num_warmup_steps=math.ceil(max_steps * 0.03),
|
num_warmup_steps=math.ceil(max_steps * 0.03),
|
||||||
num_training_steps=max_steps)
|
num_training_steps=max_steps)
|
||||||
strategy_dict = strategy.prepare(
|
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
|
||||||
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
|
|
||||||
)
|
|
||||||
model = strategy_dict['model']
|
model = strategy_dict['model']
|
||||||
optim = strategy_dict['optimizer']
|
optim = strategy_dict['optimizer']
|
||||||
lr_scheduler = strategy_dict['lr_scheduler']
|
lr_scheduler = strategy_dict['lr_scheduler']
|
||||||
|
|
Loading…
Reference in New Issue