support session-based training (#4313)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
pull/2519/head
Yuanchen 2023-07-28 11:29:55 +08:00 committed by GitHub
parent ef4b99ebcd
commit 5187c96b7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 299 additions and 14 deletions

View File

@ -0,0 +1,87 @@
# Copyright 2023 lm-sys@FastChat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from enum import Enum, auto
from typing import List
class SeparatorStyle(Enum):
ADD_EOS_TOKEN = auto()
@dataclasses.dataclass
class Conversation:
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN
sep: str = "</s>"
skip_next: bool = False
def get_prompt(self):
if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN:
ret = self.system
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ": "
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep)
def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep
}
conv = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_EOS_TOKEN,
sep="</s>",
)
default_conversation = conv

View File

@ -15,7 +15,7 @@
import copy import copy
import random import random
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence from typing import Callable, Dict, List, Sequence, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -25,11 +25,21 @@ from tqdm import tqdm
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from .conversation import default_conversation
from .utils import is_rank_0, jload from .utils import is_rank_0, jload
# The following is a template prompt for a 4-round conversation.
"""
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>
"""
# Please note that we only calculate loss on assistant's answer tokens.
logger = get_dist_logger() logger = get_dist_logger()
IGNORE_INDEX = -100 IGNORE_INDEX = -100
DEFAULT_EOS_TOKEN = "</s>"
PROMPT_DICT = { PROMPT_DICT = {
"prompt_input": "prompt_input":
("Below is an instruction that describes a task, paired with an input that provides further context. " ("Below is an instruction that describes a task, paired with an input that provides further context. "
@ -107,6 +117,61 @@ def preprocess(
return dict(input_ids=input_ids, labels=labels) return dict(input_ids=input_ids, labels=labels)
def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
max_length: int) -> Dict:
"""Preprocess the conversation data by tokenizing."""
conversations = []
intermediates = []
for source in sources:
header = f"{default_conversation.system}"
conversation, intermediate = _add_speaker_and_signal(header, source)
conversations.append(conversation)
intermediates.append(intermediate)
conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
assert len(targets) == len(intermediates)
for target, inters in zip(targets, intermediates):
mask = torch.zeros_like(target, dtype=torch.bool)
for inter in inters:
tokenized = _tokenize_fn(inter, tokenizer, max_length)
start_idx = tokenized["input_ids"][0].size(0) - 1
end_idx = tokenized["input_ids"][1].size(0)
mask[start_idx:end_idx] = True
target[~mask] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def _add_speaker_and_signal(header: str,
source: List[Dict],
get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
END_SIGNAL = DEFAULT_EOS_TOKEN
conversation = header
intermediate = []
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = default_conversation.roles[1]
else:
from_str = 'unknown'
value = from_str + ": " + sentence["value"] + END_SIGNAL
if sentence["from"].lower() == "gpt":
start = conversation + from_str + ": "
end = conversation + value
intermediate.append([start, end])
if get_conversation:
conversation += value
return conversation, intermediate
class SupervisedDataset(Dataset): class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning.""" """Dataset for supervised fine-tuning."""
@ -125,15 +190,27 @@ class SupervisedDataset(Dataset):
list_data_dict = list_data_dict[:max_datasets_size] list_data_dict = list_data_dict[:max_datasets_size]
logger.info("Formatting inputs...") logger.info("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] if "conversations" not in list_data_dict[0]:
sources = [ prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) sources = [
for example in list_data_dict prompt_input.format_map(example)
] if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] ]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
logger.info("Tokenizing inputs... This may take some time...") if is_rank_0():
data_dict = preprocess(sources, targets, tokenizer, max_length) logger.info("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer, max_length)
else:
if is_rank_0():
logger.info("Tokenizing inputs... This may take some time...")
sources = [conv["conversations"] for conv in list_data_dict]
data_dict = preprocess_conversation(sources, tokenizer, max_length)
if is_rank_0():
logger.info("Tokenizing finish.")
self.input_ids = data_dict["input_ids"] self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"] self.labels = data_dict["labels"]

View File

@ -6,6 +6,7 @@
- [Table of Contents](#table-of-contents) - [Table of Contents](#table-of-contents)
- [Install requirements](#install-requirements) - [Install requirements](#install-requirements)
- [Supervised datasets collection](#supervised-datasets-collection) - [Supervised datasets collection](#supervised-datasets-collection)
- [Conversation dataset generation](#conversation-dataset-generation)
- [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning) - [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning)
- [Arg List](#arg-list) - [Arg List](#arg-list)
- [Stage2 - Training reward model](#stage2---training-reward-model) - [Stage2 - Training reward model](#stage2---training-reward-model)
@ -45,6 +46,49 @@ The following pic shows how we collected the data.
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/data-collect.png" width=500/> <img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/data-collect.png" width=500/>
</p> </p>
### Conversation dataset generation
In order to further improve the model's ability to handle multi-turn conversations, we need to include samples with multi-turn conversations in the dataset. However, the samples in InstructWild and Alpaca datasets currently consist of only single-turn conversations, and their dataset organization is not suitable for storing multi-turn conversations. Additionally, after converting the aforementioned datasets, we also need to include multi-turn conversation datasets like ShareGPT, and we should transform them into the training format supported by ColossalChat.
A sample of conversation dataset should have the following fields:
* `type` (str, optional): The type of the data sample.
* `language` (str, optional): The language of the data sample.
* `dataset` (str, optional): The dataset the data sample originates from.
* `conversations` (str, compulsory): Conversation content of the data sample.
* `id` (int, optional): The ID of the data sample.
A simple example:
```json
{
"type": "instruction",
"language": "English",
"dataset": "Alpaca",
"conversations": [
{
"from": "human",
"value": "Give three tips for staying healthy."
},
{
"from": "gpt",
"value": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
}
],
"id": 1
}
```
> **NOTE:** Only key `conversations` is compulsary for training and other keys serve as metadata. The length of `conversations` varies.
You can run the `examples/generate_conversation_dataset.py` to generate a conversation dataset supported by ColossalChat.
You can use the following cmd to generate conversation dataset.
```
python generate_conversation_dataset.py \
--dataset "All"
--save_path "/path/to/dataset"
```
## Stage1 - Supervised instructs tuning ## Stage1 - Supervised instructs tuning
Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model. Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.

View File

@ -0,0 +1,79 @@
import argparse
import json
from datasets import load_dataset
def generate_alpaca():
# We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
conversation_dataset = []
dataset = load_dataset("tatsu-lab/alpaca", split="train")
instructions = dataset["instruction"]
inputs = dataset["input"]
outputs = dataset["output"]
assert len(instructions) == len(inputs) == len(outputs)
for idx in range(len(instructions)):
human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
human = {"from": "human", "value": human_utterance}
gpt_utterance = outputs[idx]
gpt = {"from": "gpt", "value": gpt_utterance}
conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
conversation_dataset.append(conversation)
return conversation_dataset
def generate_sharegpt():
# ShareGPT data requires less processing.
conversation_dataset = []
dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered",
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
split="train")
conversations = dataset["conversations"]
for idx in range(len(conversations)):
for conv in conversations[idx]:
# We don't need markdown and text value.
del conv["markdown"]
del conv["text"]
conversation = dict(type="conversation",
language="Multilingual",
dataset="ShareGPT",
conversations=conversations[idx])
conversation_dataset.append(conversation)
return conversation_dataset
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',
type=str,
default="All",
choices=["Alpaca", "ShareGPT", "All"],
help="which dataset to convert, All will combine Alpaca and ShareGPT")
parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset")
args = parser.parse_args()
conversation_dataset = []
if args.dataset == "Alpaca":
conversation_dataset.extend(generate_alpaca())
elif args.dataset == "ShareGPT":
conversation_dataset.extend(generate_sharegpt())
else:
conversation_dataset.extend(generate_alpaca())
conversation_dataset.extend(generate_sharegpt())
for idx, sample in enumerate(conversation_dataset):
sample["id"] = idx + 1
with open(args.save_path, mode='w') as f:
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)

View File

@ -74,8 +74,8 @@ def train(args):
padding_side="right", padding_side="right",
use_fast=False, use_fast=False,
) )
tokenizer.eos_token = '<\s>' tokenizer.eos_token = '</s>'
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.eos_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
@ -153,9 +153,7 @@ def train(args):
optim, optim,
num_warmup_steps=math.ceil(max_steps * 0.03), num_warmup_steps=math.ceil(max_steps * 0.03),
num_training_steps=max_steps) num_training_steps=max_steps)
strategy_dict = strategy.prepare( strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
)
model = strategy_dict['model'] model = strategy_dict['model']
optim = strategy_dict['optimizer'] optim = strategy_dict['optimizer']
lr_scheduler = strategy_dict['lr_scheduler'] lr_scheduler = strategy_dict['lr_scheduler']