support session-based training (#4313)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
pull/2519/head
Yuanchen 2023-07-28 11:29:55 +08:00 committed by GitHub
parent ef4b99ebcd
commit 5187c96b7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 299 additions and 14 deletions

View File

@ -0,0 +1,87 @@
# Copyright 2023 lm-sys@FastChat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from enum import Enum, auto
from typing import List
class SeparatorStyle(Enum):
ADD_EOS_TOKEN = auto()
@dataclasses.dataclass
class Conversation:
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN
sep: str = "</s>"
skip_next: bool = False
def get_prompt(self):
if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN:
ret = self.system
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ": "
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep)
def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep
}
conv = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_EOS_TOKEN,
sep="</s>",
)
default_conversation = conv

View File

@ -15,7 +15,7 @@
import copy
import random
from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence
from typing import Callable, Dict, List, Sequence, Tuple
import torch
import torch.distributed as dist
@ -25,11 +25,21 @@ from tqdm import tqdm
from colossalai.logging import get_dist_logger
from .conversation import default_conversation
from .utils import is_rank_0, jload
# The following is a template prompt for a 4-round conversation.
"""
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>
"""
# Please note that we only calculate loss on assistant's answer tokens.
logger = get_dist_logger()
IGNORE_INDEX = -100
DEFAULT_EOS_TOKEN = "</s>"
PROMPT_DICT = {
"prompt_input":
("Below is an instruction that describes a task, paired with an input that provides further context. "
@ -107,6 +117,61 @@ def preprocess(
return dict(input_ids=input_ids, labels=labels)
def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
max_length: int) -> Dict:
"""Preprocess the conversation data by tokenizing."""
conversations = []
intermediates = []
for source in sources:
header = f"{default_conversation.system}"
conversation, intermediate = _add_speaker_and_signal(header, source)
conversations.append(conversation)
intermediates.append(intermediate)
conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
assert len(targets) == len(intermediates)
for target, inters in zip(targets, intermediates):
mask = torch.zeros_like(target, dtype=torch.bool)
for inter in inters:
tokenized = _tokenize_fn(inter, tokenizer, max_length)
start_idx = tokenized["input_ids"][0].size(0) - 1
end_idx = tokenized["input_ids"][1].size(0)
mask[start_idx:end_idx] = True
target[~mask] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def _add_speaker_and_signal(header: str,
source: List[Dict],
get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
END_SIGNAL = DEFAULT_EOS_TOKEN
conversation = header
intermediate = []
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = default_conversation.roles[1]
else:
from_str = 'unknown'
value = from_str + ": " + sentence["value"] + END_SIGNAL
if sentence["from"].lower() == "gpt":
start = conversation + from_str + ": "
end = conversation + value
intermediate.append([start, end])
if get_conversation:
conversation += value
return conversation, intermediate
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
@ -125,15 +190,27 @@ class SupervisedDataset(Dataset):
list_data_dict = list_data_dict[:max_datasets_size]
logger.info("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
if "conversations" not in list_data_dict[0]:
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example)
if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
logger.info("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer, max_length)
if is_rank_0():
logger.info("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer, max_length)
else:
if is_rank_0():
logger.info("Tokenizing inputs... This may take some time...")
sources = [conv["conversations"] for conv in list_data_dict]
data_dict = preprocess_conversation(sources, tokenizer, max_length)
if is_rank_0():
logger.info("Tokenizing finish.")
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]

View File

@ -6,6 +6,7 @@
- [Table of Contents](#table-of-contents)
- [Install requirements](#install-requirements)
- [Supervised datasets collection](#supervised-datasets-collection)
- [Conversation dataset generation](#conversation-dataset-generation)
- [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning)
- [Arg List](#arg-list)
- [Stage2 - Training reward model](#stage2---training-reward-model)
@ -45,6 +46,49 @@ The following pic shows how we collected the data.
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/data-collect.png" width=500/>
</p>
### Conversation dataset generation
In order to further improve the model's ability to handle multi-turn conversations, we need to include samples with multi-turn conversations in the dataset. However, the samples in InstructWild and Alpaca datasets currently consist of only single-turn conversations, and their dataset organization is not suitable for storing multi-turn conversations. Additionally, after converting the aforementioned datasets, we also need to include multi-turn conversation datasets like ShareGPT, and we should transform them into the training format supported by ColossalChat.
A sample of conversation dataset should have the following fields:
* `type` (str, optional): The type of the data sample.
* `language` (str, optional): The language of the data sample.
* `dataset` (str, optional): The dataset the data sample originates from.
* `conversations` (str, compulsory): Conversation content of the data sample.
* `id` (int, optional): The ID of the data sample.
A simple example:
```json
{
"type": "instruction",
"language": "English",
"dataset": "Alpaca",
"conversations": [
{
"from": "human",
"value": "Give three tips for staying healthy."
},
{
"from": "gpt",
"value": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
}
],
"id": 1
}
```
> **NOTE:** Only key `conversations` is compulsary for training and other keys serve as metadata. The length of `conversations` varies.
You can run the `examples/generate_conversation_dataset.py` to generate a conversation dataset supported by ColossalChat.
You can use the following cmd to generate conversation dataset.
```
python generate_conversation_dataset.py \
--dataset "All"
--save_path "/path/to/dataset"
```
## Stage1 - Supervised instructs tuning
Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.

View File

@ -0,0 +1,79 @@
import argparse
import json
from datasets import load_dataset
def generate_alpaca():
# We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
conversation_dataset = []
dataset = load_dataset("tatsu-lab/alpaca", split="train")
instructions = dataset["instruction"]
inputs = dataset["input"]
outputs = dataset["output"]
assert len(instructions) == len(inputs) == len(outputs)
for idx in range(len(instructions)):
human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
human = {"from": "human", "value": human_utterance}
gpt_utterance = outputs[idx]
gpt = {"from": "gpt", "value": gpt_utterance}
conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
conversation_dataset.append(conversation)
return conversation_dataset
def generate_sharegpt():
# ShareGPT data requires less processing.
conversation_dataset = []
dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered",
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
split="train")
conversations = dataset["conversations"]
for idx in range(len(conversations)):
for conv in conversations[idx]:
# We don't need markdown and text value.
del conv["markdown"]
del conv["text"]
conversation = dict(type="conversation",
language="Multilingual",
dataset="ShareGPT",
conversations=conversations[idx])
conversation_dataset.append(conversation)
return conversation_dataset
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',
type=str,
default="All",
choices=["Alpaca", "ShareGPT", "All"],
help="which dataset to convert, All will combine Alpaca and ShareGPT")
parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset")
args = parser.parse_args()
conversation_dataset = []
if args.dataset == "Alpaca":
conversation_dataset.extend(generate_alpaca())
elif args.dataset == "ShareGPT":
conversation_dataset.extend(generate_sharegpt())
else:
conversation_dataset.extend(generate_alpaca())
conversation_dataset.extend(generate_sharegpt())
for idx, sample in enumerate(conversation_dataset):
sample["id"] = idx + 1
with open(args.save_path, mode='w') as f:
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)

View File

@ -74,8 +74,8 @@ def train(args):
padding_side="right",
use_fast=False,
)
tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token
tokenizer.eos_token = '</s>'
tokenizer.pad_token = tokenizer.eos_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
@ -153,9 +153,7 @@ def train(args):
optim,
num_warmup_steps=math.ceil(max_steps * 0.03),
num_training_steps=max_steps)
strategy_dict = strategy.prepare(
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
)
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
model = strategy_dict['model']
optim = strategy_dict['optimizer']
lr_scheduler = strategy_dict['lr_scheduler']