mirror of https://github.com/hpcaitech/ColossalAI
support session-based training (#4313)
Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>pull/2519/head
parent
ef4b99ebcd
commit
5187c96b7c
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2023 lm-sys@FastChat
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
from enum import Enum, auto
|
||||
from typing import List
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
ADD_EOS_TOKEN = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN
|
||||
sep: str = "</s>"
|
||||
|
||||
skip_next: bool = False
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN:
|
||||
ret = self.system
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += role + ": "
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep
|
||||
}
|
||||
|
||||
|
||||
conv = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ADD_EOS_TOKEN,
|
||||
sep="</s>",
|
||||
)
|
||||
|
||||
default_conversation = conv
|
|
@ -15,7 +15,7 @@
|
|||
import copy
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, Sequence
|
||||
from typing import Callable, Dict, List, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -25,11 +25,21 @@ from tqdm import tqdm
|
|||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .conversation import default_conversation
|
||||
from .utils import is_rank_0, jload
|
||||
|
||||
# The following is a template prompt for a 4-round conversation.
|
||||
"""
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>
|
||||
"""
|
||||
# Please note that we only calculate loss on assistant's answer tokens.
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
DEFAULT_EOS_TOKEN = "</s>"
|
||||
PROMPT_DICT = {
|
||||
"prompt_input":
|
||||
("Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
|
@ -107,6 +117,61 @@ def preprocess(
|
|||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
|
||||
def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_length: int) -> Dict:
|
||||
"""Preprocess the conversation data by tokenizing."""
|
||||
conversations = []
|
||||
intermediates = []
|
||||
for source in sources:
|
||||
header = f"{default_conversation.system}"
|
||||
conversation, intermediate = _add_speaker_and_signal(header, source)
|
||||
conversations.append(conversation)
|
||||
intermediates.append(intermediate)
|
||||
|
||||
conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
|
||||
input_ids = conversations_tokenized["input_ids"]
|
||||
targets = copy.deepcopy(input_ids)
|
||||
|
||||
assert len(targets) == len(intermediates)
|
||||
for target, inters in zip(targets, intermediates):
|
||||
mask = torch.zeros_like(target, dtype=torch.bool)
|
||||
for inter in inters:
|
||||
tokenized = _tokenize_fn(inter, tokenizer, max_length)
|
||||
|
||||
start_idx = tokenized["input_ids"][0].size(0) - 1
|
||||
end_idx = tokenized["input_ids"][1].size(0)
|
||||
|
||||
mask[start_idx:end_idx] = True
|
||||
target[~mask] = IGNORE_INDEX
|
||||
|
||||
return dict(input_ids=input_ids, labels=targets)
|
||||
|
||||
|
||||
def _add_speaker_and_signal(header: str,
|
||||
source: List[Dict],
|
||||
get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
|
||||
END_SIGNAL = DEFAULT_EOS_TOKEN
|
||||
conversation = header
|
||||
intermediate = []
|
||||
for sentence in source:
|
||||
from_str = sentence["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = default_conversation.roles[0]
|
||||
elif from_str.lower() == "gpt":
|
||||
from_str = default_conversation.roles[1]
|
||||
else:
|
||||
from_str = 'unknown'
|
||||
|
||||
value = from_str + ": " + sentence["value"] + END_SIGNAL
|
||||
if sentence["from"].lower() == "gpt":
|
||||
start = conversation + from_str + ": "
|
||||
end = conversation + value
|
||||
intermediate.append([start, end])
|
||||
if get_conversation:
|
||||
conversation += value
|
||||
return conversation, intermediate
|
||||
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
|
@ -125,15 +190,27 @@ class SupervisedDataset(Dataset):
|
|||
list_data_dict = list_data_dict[:max_datasets_size]
|
||||
|
||||
logger.info("Formatting inputs...")
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
sources = [
|
||||
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
||||
for example in list_data_dict
|
||||
]
|
||||
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
||||
if "conversations" not in list_data_dict[0]:
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
sources = [
|
||||
prompt_input.format_map(example)
|
||||
if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
|
||||
]
|
||||
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
||||
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
||||
if is_rank_0():
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
|
||||
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
||||
else:
|
||||
if is_rank_0():
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
|
||||
sources = [conv["conversations"] for conv in list_data_dict]
|
||||
data_dict = preprocess_conversation(sources, tokenizer, max_length)
|
||||
|
||||
if is_rank_0():
|
||||
logger.info("Tokenizing finish.")
|
||||
|
||||
self.input_ids = data_dict["input_ids"]
|
||||
self.labels = data_dict["labels"]
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
- [Table of Contents](#table-of-contents)
|
||||
- [Install requirements](#install-requirements)
|
||||
- [Supervised datasets collection](#supervised-datasets-collection)
|
||||
- [Conversation dataset generation](#conversation-dataset-generation)
|
||||
- [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning)
|
||||
- [Arg List](#arg-list)
|
||||
- [Stage2 - Training reward model](#stage2---training-reward-model)
|
||||
|
@ -45,6 +46,49 @@ The following pic shows how we collected the data.
|
|||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/data-collect.png" width=500/>
|
||||
</p>
|
||||
|
||||
### Conversation dataset generation
|
||||
|
||||
In order to further improve the model's ability to handle multi-turn conversations, we need to include samples with multi-turn conversations in the dataset. However, the samples in InstructWild and Alpaca datasets currently consist of only single-turn conversations, and their dataset organization is not suitable for storing multi-turn conversations. Additionally, after converting the aforementioned datasets, we also need to include multi-turn conversation datasets like ShareGPT, and we should transform them into the training format supported by ColossalChat.
|
||||
|
||||
A sample of conversation dataset should have the following fields:
|
||||
|
||||
* `type` (str, optional): The type of the data sample.
|
||||
* `language` (str, optional): The language of the data sample.
|
||||
* `dataset` (str, optional): The dataset the data sample originates from.
|
||||
* `conversations` (str, compulsory): Conversation content of the data sample.
|
||||
* `id` (int, optional): The ID of the data sample.
|
||||
|
||||
A simple example:
|
||||
```json
|
||||
{
|
||||
"type": "instruction",
|
||||
"language": "English",
|
||||
"dataset": "Alpaca",
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "Give three tips for staying healthy."
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
|
||||
}
|
||||
],
|
||||
"id": 1
|
||||
}
|
||||
```
|
||||
|
||||
> **NOTE:** Only key `conversations` is compulsary for training and other keys serve as metadata. The length of `conversations` varies.
|
||||
|
||||
You can run the `examples/generate_conversation_dataset.py` to generate a conversation dataset supported by ColossalChat.
|
||||
|
||||
You can use the following cmd to generate conversation dataset.
|
||||
```
|
||||
python generate_conversation_dataset.py \
|
||||
--dataset "All"
|
||||
--save_path "/path/to/dataset"
|
||||
```
|
||||
|
||||
## Stage1 - Supervised instructs tuning
|
||||
|
||||
Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
import argparse
|
||||
import json
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def generate_alpaca():
|
||||
# We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
|
||||
conversation_dataset = []
|
||||
dataset = load_dataset("tatsu-lab/alpaca", split="train")
|
||||
|
||||
instructions = dataset["instruction"]
|
||||
inputs = dataset["input"]
|
||||
outputs = dataset["output"]
|
||||
|
||||
assert len(instructions) == len(inputs) == len(outputs)
|
||||
|
||||
for idx in range(len(instructions)):
|
||||
human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
|
||||
human = {"from": "human", "value": human_utterance}
|
||||
|
||||
gpt_utterance = outputs[idx]
|
||||
gpt = {"from": "gpt", "value": gpt_utterance}
|
||||
|
||||
conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
|
||||
conversation_dataset.append(conversation)
|
||||
|
||||
return conversation_dataset
|
||||
|
||||
|
||||
def generate_sharegpt():
|
||||
# ShareGPT data requires less processing.
|
||||
conversation_dataset = []
|
||||
dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered",
|
||||
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
|
||||
split="train")
|
||||
|
||||
conversations = dataset["conversations"]
|
||||
|
||||
for idx in range(len(conversations)):
|
||||
for conv in conversations[idx]:
|
||||
# We don't need markdown and text value.
|
||||
del conv["markdown"]
|
||||
del conv["text"]
|
||||
|
||||
conversation = dict(type="conversation",
|
||||
language="Multilingual",
|
||||
dataset="ShareGPT",
|
||||
conversations=conversations[idx])
|
||||
conversation_dataset.append(conversation)
|
||||
|
||||
return conversation_dataset
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
default="All",
|
||||
choices=["Alpaca", "ShareGPT", "All"],
|
||||
help="which dataset to convert, All will combine Alpaca and ShareGPT")
|
||||
parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset")
|
||||
args = parser.parse_args()
|
||||
|
||||
conversation_dataset = []
|
||||
|
||||
if args.dataset == "Alpaca":
|
||||
conversation_dataset.extend(generate_alpaca())
|
||||
elif args.dataset == "ShareGPT":
|
||||
conversation_dataset.extend(generate_sharegpt())
|
||||
else:
|
||||
conversation_dataset.extend(generate_alpaca())
|
||||
conversation_dataset.extend(generate_sharegpt())
|
||||
|
||||
for idx, sample in enumerate(conversation_dataset):
|
||||
sample["id"] = idx + 1
|
||||
|
||||
with open(args.save_path, mode='w') as f:
|
||||
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
|
|
@ -74,8 +74,8 @@ def train(args):
|
|||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
tokenizer.eos_token = '</s>'
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
|
@ -153,9 +153,7 @@ def train(args):
|
|||
optim,
|
||||
num_warmup_steps=math.ceil(max_steps * 0.03),
|
||||
num_training_steps=max_steps)
|
||||
strategy_dict = strategy.prepare(
|
||||
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
|
||||
)
|
||||
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
|
||||
model = strategy_dict['model']
|
||||
optim = strategy_dict['optimizer']
|
||||
lr_scheduler = strategy_dict['lr_scheduler']
|
||||
|
|
Loading…
Reference in New Issue