From 4fd4bd9d9a88bde184d347a4b283b117e5025630 Mon Sep 17 00:00:00 2001 From: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com> Date: Thu, 23 Mar 2023 16:46:20 +0800 Subject: [PATCH] [chatgpt] support instuct training (#3216) --- .../ChatGPT/chatgpt/dataset/__init__.py | 4 +- .../ChatGPT/chatgpt/dataset/sft_dataset.py | 122 +++++++++++++++++- applications/ChatGPT/chatgpt/dataset/utils.py | 15 +++ .../ChatGPT/chatgpt/models/llama/__init__.py | 3 +- .../ChatGPT/chatgpt/models/llama/llama_lm.py | 38 ++++++ applications/ChatGPT/chatgpt/trainer/sft.py | 50 ++++--- .../ChatGPT/chatgpt/utils/__init__.py | 3 + .../ChatGPT/chatgpt/utils/tokenizer_utils.py | 74 +++++++++++ applications/ChatGPT/examples/train_sft.py | 43 ++++-- 9 files changed, 313 insertions(+), 39 deletions(-) create mode 100644 applications/ChatGPT/chatgpt/models/llama/llama_lm.py create mode 100644 applications/ChatGPT/chatgpt/utils/__init__.py create mode 100644 applications/ChatGPT/chatgpt/utils/tokenizer_utils.py diff --git a/applications/ChatGPT/chatgpt/dataset/__init__.py b/applications/ChatGPT/chatgpt/dataset/__init__.py index 78fd2c070..df484f46d 100644 --- a/applications/ChatGPT/chatgpt/dataset/__init__.py +++ b/applications/ChatGPT/chatgpt/dataset/__init__.py @@ -1,5 +1,5 @@ from .reward_dataset import RmStaticDataset, HhRlhfDataset from .utils import is_rank_0 -from .sft_dataset import SFTDataset +from .sft_dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator -__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset'] +__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset', 'AlpacaDataset', 'AlpacaDataCollator'] diff --git a/applications/ChatGPT/chatgpt/dataset/sft_dataset.py b/applications/ChatGPT/chatgpt/dataset/sft_dataset.py index 53ad20507..67e1b761c 100644 --- a/applications/ChatGPT/chatgpt/dataset/sft_dataset.py +++ b/applications/ChatGPT/chatgpt/dataset/sft_dataset.py @@ -1,12 +1,46 @@ -from typing import Callable +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from dataclasses import dataclass, field +from typing import Callable, Dict, Sequence import random from torch.utils.data import Dataset import torch.distributed as dist from tqdm import tqdm import torch -from .utils import is_rank_0 +from .utils import is_rank_0, jload + +import transformers +from colossalai.logging import get_dist_logger +logger = get_dist_logger() + +IGNORE_INDEX = -100 +PROMPT_DICT = { + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:" + ), +} class SFTDataset(Dataset): """ @@ -38,3 +72,87 @@ class SFTDataset(Dataset): def __getitem__(self, idx): return self.prompts[idx] + + +def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) + for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + +def preprocess( + sources: Sequence[str], + targets: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + """Preprocess the data by tokenizing.""" + examples = [s + t for s, t in zip(sources, targets)] + examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] + input_ids = examples_tokenized["input_ids"] + labels = copy.deepcopy(input_ids) + for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): + label[:source_len] = IGNORE_INDEX + return dict(input_ids=input_ids, labels=labels) + +class AlpacaDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): + super(AlpacaDataset, self).__init__() + logger.info("Loading data...") + list_data_dict = jload(data_path) + + logger.info("Formatting inputs...") + prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] + sources = [ + prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) + for example in list_data_dict + ] + targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] + + logger.info("Tokenizing inputs... This may take some time...") + data_dict = preprocess(sources, targets, tokenizer) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict(input_ids=self.input_ids[i], labels=self.labels[i]) + +@dataclass +class AlpacaDataCollator(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + return dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) diff --git a/applications/ChatGPT/chatgpt/dataset/utils.py b/applications/ChatGPT/chatgpt/dataset/utils.py index 6c9f7f085..0e88cc8c3 100644 --- a/applications/ChatGPT/chatgpt/dataset/utils.py +++ b/applications/ChatGPT/chatgpt/dataset/utils.py @@ -1,5 +1,20 @@ +import io +import json + import torch.distributed as dist def is_rank_0() -> bool: return not dist.is_initialized() or dist.get_rank() == 0 + +def _make_r_io_base(f, mode: str): + if not isinstance(f, io.IOBase): + f = open(f, mode=mode) + return f + +def jload(f, mode="r"): + """Load a .json file into a dictionary.""" + f = _make_r_io_base(f, mode) + jdict = json.load(f) + f.close() + return jdict \ No newline at end of file diff --git a/applications/ChatGPT/chatgpt/models/llama/__init__.py b/applications/ChatGPT/chatgpt/models/llama/__init__.py index 9b2a024af..3edb51e14 100644 --- a/applications/ChatGPT/chatgpt/models/llama/__init__.py +++ b/applications/ChatGPT/chatgpt/models/llama/__init__.py @@ -1,5 +1,6 @@ from .llama_actor import LlamaActor from .llama_critic import LlamaCritic from .llama_rm import LlamaRM +from .llama_lm import LlamaLM -__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM'] +__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM'] diff --git a/applications/ChatGPT/chatgpt/models/llama/llama_lm.py b/applications/ChatGPT/chatgpt/models/llama/llama_lm.py new file mode 100644 index 000000000..c63077b1a --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/llama/llama_lm.py @@ -0,0 +1,38 @@ +from typing import Optional + +from transformers import LlamaConfig, LlamaForCausalLM + +from ..base import LM + + +class LlamaLM(LM): + """ + Llama language model. + + Args: + pretrained (str): Pretrained model name or path. + config (LlamaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + + if pretrained is not None: + model = LlamaForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = LlamaForCausalLM(config) + else: + model = LlamaForCausalLM(LlamaConfig()) + + if checkpoint: + model.gradient_checkpointing_enable() + + super().__init__(model, lora_rank, lora_train_bias) + diff --git a/applications/ChatGPT/chatgpt/trainer/sft.py b/applications/ChatGPT/chatgpt/trainer/sft.py index e3913d46b..dd5cd35f5 100644 --- a/applications/ChatGPT/chatgpt/trainer/sft.py +++ b/applications/ChatGPT/chatgpt/trainer/sft.py @@ -2,7 +2,6 @@ from abc import ABC from typing import Optional import loralib as lora import torch -from chatgpt.dataset import SFTDataset from chatgpt.models.loss import GPTLMLoss from torch.optim import Adam, Optimizer from torch.utils.data import DataLoader @@ -22,8 +21,8 @@ class SFTTrainer(ABC): model (torch.nn.Module): the model to train strategy (Strategy): the strategy to use for training optim(Optimizer): the optimizer to use for training - train_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for training - eval_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for evaluation + train_dataloader: the dataloader to use for training + eval_dataloader: the dataloader to use for evaluation batch_size (int, defaults to 1): the batch size while training max_epochs (int, defaults to 2): the number of epochs to train optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer @@ -34,8 +33,8 @@ class SFTTrainer(ABC): model, strategy: Strategy, optim: Optimizer, - train_dataset: SFTDataset, - eval_dataset: SFTDataset, + train_dataloader: DataLoader, + eval_dataloader: DataLoader = None, sampler: Optional[DistributedSampler] = None, batch_size: int = 1, max_epochs: int = 2, @@ -43,13 +42,10 @@ class SFTTrainer(ABC): super().__init__() self.strategy = strategy self.epochs = max_epochs - self.train_dataset = train_dataset - self.eval_dataset = eval_dataset self.sampler = sampler - self.train_dataloader = DataLoader(self.train_dataset, shuffle=(sampler is None), - sampler=sampler, batch_size=batch_size) - self.eval_dataloader = DataLoader(self.eval_dataset, batch_size=batch_size) + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader self.model = strategy.setup_model(model) if "DDP" in str(self.strategy): @@ -79,23 +75,25 @@ class SFTTrainer(ABC): logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}') # eval - self.model.eval() - with torch.no_grad(): - loss_sum = 0 - num_seen = 0 - for batch in self.eval_dataloader: - prompt_ids = batch["input_ids"] - p_mask = batch["attention_mask"] - prompt_ids = prompt_ids.squeeze(1).cuda() - p_mask = p_mask.squeeze(1).cuda() + if self.eval_dataloader is not None: + self.model.eval() + with torch.no_grad(): + loss_sum = 0 + num_seen = 0 + for batch in self.eval_dataloader: + prompt_ids = batch["input_ids"] + p_mask = batch["attention_mask"] + prompt_ids = prompt_ids.squeeze(1).cuda() + p_mask = p_mask.squeeze(1).cuda() - prompt_logits = self.model(prompt_ids, attention_mask=p_mask) - loss = self.loss_fn(prompt_logits, prompt_ids) - loss_sum += loss.item() - num_seen += prompt_ids.size(0) + prompt_logits = self.model(prompt_ids, attention_mask=p_mask) + loss = self.loss_fn(prompt_logits, prompt_ids) + loss_sum += loss.item() + num_seen += prompt_ids.size(0) - loss_mean = loss_sum / num_seen - if dist.get_rank() == 0: - logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}') + loss_mean = loss_sum / num_seen + if dist.get_rank() == 0: + logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}') + epoch_bar.update() diff --git a/applications/ChatGPT/chatgpt/utils/__init__.py b/applications/ChatGPT/chatgpt/utils/__init__.py new file mode 100644 index 000000000..8f526d7ef --- /dev/null +++ b/applications/ChatGPT/chatgpt/utils/__init__.py @@ -0,0 +1,3 @@ +from .tokenizer_utils import smart_tokenizer_and_embedding_resize, prepare_llama_tokenizer_and_embedding + +__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding'] \ No newline at end of file diff --git a/applications/ChatGPT/chatgpt/utils/tokenizer_utils.py b/applications/ChatGPT/chatgpt/utils/tokenizer_utils.py new file mode 100644 index 000000000..8699bf64c --- /dev/null +++ b/applications/ChatGPT/chatgpt/utils/tokenizer_utils.py @@ -0,0 +1,74 @@ +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import transformers + +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" + +def prepare_llama_tokenizer_and_embedding( + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, + special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN), +): + """prepare llama tokenizer and embedding. + + """ + + if tokenizer.pad_token is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), + tokenizer=tokenizer, + model=model, + ) + + tokenizer.add_special_tokens( + { + "eos_token": DEFAULT_EOS_TOKEN, + "bos_token": DEFAULT_BOS_TOKEN, + "unk_token": DEFAULT_UNK_TOKEN, + } + ) + + return tokenizer + + +def smart_tokenizer_and_embedding_resize( + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, + special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN), +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + + if tokenizer.pad_token is None: + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + \ No newline at end of file diff --git a/applications/ChatGPT/examples/train_sft.py b/applications/ChatGPT/examples/train_sft.py index 4b3f85a2a..83b34f9dd 100644 --- a/applications/ChatGPT/examples/train_sft.py +++ b/applications/ChatGPT/examples/train_sft.py @@ -4,15 +4,18 @@ import loralib as lora import torch import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler -from chatgpt.dataset import SFTDataset +from chatgpt.dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator from chatgpt.models.base import RewardModel from chatgpt.models.bloom import BLOOMLM from chatgpt.models.gpt import GPTLM from chatgpt.models.opt import OPTLM +from chatgpt.models.llama import LlamaLM from chatgpt.trainer import SFTTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from chatgpt.utils import prepare_llama_tokenizer_and_embedding from datasets import load_dataset from torch.optim import Adam +from torch.utils.data import DataLoader from transformers import AutoTokenizer, BloomTokenizerFast from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer @@ -41,6 +44,8 @@ def train(args): model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() elif args.model == 'gpt2': model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + elif args.model == 'llama': + model = LlamaLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() else: raise ValueError(f'Unsupported model "{args.model}"') @@ -53,9 +58,19 @@ def train(args): tokenizer.pad_token = tokenizer.eos_token elif args.model == 'opt': tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif args.model == 'llama': + tokenizer = AutoTokenizer.from_pretrained( + args.pretrain, + padding_side="right", + use_fast=False, + ) else: raise ValueError(f'Unsupported model "{args.model}"') - tokenizer.pad_token = tokenizer.eos_token + + if args.model == 'llama': + tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) + else: + tokenizer.pad_token = tokenizer.eos_token max_len = 512 @@ -67,11 +82,19 @@ def train(args): logger = get_dist_logger() - train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train') - eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test') + # configure dataset + if args.dataset == 'yizhongw/self_instruct': + train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train') + eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test') - train_dataset = SFTDataset(train_data, tokenizer, max_len) - eval_dataset = SFTDataset(eval_data, tokenizer, max_len) + train_dataset = SFTDataset(train_data, tokenizer, max_len) + eval_dataset = SFTDataset(eval_data, tokenizer, max_len) + + elif 'alpaca' in args.dataset: + train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.dataset) + eval_dataset = None + eval_dataset + data_collator = AlpacaDataCollator(tokenizer=tokenizer) if dist.is_initialized() and dist.get_world_size() > 1: sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True) @@ -79,11 +102,15 @@ def train(args): else: sampler = None + train_dataloader = DataLoader(train_dataset, shuffle=(sampler is None), sampler=sampler, batch_size=args.batch_size) + if eval_dataset is not None: + eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size) + trainer = SFTTrainer(model=model, strategy=strategy, optim=optim, - train_dataset=train_dataset, - eval_dataset=eval_dataset, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, sampler=sampler, batch_size=args.batch_size, max_epochs=args.max_epochs)