mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] support instuct training (#3216)
parent
cd142fbefa
commit
4fd4bd9d9a
|
@ -1,5 +1,5 @@
|
|||
from .reward_dataset import RmStaticDataset, HhRlhfDataset
|
||||
from .utils import is_rank_0
|
||||
from .sft_dataset import SFTDataset
|
||||
from .sft_dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator
|
||||
|
||||
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset']
|
||||
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset', 'AlpacaDataset', 'AlpacaDataCollator']
|
||||
|
|
|
@ -1,12 +1,46 @@
|
|||
from typing import Callable
|
||||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, Sequence
|
||||
import random
|
||||
from torch.utils.data import Dataset
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from .utils import is_rank_0
|
||||
from .utils import is_rank_0, jload
|
||||
|
||||
import transformers
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
PROMPT_DICT = {
|
||||
"prompt_input": (
|
||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
|
||||
),
|
||||
"prompt_no_input": (
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Response:"
|
||||
),
|
||||
}
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
"""
|
||||
|
@ -38,3 +72,87 @@ class SFTDataset(Dataset):
|
|||
|
||||
def __getitem__(self, idx):
|
||||
return self.prompts[idx]
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
)
|
||||
for text in strings
|
||||
]
|
||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||
input_ids_lens = labels_lens = [
|
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
||||
]
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
label[:source_len] = IGNORE_INDEX
|
||||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
class AlpacaDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
|
||||
super(AlpacaDataset, self).__init__()
|
||||
logger.info("Loading data...")
|
||||
list_data_dict = jload(data_path)
|
||||
|
||||
logger.info("Formatting inputs...")
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
sources = [
|
||||
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
||||
for example in list_data_dict
|
||||
]
|
||||
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
||||
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
data_dict = preprocess(sources, targets, tokenizer)
|
||||
|
||||
self.input_ids = data_dict["input_ids"]
|
||||
self.labels = data_dict["labels"]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
|
||||
@dataclass
|
||||
class AlpacaDataCollator(object):
|
||||
"""Collate examples for supervised fine-tuning."""
|
||||
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
|
||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
||||
)
|
||||
|
|
|
@ -1,5 +1,20 @@
|
|||
import io
|
||||
import json
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
||||
|
||||
def _make_r_io_base(f, mode: str):
|
||||
if not isinstance(f, io.IOBase):
|
||||
f = open(f, mode=mode)
|
||||
return f
|
||||
|
||||
def jload(f, mode="r"):
|
||||
"""Load a .json file into a dictionary."""
|
||||
f = _make_r_io_base(f, mode)
|
||||
jdict = json.load(f)
|
||||
f.close()
|
||||
return jdict
|
|
@ -1,5 +1,6 @@
|
|||
from .llama_actor import LlamaActor
|
||||
from .llama_critic import LlamaCritic
|
||||
from .llama_rm import LlamaRM
|
||||
from .llama_lm import LlamaLM
|
||||
|
||||
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
|
||||
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM']
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
from typing import Optional
|
||||
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class LlamaLM(LM):
|
||||
"""
|
||||
Llama language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
if pretrained is not None:
|
||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = LlamaForCausalLM(config)
|
||||
else:
|
||||
model = LlamaForCausalLM(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
||||
|
|
@ -2,7 +2,6 @@ from abc import ABC
|
|||
from typing import Optional
|
||||
import loralib as lora
|
||||
import torch
|
||||
from chatgpt.dataset import SFTDataset
|
||||
from chatgpt.models.loss import GPTLMLoss
|
||||
from torch.optim import Adam, Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -22,8 +21,8 @@ class SFTTrainer(ABC):
|
|||
model (torch.nn.Module): the model to train
|
||||
strategy (Strategy): the strategy to use for training
|
||||
optim(Optimizer): the optimizer to use for training
|
||||
train_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for training
|
||||
eval_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for evaluation
|
||||
train_dataloader: the dataloader to use for training
|
||||
eval_dataloader: the dataloader to use for evaluation
|
||||
batch_size (int, defaults to 1): the batch size while training
|
||||
max_epochs (int, defaults to 2): the number of epochs to train
|
||||
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
|
||||
|
@ -34,8 +33,8 @@ class SFTTrainer(ABC):
|
|||
model,
|
||||
strategy: Strategy,
|
||||
optim: Optimizer,
|
||||
train_dataset: SFTDataset,
|
||||
eval_dataset: SFTDataset,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader = None,
|
||||
sampler: Optional[DistributedSampler] = None,
|
||||
batch_size: int = 1,
|
||||
max_epochs: int = 2,
|
||||
|
@ -43,13 +42,10 @@ class SFTTrainer(ABC):
|
|||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.epochs = max_epochs
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.sampler = sampler
|
||||
|
||||
self.train_dataloader = DataLoader(self.train_dataset, shuffle=(sampler is None),
|
||||
sampler=sampler, batch_size=batch_size)
|
||||
self.eval_dataloader = DataLoader(self.eval_dataset, batch_size=batch_size)
|
||||
self.train_dataloader = train_dataloader
|
||||
self.eval_dataloader = eval_dataloader
|
||||
|
||||
self.model = strategy.setup_model(model)
|
||||
if "DDP" in str(self.strategy):
|
||||
|
@ -79,23 +75,25 @@ class SFTTrainer(ABC):
|
|||
logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
|
||||
|
||||
# eval
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
loss_sum = 0
|
||||
num_seen = 0
|
||||
for batch in self.eval_dataloader:
|
||||
prompt_ids = batch["input_ids"]
|
||||
p_mask = batch["attention_mask"]
|
||||
prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||
p_mask = p_mask.squeeze(1).cuda()
|
||||
if self.eval_dataloader is not None:
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
loss_sum = 0
|
||||
num_seen = 0
|
||||
for batch in self.eval_dataloader:
|
||||
prompt_ids = batch["input_ids"]
|
||||
p_mask = batch["attention_mask"]
|
||||
prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||
p_mask = p_mask.squeeze(1).cuda()
|
||||
|
||||
prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
|
||||
loss = self.loss_fn(prompt_logits, prompt_ids)
|
||||
loss_sum += loss.item()
|
||||
num_seen += prompt_ids.size(0)
|
||||
prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
|
||||
loss = self.loss_fn(prompt_logits, prompt_ids)
|
||||
loss_sum += loss.item()
|
||||
num_seen += prompt_ids.size(0)
|
||||
|
||||
loss_mean = loss_sum / num_seen
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
|
||||
loss_mean = loss_sum / num_seen
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
|
||||
|
||||
epoch_bar.update()
|
||||
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .tokenizer_utils import smart_tokenizer_and_embedding_resize, prepare_llama_tokenizer_and_embedding
|
||||
|
||||
__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding']
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import transformers
|
||||
|
||||
DEFAULT_PAD_TOKEN = "[PAD]"
|
||||
DEFAULT_EOS_TOKEN = "</s>"
|
||||
DEFAULT_BOS_TOKEN = "</s>"
|
||||
DEFAULT_UNK_TOKEN = "</s>"
|
||||
|
||||
def prepare_llama_tokenizer_and_embedding(
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
model: transformers.PreTrainedModel,
|
||||
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
):
|
||||
"""prepare llama tokenizer and embedding.
|
||||
|
||||
"""
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
smart_tokenizer_and_embedding_resize(
|
||||
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
)
|
||||
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": DEFAULT_EOS_TOKEN,
|
||||
"bos_token": DEFAULT_BOS_TOKEN,
|
||||
"unk_token": DEFAULT_UNK_TOKEN,
|
||||
}
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def smart_tokenizer_and_embedding_resize(
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
model: transformers.PreTrainedModel,
|
||||
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
):
|
||||
"""Resize tokenizer and embedding.
|
||||
|
||||
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
||||
"""
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if num_new_tokens > 0:
|
||||
input_embeddings = model.get_input_embeddings().weight.data
|
||||
output_embeddings = model.get_output_embeddings().weight.data
|
||||
|
||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
|
||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||
|
|
@ -4,15 +4,18 @@ import loralib as lora
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from chatgpt.dataset import SFTDataset
|
||||
from chatgpt.dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator
|
||||
from chatgpt.models.base import RewardModel
|
||||
from chatgpt.models.bloom import BLOOMLM
|
||||
from chatgpt.models.gpt import GPTLM
|
||||
from chatgpt.models.opt import OPTLM
|
||||
from chatgpt.models.llama import LlamaLM
|
||||
from chatgpt.trainer import SFTTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from chatgpt.utils import prepare_llama_tokenizer_and_embedding
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
|
@ -41,6 +44,8 @@ def train(args):
|
|||
model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
elif args.model == 'gpt2':
|
||||
model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
elif args.model == 'llama':
|
||||
model = LlamaLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
|
@ -53,9 +58,19 @@ def train(args):
|
|||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
elif args.model == 'llama':
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrain,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
if args.model == 'llama':
|
||||
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
|
||||
else:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
max_len = 512
|
||||
|
||||
|
@ -67,11 +82,19 @@ def train(args):
|
|||
|
||||
logger = get_dist_logger()
|
||||
|
||||
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
|
||||
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
|
||||
# configure dataset
|
||||
if args.dataset == 'yizhongw/self_instruct':
|
||||
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
|
||||
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
|
||||
|
||||
train_dataset = SFTDataset(train_data, tokenizer, max_len)
|
||||
eval_dataset = SFTDataset(eval_data, tokenizer, max_len)
|
||||
train_dataset = SFTDataset(train_data, tokenizer, max_len)
|
||||
eval_dataset = SFTDataset(eval_data, tokenizer, max_len)
|
||||
|
||||
elif 'alpaca' in args.dataset:
|
||||
train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.dataset)
|
||||
eval_dataset = None
|
||||
eval_dataset
|
||||
data_collator = AlpacaDataCollator(tokenizer=tokenizer)
|
||||
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
|
@ -79,11 +102,15 @@ def train(args):
|
|||
else:
|
||||
sampler = None
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, shuffle=(sampler is None), sampler=sampler, batch_size=args.batch_size)
|
||||
if eval_dataset is not None:
|
||||
eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size)
|
||||
|
||||
trainer = SFTTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
sampler=sampler,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs)
|
||||
|
|
Loading…
Reference in New Issue