mirror of https://github.com/hpcaitech/ColossalAI
[Chat]Add Peft support & fix the ptx bug (#3433)
* Update ppo.py Fix the bug of fetching wrong batch data * Add peft model support in SFT and Prompts training In stage-1 and stage-3, the peft model supports are added. So the trained artifacts will be only a small lora additions instead of the whole bunch of files. * Delete test_prompts.txt * Delete test_pretrained.txt * Move the peft stuffs to a community folder. * Move the demo sft to community * delete dirty files * Add instructions to install peft using source * Remove Chinese comments * remove the Chinese commentspull/3368/head
parent
73afb63594
commit
62f4e2eb07
|
@ -92,9 +92,10 @@ class PPOTrainer(Trainer):
|
|||
|
||||
# ptx loss
|
||||
if self.ptx_coef != 0:
|
||||
ptx = next(iter(self.pretrain_dataloader))['input_ids'].to(torch.cuda.current_device())
|
||||
label = next(iter(self.pretrain_dataloader))['labels'].to(torch.cuda.current_device())[:, 1:]
|
||||
attention_mask = next(iter(self.pretrain_dataloader))['attention_mask'].to(torch.cuda.current_device())
|
||||
batch = next(iter(self.pretrain_dataloader))
|
||||
ptx = batch['input_ids'].to(torch.cuda.current_device())
|
||||
label = batch['labels'].to(torch.cuda.current_device())[:, 1:]
|
||||
attention_mask = batch['attention_mask'].to(torch.cuda.current_device())
|
||||
ptx_log_probs = self.actor.get_base_model()(ptx, attention_mask=attention_mask)['logits'][..., :-1, :]
|
||||
ptx_loss = self.ptx_loss_fn(ptx_log_probs.view(-1, ptx_log_probs.size(-1)), label.view(-1))
|
||||
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Add Peft support for SFT and Prompts model training
|
||||
|
||||
The orginal implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed.
|
||||
|
||||
Since reward model is relative small, I just keep it as original one. I suggest train full model to get the proper reward/critic model.
|
||||
|
||||
# Prelimenary installation
|
||||
Since the current pypi peft package(0.2) has some bugs, please install the peft package using source.
|
||||
```
|
||||
git clone https://github.com/huggingface/peft
|
||||
cd peft
|
||||
pip install .
|
||||
```
|
||||
|
||||
# Usage
|
||||
For SFT training, just call train_peft_sft.py
|
||||
|
||||
Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
|
||||
|
||||
For stage-3 rlhf training, call train_peft_prompts.py.
|
||||
Its arguments are almost idential to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
|
||||
|
||||
# Dataformat
|
||||
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt.
|
|
@ -0,0 +1,242 @@
|
|||
import copy
|
||||
from typing import Dict, Sequence
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :int = 512) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
) for text in strings
|
||||
]
|
||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||
input_ids_lens = labels_lens = [
|
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
||||
]
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
|
||||
def preprocess(
|
||||
sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: AutoTokenizer,
|
||||
max_length :int = 512
|
||||
) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer,max_length) for strings in (examples, sources)]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
label[:source_len] = IGNORE_INDEX
|
||||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
|
||||
class EasySupervisedDataset(Dataset):
|
||||
def __init__(self, data_file :str, tokenizer :AutoTokenizer,max_length :int = 512) -> None:
|
||||
super(EasySupervisedDataset,self).__init__()
|
||||
with open(data_file,"r",encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
||||
sources,targets = [],[]
|
||||
for line in all_lines:
|
||||
if "回答:" in line:
|
||||
sep_index = line.index("回答:")
|
||||
sources.append(line[:sep_index+3])
|
||||
targets.append(line[sep_index+3:]+tokenizer.eos_token)
|
||||
else:
|
||||
sources.append(line)
|
||||
targets.append(""+tokenizer.eos_token)
|
||||
data_dict = preprocess(sources, targets, tokenizer,max_length)
|
||||
|
||||
self.input_ids = data_dict["input_ids"]
|
||||
self.labels = data_dict["labels"]
|
||||
self.data_file = data_file
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
|
||||
def __repr__(self):
|
||||
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
|
||||
|
||||
class EasyPromptsDataset(Dataset):
|
||||
def __init__(self,data_file :str, tokenizer :AutoTokenizer, max_length :int = 96) -> None:
|
||||
super(EasyPromptsDataset,self).__init__()
|
||||
with open(data_file,"r",encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
all_lines = [line if "回答:" not in line else line[:line.index("回答:")+3] for line in all_lines]
|
||||
self.prompts = [
|
||||
tokenizer(line,
|
||||
return_tensors='pt',
|
||||
max_length=max_length,
|
||||
padding='max_length',
|
||||
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
|
||||
for line in tqdm(all_lines)
|
||||
]
|
||||
self.data_file = data_file
|
||||
def __len__(self):
|
||||
return len(self.prompts)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.prompts[idx]
|
||||
|
||||
def __repr__(self):
|
||||
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
|
||||
|
||||
|
||||
class EasyRewardDataset(Dataset):
|
||||
def __init__(self,train_file :str,tokenizer :AutoTokenizer, special_token = None,max_length = 512) -> None:
|
||||
super(EasyRewardDataset,self).__init__()
|
||||
self.chosen = []
|
||||
self.reject = []
|
||||
if special_token is None:
|
||||
self.end_token = tokenizer.eos_token
|
||||
else:
|
||||
self.end_token = special_token
|
||||
print(self.end_token)
|
||||
#read all lines in the train_file to a list
|
||||
with open(train_file,"r",encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
for line in tqdm(all_lines):
|
||||
data = json.loads(line)
|
||||
prompt = "提问:"+data['prompt']+" 回答:"
|
||||
|
||||
chosen = prompt + data['chosen'] + self.end_token
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen.append({
|
||||
"input_ids": chosen_token['input_ids'],
|
||||
"attention_mask": chosen_token['attention_mask']
|
||||
})
|
||||
|
||||
reject = prompt + data['rejected'] + self.end_token
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject.append({
|
||||
"input_ids": reject_token['input_ids'],
|
||||
"attention_mask": reject_token['attention_mask']
|
||||
})
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.chosen)
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
||||
"input_ids"], self.reject[idx]["attention_mask"]
|
||||
|
||||
#python representation of the object and the string representation of the object
|
||||
def __repr__(self):
|
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||
|
||||
'''
|
||||
Easy SFT just accept a text file which can be read line by line. However the datasest will group texts together to max_length so LLM will learn the texts meaning better.
|
||||
If individual lines are not related, just set is_group_texts to False.
|
||||
'''
|
||||
class EasySFTDataset(Dataset):
|
||||
|
||||
def __init__(self,data_file :str,tokenizer :AutoTokenizer,max_length = 512,is_group_texts = True) -> None:
|
||||
super().__init__()
|
||||
#read the data_file line by line
|
||||
with open(data_file,"r",encoding="UTF-8") as f:
|
||||
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
||||
raw_input_ids = []
|
||||
for line in f:
|
||||
encoded_ids = tokenizer.encode(line)
|
||||
#if the encoded_ids is longer than max_length, then split it into several parts
|
||||
if len(encoded_ids) > max_length:
|
||||
for i in range(0,len(encoded_ids),max_length):
|
||||
raw_input_ids.append(encoded_ids[i:i+max_length])
|
||||
else:
|
||||
raw_input_ids.append(encoded_ids)
|
||||
|
||||
grouped_inpup_ids = []
|
||||
current_input_ids = []
|
||||
attention_mask = []
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
if is_group_texts:
|
||||
for input_ids in raw_input_ids:
|
||||
if len(current_input_ids) + len(input_ids) > max_length:
|
||||
#pad the current_input_ids to max_length with tokenizer.pad_token_id
|
||||
padded_length = max_length - len(current_input_ids)
|
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long))
|
||||
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
|
||||
current_input_ids = []
|
||||
else:
|
||||
current_input_ids.extend(input_ids)
|
||||
if len(current_input_ids) > 0:
|
||||
padded_length = max_length - len(current_input_ids)
|
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long))
|
||||
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
|
||||
else:
|
||||
#just append the raw_input_ids to max_length
|
||||
for input_ids in raw_input_ids:
|
||||
padded_length = max_length - len(input_ids)
|
||||
input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
|
||||
grouped_inpup_ids.append(torch.tensor(input_ids,dtype=torch.long))
|
||||
self.input_ids = grouped_inpup_ids
|
||||
self.labels = copy.deepcopy(self.input_ids)
|
||||
self.file_name = data_file
|
||||
self.attention_mask = attention_mask
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
#get item from dataset
|
||||
def __getitem__(self,idx):
|
||||
return dict(input_ids=self.input_ids[idx],labels=self.labels[idx],attention_mask=self.attention_mask[idx])
|
||||
|
||||
#generate the dataset description to be printed by print in python
|
||||
def __repr__(self):
|
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||
|
||||
#generate the dataset description to be printed by print in python
|
||||
def __str__(self):
|
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules import Module
|
||||
|
||||
from coati.models.generation import generate
|
||||
from coati.models.utils import log_probs_from_logits,masked_mean
|
||||
from transformers import BloomConfig,BloomForCausalLM
|
||||
from peft import PeftModel
|
||||
|
||||
class Actor(Module):
|
||||
"""
|
||||
Actor model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Actor Model.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
sequences = generate(self.model, input_ids, **kwargs)
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
|
||||
def forward(self,
|
||||
sequences: torch.LongTensor,
|
||||
num_actions: int,
|
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Returns action log probs
|
||||
"""
|
||||
output = self.model(sequences, attention_mask=attention_mask)
|
||||
logits = output['logits']
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
def get_base_model(self):
|
||||
return self.model
|
||||
|
||||
|
||||
class BLOOMActor(Actor):
|
||||
"""
|
||||
BLOOM Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_path: str = None) -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomForCausalLM(config)
|
||||
else:
|
||||
model = BloomForCausalLM(BloomConfig())
|
||||
if lora_path is not None:
|
||||
model = PeftModel.from_pretrained(model,lora_path)
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model)
|
||||
|
||||
def print_trainable_parameters(self):
|
||||
self.get_base_model().print_trainable_parameters()
|
||||
|
|
@ -0,0 +1,227 @@
|
|||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
|
||||
from coati.models.bloom import BLOOMRM, BLOOMCritic
|
||||
from easy_models import BLOOMActor
|
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from peft import PeftModel
|
||||
from easy_dataset import EasyPromptsDataset,EasySupervisedDataset
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
state_dict = torch.load(args.rm_path, map_location='cpu')
|
||||
|
||||
# configure model
|
||||
if args.model == 'bloom':
|
||||
# initial_model = BLOOMActor(pretrained=args.pretrain)
|
||||
print('Using peft lora to load Bloom model as inital_model')
|
||||
initial_model = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path)
|
||||
print('Using peft lora to load Bloom model as initial_model (Done)')
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
||||
if args.rm_model == None:
|
||||
rm_model_name = args.model
|
||||
else:
|
||||
rm_model_name = args.rm_model
|
||||
|
||||
if rm_model_name == 'gpt2':
|
||||
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == 'bloom':
|
||||
print("load bloom reward model ",args.rm_pretrain)
|
||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == 'opt':
|
||||
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == 'llama':
|
||||
reward_model = LlamaRM(pretrained=args.rm_pretrain)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
print('Loading reward model from', args.rm_path)
|
||||
reward_model.load_state_dict(state_dict)
|
||||
|
||||
if args.strategy != 'colossalai_gemini':
|
||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'bloom':
|
||||
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
print('Using peft lora to load Bloom model as Actor')
|
||||
actor = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path)
|
||||
print('Using peft lora to load Bloom model as Actor (Done)')
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
||||
if rm_model_name == 'gpt2':
|
||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif rm_model_name == 'bloom':
|
||||
print("load bloom critic ",args.rm_pretrain," lora_rank ",args.lora_rank," use_action_mask ",True)
|
||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
print("load bloom critic (Done) ")
|
||||
elif rm_model_name == 'opt':
|
||||
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif rm_model_name == 'llama':
|
||||
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
print('Loading reward model from', args.rm_path)
|
||||
critic.load_state_dict(state_dict)
|
||||
del state_dict
|
||||
|
||||
if args.strategy != 'colossalai_gemini':
|
||||
critic.to(torch.float16).to(torch.cuda.current_device())
|
||||
actor.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=1e-7)
|
||||
critic_optim = Adam(critic.parameters(), lr=1e-7)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
|
||||
elif args.model == 'llama':
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
if args.model == 'llama':
|
||||
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor)
|
||||
else:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
|
||||
prompt_dataset = EasyPromptsDataset(args.prompt_path,tokenizer)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
||||
prompt_sampler = None
|
||||
prompt_dataloader = DataLoader(prompt_dataset,
|
||||
shuffle=(prompt_sampler is None),
|
||||
sampler=prompt_sampler,
|
||||
batch_size=args.train_batch_size)
|
||||
|
||||
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
||||
pretrain_sampler = None
|
||||
pretrain_dataloader = DataLoader(pretrain_dataset,
|
||||
shuffle=(pretrain_sampler is None),
|
||||
sampler=pretrain_sampler,
|
||||
batch_size=args.ptx_batch_size,
|
||||
collate_fn=data_collator)
|
||||
|
||||
def tokenize_fn(texts):
|
||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||
# Different length may lead to hang when using gemini, as different generation steps
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
kl_coef=args.kl_coef,
|
||||
ptx_coef=args.ptx_coef,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=tokenize_fn,
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset')
|
||||
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive',
|
||||
help='strategy to use')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--sft_lora_path', type=str, default=None)
|
||||
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--rm_path', type=str, default=None)
|
||||
parser.add_argument('--rm_pretrain', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=2)
|
||||
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--kl_coef', type=float, default=0.1)
|
||||
parser.add_argument('--ptx_coef', type=float, default=0.9)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,187 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.bloom import BLOOMLM
|
||||
from coati.models.gpt import GPTLM
|
||||
from coati.models.llama import LlamaLM
|
||||
from coati.models.opt import OPTLM
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast,AutoModelForCausalLM
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ColoParameter
|
||||
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from peft import LoraConfig, TaskType,get_peft_model,PeftModel
|
||||
from easy_dataset import EasyDataset
|
||||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested')
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
|
||||
#if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
|
||||
if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \
|
||||
and os.path.exists(args.save_path+'/adapter_model.bin'):
|
||||
print("loading from saved peft model ",args.save_path)
|
||||
model = PeftModel.from_pretrained(model, args.save_path)
|
||||
else:
|
||||
#we'll use peft lora library to do the lora
|
||||
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
|
||||
#config lora with rank of lora_rank
|
||||
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1)
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
elif args.model == 'llama':
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrain,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if args.model == 'llama':
|
||||
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
|
||||
|
||||
if args.strategy == 'colossalai_gemini':
|
||||
# this is a hack to deal with the resized embedding
|
||||
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity
|
||||
for name, param in model.named_parameters():
|
||||
if not isinstance(param, ColoParameter):
|
||||
sub_module_name = '.'.join(name.split('.')[:-1])
|
||||
weight_name = name.split('.')[-1]
|
||||
sub_module = model.get_submodule(sub_module_name)
|
||||
setattr(sub_module, weight_name, ColoParameter(param))
|
||||
else:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
||||
else:
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.set_level('WARNING')
|
||||
|
||||
# configure dataset
|
||||
law_dataset = EasyDataset(args.dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text)
|
||||
train_dataset = law_dataset
|
||||
print(train_dataset)
|
||||
eval_dataset = None
|
||||
if args.eval_dataset is not None:
|
||||
eval_dataset = EasyDataset(args.eval_dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text)
|
||||
data_collator = default_collate
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
train_sampler = DistributedSampler(train_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
if eval_dataset is not None:
|
||||
eval_sampler = DistributedSampler(eval_dataset,
|
||||
shuffle=False,
|
||||
seed=42,
|
||||
drop_last=False,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
else:
|
||||
train_sampler = None
|
||||
eval_sampler = None
|
||||
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True)
|
||||
if eval_dataset is not None:
|
||||
eval_dataloader = DataLoader(eval_dataset,
|
||||
shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True)
|
||||
else:
|
||||
eval_dataloader = None
|
||||
|
||||
trainer = SFTTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
accimulation_steps=args.accimulation_steps)
|
||||
|
||||
trainer.fit(logger=logger, log_interval=args.log_interval)
|
||||
|
||||
# save model checkpoint after fitting on only rank0
|
||||
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(trainer.optimizer,
|
||||
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default=None)
|
||||
parser.add_argument('--eval_dataset', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='output')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--batch_size', type=int, default=4)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument('--lr', type=float, default=5e-6)
|
||||
parser.add_argument('--accimulation_steps', type=int, default=8)
|
||||
parser.add_argument('--enable_peft_lora',action='store_true', default=False)
|
||||
parser.add_argument("--is_short_text",action='store_true', default=False)
|
||||
args = parser.parse_args()
|
||||
train(args)
|
Loading…
Reference in New Issue