[chatgpt] add supervised learning fine-tune code (#3183)

* [chatgpt] add supervised fine-tune code

* [chatgpt] delete unused code and modified comment code

* [chatgpt] use pytorch distributed sampler instead

---------

Co-authored-by: zhangpengpeng <zhangpengpeng@joyy.com>
pull/3197/head
pgzhang 2 years ago committed by GitHub
parent e7f3bed2d3
commit b429529365
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,5 @@
from .reward_dataset import RmStaticDataset, HhRlhfDataset
from .utils import is_rank_0
from .sft_dataset import SFTDataset
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0']
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset']

@ -0,0 +1,40 @@
from typing import Callable
import random
from torch.utils.data import Dataset
import torch.distributed as dist
from tqdm import tqdm
import torch
from .utils import is_rank_0
class SFTDataset(Dataset):
"""
Dataset for sft model
Args:
dataset: dataset for supervised model
tokenizer: tokenizer for supervised model
max_length: max length of input
"""
def __init__(self, dataset, tokenizer: Callable, max_length: int=512) -> None:
super().__init__()
self.prompts = []
for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt'] + data['completion'] + "<|endoftext|>"
prompt_token = tokenizer(prompt,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.prompts.append(prompt_token)
def __len__(self):
length = len(self.prompts)
return length
def __getitem__(self, idx):
return self.prompts[idx]

@ -1,5 +1,6 @@
from .actor import Actor
from .critic import Critic
from .reward_model import RewardModel
from .lm import LM
__all__ = ['Actor', 'Critic', 'RewardModel']
__all__ = ['Actor', 'Critic', 'RewardModel', 'LM']

@ -0,0 +1,33 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..generation import generate
from .actor import Actor
class LM(Actor):
"""
Language model base class.
Args:
model (nn.Module): Language Model.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
super().__init__(model=model, lora_rank=lora_rank, lora_train_bias=lora_train_bias)
def forward(self,
sequences: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Returns output log probs
"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits']
log_probs = F.log_softmax(logits, dim=-1)
return log_probs

@ -1,5 +1,6 @@
from .bloom_actor import BLOOMActor
from .bloom_critic import BLOOMCritic
from .bloom_rm import BLOOMRM
from .bloom_lm import BLOOMLM
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM']

@ -0,0 +1,36 @@
from typing import Optional
import torch
from transformers import BloomConfig, BloomForCausalLM, BloomModel
from ..base import LM
class BLOOMLM(LM):
"""
BLOOM language model.
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = BloomForCausalLM(config)
else:
model = BloomForCausalLM(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

@ -1,5 +1,6 @@
from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
from .gpt_lm import GPTLM
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM']

@ -0,0 +1,36 @@
from typing import Optional
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from ..base import LM
class GPTLM(LM):
"""
GPT language model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LoRa layer.
lora_train_bias (str): Bias training strategy for the LoRa layer.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None:
model = GPT2LMHeadModel(config)
else:
model = GPT2LMHeadModel(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

@ -1,5 +1,6 @@
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
from .opt_lm import OPTLM
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM']

@ -0,0 +1,36 @@
from typing import Optional
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from ..base import LM
class OPTLM(LM):
"""
OPT language model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

@ -1,5 +1,6 @@
from .base import Trainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer']
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']

@ -0,0 +1,101 @@
from abc import ABC
from typing import Optional
import loralib as lora
import torch
from chatgpt.dataset import SFTDataset
from chatgpt.models.loss import GPTLMLoss
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
import torch.distributed as dist
from .strategies import Strategy
from .utils import is_rank_0
from colossalai.logging import get_dist_logger
class SFTTrainer(ABC):
"""
Trainer to use while training reward model.
Args:
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for training
eval_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
"""
def __init__(
self,
model,
strategy: Strategy,
optim: Optimizer,
train_dataset: SFTDataset,
eval_dataset: SFTDataset,
sampler: Optional[DistributedSampler] = None,
batch_size: int = 1,
max_epochs: int = 2,
) -> None:
super().__init__()
self.strategy = strategy
self.epochs = max_epochs
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.sampler = sampler
self.train_dataloader = DataLoader(self.train_dataset, shuffle=(sampler is None),
sampler=sampler, batch_size=batch_size)
self.eval_dataloader = DataLoader(self.eval_dataset, batch_size=batch_size)
self.model = strategy.setup_model(model)
if "DDP" in str(self.strategy):
self.model = self.model.module
self.loss_fn = GPTLMLoss()
self.optimizer = strategy.setup_optimizer(optim, self.model)
def fit(self, logger, use_lora, log_interval=10):
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
for epoch in range(self.epochs):
if isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch)
# train
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):
prompt_ids = batch["input_ids"]
p_mask = batch["attention_mask"]
prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()
prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
loss = self.loss_fn(prompt_logits, prompt_ids)
self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
if batch_id % log_interval == 0:
logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
# eval
self.model.eval()
with torch.no_grad():
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
prompt_ids = batch["input_ids"]
p_mask = batch["attention_mask"]
prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()
prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
loss = self.loss_fn(prompt_logits, prompt_ids)
loss_sum += loss.item()
num_seen += prompt_ids.size(0)
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
epoch_bar.update()

@ -0,0 +1,114 @@
import argparse
import loralib as lora
import torch
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from chatgpt.dataset import SFTDataset
from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMLM
from chatgpt.models.gpt import GPTLM
from chatgpt.models.opt import OPTLM
from chatgpt.trainer import SFTTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import get_dist_logger
def train(args):
# configure strategy
if args.strategy == 'naive':
strategy = NaiveStrategy()
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
if args.model == 'bloom':
model = BLOOMLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'opt':
model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'gpt2':
model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')
tokenizer.pad_token = tokenizer.eos_token
max_len = 512
# configure optimizer
if args.strategy.startswith('colossalai'):
optim = HybridAdam(model.parameters(), lr=5e-5)
else:
optim = Adam(model.parameters(), lr=5e-5)
logger = get_dist_logger()
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
train_dataset = SFTDataset(train_data, tokenizer, max_len)
eval_dataset = SFTDataset(eval_data, tokenizer, max_len)
if dist.is_initialized() and dist.get_world_size() > 1:
sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
logger.info("Using Distributed Sampler")
else:
sampler = None
trainer = SFTTrainer(model=model,
strategy=strategy,
optim=optim,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
sampler=sampler,
batch_size=args.batch_size,
max_epochs=args.max_epochs)
trainer.fit(logger=logger, use_lora=args.lora_rank, log_interval=args.log_interval)
# save model checkpoint after fitting on only rank0
strategy.save_model(model, 'sft_checkpoint.pt', only_rank0=True)
# save optimizer checkpoint on all ranks
strategy.save_optimizer(optim, 'sft_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='yizhongw/self_instruct')
parser.add_argument('--save_path', type=str, default='sft_ckpt.pth')
parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
args = parser.parse_args()
train(args)

@ -0,0 +1,20 @@
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 8
#torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2 --log_interval 10
#torchrun --standalone --nproc_per_node=8 train_sft.py --model 'gpt2' --strategy colossalai_zero2 --batch_size 1 --log_interval 10
torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2 --log_interval 10
Loading…
Cancel
Save