diff --git a/applications/ChatGPT/chatgpt/dataset/__init__.py b/applications/ChatGPT/chatgpt/dataset/__init__.py index 833930987..78fd2c070 100644 --- a/applications/ChatGPT/chatgpt/dataset/__init__.py +++ b/applications/ChatGPT/chatgpt/dataset/__init__.py @@ -1,4 +1,5 @@ from .reward_dataset import RmStaticDataset, HhRlhfDataset from .utils import is_rank_0 +from .sft_dataset import SFTDataset -__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0'] +__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset'] diff --git a/applications/ChatGPT/chatgpt/dataset/sft_dataset.py b/applications/ChatGPT/chatgpt/dataset/sft_dataset.py new file mode 100644 index 000000000..53ad20507 --- /dev/null +++ b/applications/ChatGPT/chatgpt/dataset/sft_dataset.py @@ -0,0 +1,40 @@ +from typing import Callable +import random +from torch.utils.data import Dataset +import torch.distributed as dist +from tqdm import tqdm +import torch + +from .utils import is_rank_0 + + +class SFTDataset(Dataset): + """ + Dataset for sft model + + Args: + dataset: dataset for supervised model + tokenizer: tokenizer for supervised model + max_length: max length of input + """ + + def __init__(self, dataset, tokenizer: Callable, max_length: int=512) -> None: + super().__init__() + self.prompts = [] + + for data in tqdm(dataset, disable=not is_rank_0()): + prompt = data['prompt'] + data['completion'] + "<|endoftext|>" + prompt_token = tokenizer(prompt, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + + self.prompts.append(prompt_token) + + def __len__(self): + length = len(self.prompts) + return length + + def __getitem__(self, idx): + return self.prompts[idx] diff --git a/applications/ChatGPT/chatgpt/models/base/__init__.py b/applications/ChatGPT/chatgpt/models/base/__init__.py index 86f403556..7c7b1ceba 100644 --- a/applications/ChatGPT/chatgpt/models/base/__init__.py +++ b/applications/ChatGPT/chatgpt/models/base/__init__.py @@ -1,5 +1,6 @@ from .actor import Actor from .critic import Critic from .reward_model import RewardModel +from .lm import LM -__all__ = ['Actor', 'Critic', 'RewardModel'] +__all__ = ['Actor', 'Critic', 'RewardModel', 'LM'] diff --git a/applications/ChatGPT/chatgpt/models/base/lm.py b/applications/ChatGPT/chatgpt/models/base/lm.py new file mode 100644 index 000000000..b6bd7aff8 --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/base/lm.py @@ -0,0 +1,33 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..generation import generate +from .actor import Actor + + +class LM(Actor): + """ + Language model base class. + + Args: + model (nn.Module): Language Model. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: + super().__init__(model=model, lora_rank=lora_rank, lora_train_bias=lora_train_bias) + + def forward(self, + sequences: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Returns output log probs + """ + output = self.model(sequences, attention_mask=attention_mask) + logits = output['logits'] + log_probs = F.log_softmax(logits, dim=-1) + return log_probs + diff --git a/applications/ChatGPT/chatgpt/models/bloom/__init__.py b/applications/ChatGPT/chatgpt/models/bloom/__init__.py index d0e7f7b1e..7d6d7753b 100644 --- a/applications/ChatGPT/chatgpt/models/bloom/__init__.py +++ b/applications/ChatGPT/chatgpt/models/bloom/__init__.py @@ -1,5 +1,6 @@ from .bloom_actor import BLOOMActor from .bloom_critic import BLOOMCritic from .bloom_rm import BLOOMRM +from .bloom_lm import BLOOMLM -__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM'] +__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM'] diff --git a/applications/ChatGPT/chatgpt/models/bloom/bloom_lm.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_lm.py new file mode 100644 index 000000000..81e17f27c --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_lm.py @@ -0,0 +1,36 @@ +from typing import Optional + +import torch +from transformers import BloomConfig, BloomForCausalLM, BloomModel + +from ..base import LM + + +class BLOOMLM(LM): + """ + BLOOM language model. + + Args: + pretrained (str): Pretrained model name or path. + config (BloomConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = BloomForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = BloomForCausalLM(config) + else: + model = BloomForCausalLM(BloomConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model, lora_rank, lora_train_bias) + diff --git a/applications/ChatGPT/chatgpt/models/gpt/__init__.py b/applications/ChatGPT/chatgpt/models/gpt/__init__.py index 63dc5ab0f..c6ae05113 100644 --- a/applications/ChatGPT/chatgpt/models/gpt/__init__.py +++ b/applications/ChatGPT/chatgpt/models/gpt/__init__.py @@ -1,5 +1,6 @@ from .gpt_actor import GPTActor from .gpt_critic import GPTCritic from .gpt_rm import GPTRM +from .gpt_lm import GPTLM -__all__ = ['GPTActor', 'GPTCritic', 'GPTRM'] +__all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM'] diff --git a/applications/ChatGPT/chatgpt/models/gpt/gpt_lm.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_lm.py new file mode 100644 index 000000000..5740c80d3 --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_lm.py @@ -0,0 +1,36 @@ +from typing import Optional + +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + +from ..base import LM + + +class GPTLM(LM): + """ + GPT language model. + + Args: + pretrained (str): Pretrained model name or path. + config (GPT2Config): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the LoRa layer. + lora_train_bias (str): Bias training strategy for the LoRa layer. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = GPT2LMHeadModel.from_pretrained(pretrained) + elif config is not None: + model = GPT2LMHeadModel(config) + else: + model = GPT2LMHeadModel(GPT2Config()) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model, lora_rank, lora_train_bias) + diff --git a/applications/ChatGPT/chatgpt/models/opt/__init__.py b/applications/ChatGPT/chatgpt/models/opt/__init__.py index 334f4df00..fccec3bdf 100644 --- a/applications/ChatGPT/chatgpt/models/opt/__init__.py +++ b/applications/ChatGPT/chatgpt/models/opt/__init__.py @@ -1,5 +1,6 @@ from .opt_actor import OPTActor from .opt_critic import OPTCritic from .opt_rm import OPTRM +from .opt_lm import OPTLM -__all__ = ['OPTActor', 'OPTCritic', 'OPTRM'] +__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM'] diff --git a/applications/ChatGPT/chatgpt/models/opt/opt_lm.py b/applications/ChatGPT/chatgpt/models/opt/opt_lm.py new file mode 100644 index 000000000..35bfe198a --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/opt/opt_lm.py @@ -0,0 +1,36 @@ +from typing import Optional + +from transformers.models.opt.configuration_opt import OPTConfig +from transformers.models.opt.modeling_opt import OPTForCausalLM + +from ..base import LM + + +class OPTLM(LM): + """ + OPT language model. + + Args: + pretrained (str): Pretrained model name or path. + config (OPTConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = OPTForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = OPTForCausalLM(config) + else: + model = OPTForCausalLM(OPTConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model, lora_rank, lora_train_bias) + diff --git a/applications/ChatGPT/chatgpt/trainer/__init__.py b/applications/ChatGPT/chatgpt/trainer/__init__.py index c47c76347..525b57bf2 100644 --- a/applications/ChatGPT/chatgpt/trainer/__init__.py +++ b/applications/ChatGPT/chatgpt/trainer/__init__.py @@ -1,5 +1,6 @@ from .base import Trainer from .ppo import PPOTrainer from .rm import RewardModelTrainer +from .sft import SFTTrainer -__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer'] +__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer'] diff --git a/applications/ChatGPT/chatgpt/trainer/sft.py b/applications/ChatGPT/chatgpt/trainer/sft.py new file mode 100644 index 000000000..e3913d46b --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/sft.py @@ -0,0 +1,101 @@ +from abc import ABC +from typing import Optional +import loralib as lora +import torch +from chatgpt.dataset import SFTDataset +from chatgpt.models.loss import GPTLMLoss +from torch.optim import Adam, Optimizer +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm +import torch.distributed as dist +from .strategies import Strategy +from .utils import is_rank_0 +from colossalai.logging import get_dist_logger + + +class SFTTrainer(ABC): + """ + Trainer to use while training reward model. + + Args: + model (torch.nn.Module): the model to train + strategy (Strategy): the strategy to use for training + optim(Optimizer): the optimizer to use for training + train_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for training + eval_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for evaluation + batch_size (int, defaults to 1): the batch size while training + max_epochs (int, defaults to 2): the number of epochs to train + optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer + """ + + def __init__( + self, + model, + strategy: Strategy, + optim: Optimizer, + train_dataset: SFTDataset, + eval_dataset: SFTDataset, + sampler: Optional[DistributedSampler] = None, + batch_size: int = 1, + max_epochs: int = 2, + ) -> None: + super().__init__() + self.strategy = strategy + self.epochs = max_epochs + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.sampler = sampler + + self.train_dataloader = DataLoader(self.train_dataset, shuffle=(sampler is None), + sampler=sampler, batch_size=batch_size) + self.eval_dataloader = DataLoader(self.eval_dataset, batch_size=batch_size) + + self.model = strategy.setup_model(model) + if "DDP" in str(self.strategy): + self.model = self.model.module + self.loss_fn = GPTLMLoss() + self.optimizer = strategy.setup_optimizer(optim, self.model) + + def fit(self, logger, use_lora, log_interval=10): + epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0()) + for epoch in range(self.epochs): + if isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + # train + self.model.train() + for batch_id, batch in enumerate(self.train_dataloader): + prompt_ids = batch["input_ids"] + p_mask = batch["attention_mask"] + prompt_ids = prompt_ids.squeeze(1).cuda() + p_mask = p_mask.squeeze(1).cuda() + prompt_logits = self.model(prompt_ids, attention_mask=p_mask) + + loss = self.loss_fn(prompt_logits, prompt_ids) + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer) + self.optimizer.zero_grad() + if batch_id % log_interval == 0: + logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}') + + # eval + self.model.eval() + with torch.no_grad(): + loss_sum = 0 + num_seen = 0 + for batch in self.eval_dataloader: + prompt_ids = batch["input_ids"] + p_mask = batch["attention_mask"] + prompt_ids = prompt_ids.squeeze(1).cuda() + p_mask = p_mask.squeeze(1).cuda() + + prompt_logits = self.model(prompt_ids, attention_mask=p_mask) + loss = self.loss_fn(prompt_logits, prompt_ids) + loss_sum += loss.item() + num_seen += prompt_ids.size(0) + + loss_mean = loss_sum / num_seen + if dist.get_rank() == 0: + logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}') + epoch_bar.update() + diff --git a/applications/ChatGPT/examples/train_sft.py b/applications/ChatGPT/examples/train_sft.py new file mode 100644 index 000000000..4b3f85a2a --- /dev/null +++ b/applications/ChatGPT/examples/train_sft.py @@ -0,0 +1,114 @@ +import argparse + +import loralib as lora +import torch +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +from chatgpt.dataset import SFTDataset +from chatgpt.models.base import RewardModel +from chatgpt.models.bloom import BLOOMLM +from chatgpt.models.gpt import GPTLM +from chatgpt.models.opt import OPTLM +from chatgpt.trainer import SFTTrainer +from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from datasets import load_dataset +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam +from colossalai.logging import get_dist_logger + + +def train(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + # configure model + with strategy.model_init_context(): + if args.model == 'bloom': + model = BLOOMLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + elif args.model == 'opt': + model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + elif args.model == 'gpt2': + model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + tokenizer.pad_token = tokenizer.eos_token + + max_len = 512 + + # configure optimizer + if args.strategy.startswith('colossalai'): + optim = HybridAdam(model.parameters(), lr=5e-5) + else: + optim = Adam(model.parameters(), lr=5e-5) + + logger = get_dist_logger() + + train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train') + eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test') + + train_dataset = SFTDataset(train_data, tokenizer, max_len) + eval_dataset = SFTDataset(eval_data, tokenizer, max_len) + + if dist.is_initialized() and dist.get_world_size() > 1: + sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True) + logger.info("Using Distributed Sampler") + else: + sampler = None + + trainer = SFTTrainer(model=model, + strategy=strategy, + optim=optim, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + sampler=sampler, + batch_size=args.batch_size, + max_epochs=args.max_epochs) + + trainer.fit(logger=logger, use_lora=args.lora_rank, log_interval=args.log_interval) + + # save model checkpoint after fitting on only rank0 + strategy.save_model(model, 'sft_checkpoint.pt', only_rank0=True) + # save optimizer checkpoint on all ranks + strategy.save_optimizer(optim, 'sft_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom') + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--dataset', type=str, default='yizhongw/self_instruct') + parser.add_argument('--save_path', type=str, default='sft_ckpt.pth') + parser.add_argument('--max_epochs', type=int, default=1) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") + args = parser.parse_args() + train(args) + diff --git a/applications/ChatGPT/examples/train_sft.sh b/applications/ChatGPT/examples/train_sft.sh new file mode 100755 index 000000000..9f747b246 --- /dev/null +++ b/applications/ChatGPT/examples/train_sft.sh @@ -0,0 +1,20 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 8 + +#torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2 --log_interval 10 +#torchrun --standalone --nproc_per_node=8 train_sft.py --model 'gpt2' --strategy colossalai_zero2 --batch_size 1 --log_interval 10 +torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2 --log_interval 10