[chatgpt] support instuct training (#3216)

pull/3218/head
Fazzie-Maqianli 2023-03-23 16:46:20 +08:00 committed by GitHub
parent cd142fbefa
commit 4fd4bd9d9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 313 additions and 39 deletions

View File

@ -1,5 +1,5 @@
from .reward_dataset import RmStaticDataset, HhRlhfDataset
from .utils import is_rank_0
from .sft_dataset import SFTDataset
from .sft_dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset']
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset', 'AlpacaDataset', 'AlpacaDataCollator']

View File

@ -1,12 +1,46 @@
from typing import Callable
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence
import random
from torch.utils.data import Dataset
import torch.distributed as dist
from tqdm import tqdm
import torch
from .utils import is_rank_0
from .utils import is_rank_0, jload
import transformers
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
IGNORE_INDEX = -100
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
class SFTDataset(Dataset):
"""
@ -38,3 +72,87 @@ class SFTDataset(Dataset):
def __getitem__(self, idx):
return self.prompts[idx]
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
class AlpacaDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
super(AlpacaDataset, self).__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
logger.info("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
@dataclass
class AlpacaDataCollator(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)

View File

@ -1,5 +1,20 @@
import io
import json
import torch.distributed as dist
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
def _make_r_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f = open(f, mode=mode)
return f
def jload(f, mode="r"):
"""Load a .json file into a dictionary."""
f = _make_r_io_base(f, mode)
jdict = json.load(f)
f.close()
return jdict

View File

@ -1,5 +1,6 @@
from .llama_actor import LlamaActor
from .llama_critic import LlamaCritic
from .llama_rm import LlamaRM
from .llama_lm import LlamaLM
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM']

View File

@ -0,0 +1,38 @@
from typing import Optional
from transformers import LlamaConfig, LlamaForCausalLM
from ..base import LM
class LlamaLM(LM):
"""
Llama language model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = LlamaForCausalLM(config)
else:
model = LlamaForCausalLM(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

View File

@ -2,7 +2,6 @@ from abc import ABC
from typing import Optional
import loralib as lora
import torch
from chatgpt.dataset import SFTDataset
from chatgpt.models.loss import GPTLMLoss
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
@ -22,8 +21,8 @@ class SFTTrainer(ABC):
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for training
eval_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for evaluation
train_dataloader: the dataloader to use for training
eval_dataloader: the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
@ -34,8 +33,8 @@ class SFTTrainer(ABC):
model,
strategy: Strategy,
optim: Optimizer,
train_dataset: SFTDataset,
eval_dataset: SFTDataset,
train_dataloader: DataLoader,
eval_dataloader: DataLoader = None,
sampler: Optional[DistributedSampler] = None,
batch_size: int = 1,
max_epochs: int = 2,
@ -43,13 +42,10 @@ class SFTTrainer(ABC):
super().__init__()
self.strategy = strategy
self.epochs = max_epochs
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.sampler = sampler
self.train_dataloader = DataLoader(self.train_dataset, shuffle=(sampler is None),
sampler=sampler, batch_size=batch_size)
self.eval_dataloader = DataLoader(self.eval_dataset, batch_size=batch_size)
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.model = strategy.setup_model(model)
if "DDP" in str(self.strategy):
@ -79,23 +75,25 @@ class SFTTrainer(ABC):
logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
# eval
self.model.eval()
with torch.no_grad():
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
prompt_ids = batch["input_ids"]
p_mask = batch["attention_mask"]
prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()
if self.eval_dataloader is not None:
self.model.eval()
with torch.no_grad():
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
prompt_ids = batch["input_ids"]
p_mask = batch["attention_mask"]
prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()
prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
loss = self.loss_fn(prompt_logits, prompt_ids)
loss_sum += loss.item()
num_seen += prompt_ids.size(0)
prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
loss = self.loss_fn(prompt_logits, prompt_ids)
loss_sum += loss.item()
num_seen += prompt_ids.size(0)
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
epoch_bar.update()

View File

@ -0,0 +1,3 @@
from .tokenizer_utils import smart_tokenizer_and_embedding_resize, prepare_llama_tokenizer_and_embedding
__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding']

View File

@ -0,0 +1,74 @@
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import transformers
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"
def prepare_llama_tokenizer_and_embedding(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
):
"""prepare llama tokenizer and embedding.
"""
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
tokenizer.add_special_tokens(
{
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
}
)
return tokenizer
def smart_tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
if tokenizer.pad_token is None:
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg

View File

@ -4,15 +4,18 @@ import loralib as lora
import torch
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from chatgpt.dataset import SFTDataset
from chatgpt.dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator
from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMLM
from chatgpt.models.gpt import GPTLM
from chatgpt.models.opt import OPTLM
from chatgpt.models.llama import LlamaLM
from chatgpt.trainer import SFTTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from chatgpt.utils import prepare_llama_tokenizer_and_embedding
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
@ -41,6 +44,8 @@ def train(args):
model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'gpt2':
model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'llama':
model = LlamaLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
else:
raise ValueError(f'Unsupported model "{args.model}"')
@ -53,9 +58,19 @@ def train(args):
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif args.model == 'llama':
tokenizer = AutoTokenizer.from_pretrained(
args.pretrain,
padding_side="right",
use_fast=False,
)
else:
raise ValueError(f'Unsupported model "{args.model}"')
tokenizer.pad_token = tokenizer.eos_token
if args.model == 'llama':
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
else:
tokenizer.pad_token = tokenizer.eos_token
max_len = 512
@ -67,11 +82,19 @@ def train(args):
logger = get_dist_logger()
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
# configure dataset
if args.dataset == 'yizhongw/self_instruct':
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
train_dataset = SFTDataset(train_data, tokenizer, max_len)
eval_dataset = SFTDataset(eval_data, tokenizer, max_len)
train_dataset = SFTDataset(train_data, tokenizer, max_len)
eval_dataset = SFTDataset(eval_data, tokenizer, max_len)
elif 'alpaca' in args.dataset:
train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.dataset)
eval_dataset = None
eval_dataset
data_collator = AlpacaDataCollator(tokenizer=tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1:
sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
@ -79,11 +102,15 @@ def train(args):
else:
sampler = None
train_dataloader = DataLoader(train_dataset, shuffle=(sampler is None), sampler=sampler, batch_size=args.batch_size)
if eval_dataset is not None:
eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size)
trainer = SFTTrainer(model=model,
strategy=strategy,
optim=optim,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
sampler=sampler,
batch_size=args.batch_size,
max_epochs=args.max_epochs)