mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* Update ppo.py Fix the bug of fetching wrong batch data * Add peft model support in SFT and Prompts training In stage-1 and stage-3, the peft model supports are added. So the trained artifacts will be only a small lora additions instead of the whole bunch of files. * Delete test_prompts.txt * Delete test_pretrained.txt * Move the peft stuffs to a community folder. * Move the demo sft to community * delete dirty files * Add instructions to install peft using source * Remove Chinese comments * remove the Chinese commentspull/3368/head
YY Lin
2 years ago
committed by
GitHub
6 changed files with 781 additions and 3 deletions
@ -0,0 +1,24 @@
|
||||
# Add Peft support for SFT and Prompts model training |
||||
|
||||
The orginal implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed. |
||||
|
||||
Since reward model is relative small, I just keep it as original one. I suggest train full model to get the proper reward/critic model. |
||||
|
||||
# Prelimenary installation |
||||
Since the current pypi peft package(0.2) has some bugs, please install the peft package using source. |
||||
``` |
||||
git clone https://github.com/huggingface/peft |
||||
cd peft |
||||
pip install . |
||||
``` |
||||
|
||||
# Usage |
||||
For SFT training, just call train_peft_sft.py |
||||
|
||||
Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py. |
||||
|
||||
For stage-3 rlhf training, call train_peft_prompts.py. |
||||
Its arguments are almost idential to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported. |
||||
|
||||
# Dataformat |
||||
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt. |
@ -0,0 +1,242 @@
|
||||
import copy |
||||
from typing import Dict, Sequence |
||||
from datasets import load_dataset |
||||
from torch.utils.data import Dataset |
||||
from transformers import AutoTokenizer |
||||
import torch |
||||
from tqdm import tqdm |
||||
import json |
||||
|
||||
from tqdm import tqdm |
||||
import json |
||||
|
||||
IGNORE_INDEX = -100 |
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :int = 512) -> Dict: |
||||
"""Tokenize a list of strings.""" |
||||
tokenized_list = [ |
||||
tokenizer( |
||||
text, |
||||
return_tensors="pt", |
||||
padding="longest", |
||||
max_length=max_length, |
||||
truncation=True, |
||||
) for text in strings |
||||
] |
||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] |
||||
input_ids_lens = labels_lens = [ |
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list |
||||
] |
||||
return dict( |
||||
input_ids=input_ids, |
||||
labels=labels, |
||||
input_ids_lens=input_ids_lens, |
||||
labels_lens=labels_lens, |
||||
) |
||||
|
||||
|
||||
def preprocess( |
||||
sources: Sequence[str], |
||||
targets: Sequence[str], |
||||
tokenizer: AutoTokenizer, |
||||
max_length :int = 512 |
||||
) -> Dict: |
||||
"""Preprocess the data by tokenizing.""" |
||||
examples = [s + t for s, t in zip(sources, targets)] |
||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer,max_length) for strings in (examples, sources)] |
||||
input_ids = examples_tokenized["input_ids"] |
||||
labels = copy.deepcopy(input_ids) |
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): |
||||
label[:source_len] = IGNORE_INDEX |
||||
return dict(input_ids=input_ids, labels=labels) |
||||
|
||||
|
||||
class EasySupervisedDataset(Dataset): |
||||
def __init__(self, data_file :str, tokenizer :AutoTokenizer,max_length :int = 512) -> None: |
||||
super(EasySupervisedDataset,self).__init__() |
||||
with open(data_file,"r",encoding="UTF-8") as f: |
||||
all_lines = f.readlines() |
||||
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" |
||||
sources,targets = [],[] |
||||
for line in all_lines: |
||||
if "回答:" in line: |
||||
sep_index = line.index("回答:") |
||||
sources.append(line[:sep_index+3]) |
||||
targets.append(line[sep_index+3:]+tokenizer.eos_token) |
||||
else: |
||||
sources.append(line) |
||||
targets.append(""+tokenizer.eos_token) |
||||
data_dict = preprocess(sources, targets, tokenizer,max_length) |
||||
|
||||
self.input_ids = data_dict["input_ids"] |
||||
self.labels = data_dict["labels"] |
||||
self.data_file = data_file |
||||
|
||||
def __len__(self): |
||||
return len(self.input_ids) |
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i]) |
||||
|
||||
def __repr__(self): |
||||
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})" |
||||
|
||||
def __str__(self): |
||||
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})" |
||||
|
||||
class EasyPromptsDataset(Dataset): |
||||
def __init__(self,data_file :str, tokenizer :AutoTokenizer, max_length :int = 96) -> None: |
||||
super(EasyPromptsDataset,self).__init__() |
||||
with open(data_file,"r",encoding="UTF-8") as f: |
||||
all_lines = f.readlines() |
||||
all_lines = [line if "回答:" not in line else line[:line.index("回答:")+3] for line in all_lines] |
||||
self.prompts = [ |
||||
tokenizer(line, |
||||
return_tensors='pt', |
||||
max_length=max_length, |
||||
padding='max_length', |
||||
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0) |
||||
for line in tqdm(all_lines) |
||||
] |
||||
self.data_file = data_file |
||||
def __len__(self): |
||||
return len(self.prompts) |
||||
|
||||
def __getitem__(self, idx): |
||||
return self.prompts[idx] |
||||
|
||||
def __repr__(self): |
||||
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})" |
||||
|
||||
def __str__(self): |
||||
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})" |
||||
|
||||
|
||||
class EasyRewardDataset(Dataset): |
||||
def __init__(self,train_file :str,tokenizer :AutoTokenizer, special_token = None,max_length = 512) -> None: |
||||
super(EasyRewardDataset,self).__init__() |
||||
self.chosen = [] |
||||
self.reject = [] |
||||
if special_token is None: |
||||
self.end_token = tokenizer.eos_token |
||||
else: |
||||
self.end_token = special_token |
||||
print(self.end_token) |
||||
#read all lines in the train_file to a list |
||||
with open(train_file,"r",encoding="UTF-8") as f: |
||||
all_lines = f.readlines() |
||||
for line in tqdm(all_lines): |
||||
data = json.loads(line) |
||||
prompt = "提问:"+data['prompt']+" 回答:" |
||||
|
||||
chosen = prompt + data['chosen'] + self.end_token |
||||
chosen_token = tokenizer(chosen, |
||||
max_length=max_length, |
||||
padding="max_length", |
||||
truncation=True, |
||||
return_tensors="pt") |
||||
self.chosen.append({ |
||||
"input_ids": chosen_token['input_ids'], |
||||
"attention_mask": chosen_token['attention_mask'] |
||||
}) |
||||
|
||||
reject = prompt + data['rejected'] + self.end_token |
||||
reject_token = tokenizer(reject, |
||||
max_length=max_length, |
||||
padding="max_length", |
||||
truncation=True, |
||||
return_tensors="pt") |
||||
self.reject.append({ |
||||
"input_ids": reject_token['input_ids'], |
||||
"attention_mask": reject_token['attention_mask'] |
||||
}) |
||||
|
||||
def __len__(self): |
||||
length = len(self.chosen) |
||||
return length |
||||
|
||||
def __getitem__(self, idx): |
||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ |
||||
"input_ids"], self.reject[idx]["attention_mask"] |
||||
|
||||
#python representation of the object and the string representation of the object |
||||
def __repr__(self): |
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" |
||||
|
||||
def __str__(self): |
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" |
||||
|
||||
''' |
||||
Easy SFT just accept a text file which can be read line by line. However the datasest will group texts together to max_length so LLM will learn the texts meaning better. |
||||
If individual lines are not related, just set is_group_texts to False. |
||||
''' |
||||
class EasySFTDataset(Dataset): |
||||
|
||||
def __init__(self,data_file :str,tokenizer :AutoTokenizer,max_length = 512,is_group_texts = True) -> None: |
||||
super().__init__() |
||||
#read the data_file line by line |
||||
with open(data_file,"r",encoding="UTF-8") as f: |
||||
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list |
||||
raw_input_ids = [] |
||||
for line in f: |
||||
encoded_ids = tokenizer.encode(line) |
||||
#if the encoded_ids is longer than max_length, then split it into several parts |
||||
if len(encoded_ids) > max_length: |
||||
for i in range(0,len(encoded_ids),max_length): |
||||
raw_input_ids.append(encoded_ids[i:i+max_length]) |
||||
else: |
||||
raw_input_ids.append(encoded_ids) |
||||
|
||||
grouped_inpup_ids = [] |
||||
current_input_ids = [] |
||||
attention_mask = [] |
||||
if tokenizer.pad_token_id is None: |
||||
tokenizer.pad_token_id = tokenizer.eos_token_id |
||||
if is_group_texts: |
||||
for input_ids in raw_input_ids: |
||||
if len(current_input_ids) + len(input_ids) > max_length: |
||||
#pad the current_input_ids to max_length with tokenizer.pad_token_id |
||||
padded_length = max_length - len(current_input_ids) |
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length) |
||||
grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long)) |
||||
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long)) |
||||
current_input_ids = [] |
||||
else: |
||||
current_input_ids.extend(input_ids) |
||||
if len(current_input_ids) > 0: |
||||
padded_length = max_length - len(current_input_ids) |
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length) |
||||
grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long)) |
||||
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long)) |
||||
else: |
||||
#just append the raw_input_ids to max_length |
||||
for input_ids in raw_input_ids: |
||||
padded_length = max_length - len(input_ids) |
||||
input_ids.extend([tokenizer.pad_token_id] * padded_length) |
||||
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long)) |
||||
grouped_inpup_ids.append(torch.tensor(input_ids,dtype=torch.long)) |
||||
self.input_ids = grouped_inpup_ids |
||||
self.labels = copy.deepcopy(self.input_ids) |
||||
self.file_name = data_file |
||||
self.attention_mask = attention_mask |
||||
|
||||
def __len__(self): |
||||
return len(self.input_ids) |
||||
|
||||
#get item from dataset |
||||
def __getitem__(self,idx): |
||||
return dict(input_ids=self.input_ids[idx],labels=self.labels[idx],attention_mask=self.attention_mask[idx]) |
||||
|
||||
#generate the dataset description to be printed by print in python |
||||
def __repr__(self): |
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" |
||||
|
||||
#generate the dataset description to be printed by print in python |
||||
def __str__(self): |
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" |
||||
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,97 @@
|
||||
from typing import Optional, Tuple, Union |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
from torch.nn.modules import Module |
||||
|
||||
from coati.models.generation import generate |
||||
from coati.models.utils import log_probs_from_logits,masked_mean |
||||
from transformers import BloomConfig,BloomForCausalLM |
||||
from peft import PeftModel |
||||
|
||||
class Actor(Module): |
||||
""" |
||||
Actor model base class. |
||||
|
||||
Args: |
||||
model (nn.Module): Actor Model. |
||||
""" |
||||
|
||||
def __init__(self, model: nn.Module) -> None: |
||||
super().__init__() |
||||
self.model = model |
||||
|
||||
@torch.no_grad() |
||||
def generate( |
||||
self, |
||||
input_ids: torch.Tensor, |
||||
return_action_mask: bool = True, |
||||
**kwargs |
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: |
||||
sequences = generate(self.model, input_ids, **kwargs) |
||||
attention_mask = None |
||||
pad_token_id = kwargs.get('pad_token_id', None) |
||||
if pad_token_id is not None: |
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) |
||||
if not return_action_mask: |
||||
return sequences, attention_mask, None |
||||
input_len = input_ids.size(1) |
||||
eos_token_id = kwargs.get('eos_token_id', None) |
||||
if eos_token_id is None: |
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool) |
||||
else: |
||||
# left padding may be applied, only mask action |
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 |
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input |
||||
action_mask[:, :input_len] = False |
||||
action_mask = action_mask[:, 1:] |
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] |
||||
|
||||
def forward(self, |
||||
sequences: torch.LongTensor, |
||||
num_actions: int, |
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
||||
"""Returns action log probs |
||||
""" |
||||
output = self.model(sequences, attention_mask=attention_mask) |
||||
logits = output['logits'] |
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) |
||||
return log_probs[:, -num_actions:] |
||||
|
||||
def get_base_model(self): |
||||
return self.model |
||||
|
||||
|
||||
class BLOOMActor(Actor): |
||||
""" |
||||
BLOOM Actor model. |
||||
|
||||
Args: |
||||
pretrained (str): Pretrained model name or path. |
||||
config (BloomConfig): Model config. |
||||
checkpoint (bool): Enable gradient checkpointing. |
||||
lora_rank (int): LoRA rank. |
||||
lora_train_bias (str): LoRA bias training mode. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
pretrained: str = None, |
||||
config: Optional[BloomConfig] = None, |
||||
checkpoint: bool = False, |
||||
lora_path: str = None) -> None: |
||||
if pretrained is not None: |
||||
model = BloomForCausalLM.from_pretrained(pretrained) |
||||
elif config is not None: |
||||
model = BloomForCausalLM(config) |
||||
else: |
||||
model = BloomForCausalLM(BloomConfig()) |
||||
if lora_path is not None: |
||||
model = PeftModel.from_pretrained(model,lora_path) |
||||
if checkpoint: |
||||
model.gradient_checkpointing_enable() |
||||
super().__init__(model) |
||||
|
||||
def print_trainable_parameters(self): |
||||
self.get_base_model().print_trainable_parameters() |
||||
|
@ -0,0 +1,227 @@
|
||||
import argparse |
||||
|
||||
import pandas as pd |
||||
import torch |
||||
import torch.distributed as dist |
||||
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset |
||||
from coati.models.bloom import BLOOMRM, BLOOMCritic |
||||
from easy_models import BLOOMActor |
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic |
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM |
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic |
||||
from coati.trainer import PPOTrainer |
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy |
||||
from coati.utils import prepare_llama_tokenizer_and_embedding |
||||
from torch.optim import Adam |
||||
from torch.utils.data import DataLoader |
||||
from torch.utils.data.distributed import DistributedSampler |
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer |
||||
|
||||
from colossalai.nn.optimizer import HybridAdam |
||||
from peft import PeftModel |
||||
from easy_dataset import EasyPromptsDataset,EasySupervisedDataset |
||||
|
||||
def main(args): |
||||
# configure strategy |
||||
if args.strategy == 'naive': |
||||
strategy = NaiveStrategy() |
||||
elif args.strategy == 'ddp': |
||||
strategy = DDPStrategy() |
||||
elif args.strategy == 'colossalai_gemini': |
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) |
||||
elif args.strategy == 'colossalai_zero2': |
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') |
||||
else: |
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"') |
||||
|
||||
if args.rm_path is not None: |
||||
state_dict = torch.load(args.rm_path, map_location='cpu') |
||||
|
||||
# configure model |
||||
if args.model == 'bloom': |
||||
# initial_model = BLOOMActor(pretrained=args.pretrain) |
||||
print('Using peft lora to load Bloom model as inital_model') |
||||
initial_model = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path) |
||||
print('Using peft lora to load Bloom model as initial_model (Done)') |
||||
else: |
||||
raise ValueError(f'Unsupported actor model "{args.model}"') |
||||
|
||||
if args.rm_model == None: |
||||
rm_model_name = args.model |
||||
else: |
||||
rm_model_name = args.rm_model |
||||
|
||||
if rm_model_name == 'gpt2': |
||||
reward_model = GPTRM(pretrained=args.rm_pretrain) |
||||
elif rm_model_name == 'bloom': |
||||
print("load bloom reward model ",args.rm_pretrain) |
||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain) |
||||
elif rm_model_name == 'opt': |
||||
reward_model = OPTRM(pretrained=args.rm_pretrain) |
||||
elif rm_model_name == 'llama': |
||||
reward_model = LlamaRM(pretrained=args.rm_pretrain) |
||||
else: |
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"') |
||||
|
||||
if args.rm_path is not None: |
||||
print('Loading reward model from', args.rm_path) |
||||
reward_model.load_state_dict(state_dict) |
||||
|
||||
if args.strategy != 'colossalai_gemini': |
||||
initial_model.to(torch.float16).to(torch.cuda.current_device()) |
||||
reward_model.to(torch.float16).to(torch.cuda.current_device()) |
||||
|
||||
with strategy.model_init_context(): |
||||
if args.model == 'bloom': |
||||
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) |
||||
print('Using peft lora to load Bloom model as Actor') |
||||
actor = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path) |
||||
print('Using peft lora to load Bloom model as Actor (Done)') |
||||
else: |
||||
raise ValueError(f'Unsupported actor model "{args.model}"') |
||||
|
||||
if rm_model_name == 'gpt2': |
||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) |
||||
elif rm_model_name == 'bloom': |
||||
print("load bloom critic ",args.rm_pretrain," lora_rank ",args.lora_rank," use_action_mask ",True) |
||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) |
||||
print("load bloom critic (Done) ") |
||||
elif rm_model_name == 'opt': |
||||
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) |
||||
elif rm_model_name == 'llama': |
||||
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) |
||||
else: |
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"') |
||||
|
||||
if args.rm_path is not None: |
||||
print('Loading reward model from', args.rm_path) |
||||
critic.load_state_dict(state_dict) |
||||
del state_dict |
||||
|
||||
if args.strategy != 'colossalai_gemini': |
||||
critic.to(torch.float16).to(torch.cuda.current_device()) |
||||
actor.to(torch.float16).to(torch.cuda.current_device()) |
||||
|
||||
# configure optimizer |
||||
if args.strategy.startswith('colossalai'): |
||||
actor_optim = HybridAdam(actor.parameters(), lr=1e-7) |
||||
critic_optim = HybridAdam(critic.parameters(), lr=1e-7) |
||||
else: |
||||
actor_optim = Adam(actor.parameters(), lr=1e-7) |
||||
critic_optim = Adam(critic.parameters(), lr=1e-7) |
||||
|
||||
# configure tokenizer |
||||
if args.model == 'gpt2': |
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain) |
||||
elif args.model == 'bloom': |
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain) |
||||
elif args.model == 'opt': |
||||
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain) |
||||
elif args.model == 'llama': |
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) |
||||
tokenizer.eos_token = '<\s>' |
||||
else: |
||||
raise ValueError(f'Unsupported model "{args.model}"') |
||||
|
||||
if args.model == 'llama': |
||||
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor) |
||||
else: |
||||
tokenizer.pad_token = tokenizer.eos_token |
||||
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
||||
|
||||
prompt_dataset = EasyPromptsDataset(args.prompt_path,tokenizer) |
||||
if dist.is_initialized() and dist.get_world_size() > 1: |
||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) |
||||
else: |
||||
prompt_sampler = None |
||||
prompt_dataloader = DataLoader(prompt_dataset, |
||||
shuffle=(prompt_sampler is None), |
||||
sampler=prompt_sampler, |
||||
batch_size=args.train_batch_size) |
||||
|
||||
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer) |
||||
if dist.is_initialized() and dist.get_world_size() > 1: |
||||
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) |
||||
else: |
||||
pretrain_sampler = None |
||||
pretrain_dataloader = DataLoader(pretrain_dataset, |
||||
shuffle=(pretrain_sampler is None), |
||||
sampler=pretrain_sampler, |
||||
batch_size=args.ptx_batch_size, |
||||
collate_fn=data_collator) |
||||
|
||||
def tokenize_fn(texts): |
||||
# MUST padding to max length to ensure inputs of all ranks have the same length |
||||
# Different length may lead to hang when using gemini, as different generation steps |
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) |
||||
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} |
||||
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) |
||||
|
||||
# configure trainer |
||||
trainer = PPOTrainer( |
||||
strategy, |
||||
actor, |
||||
critic, |
||||
reward_model, |
||||
initial_model, |
||||
actor_optim, |
||||
critic_optim, |
||||
kl_coef=args.kl_coef, |
||||
ptx_coef=args.ptx_coef, |
||||
max_epochs=args.max_epochs, |
||||
train_batch_size=args.train_batch_size, |
||||
experience_batch_size=args.experience_batch_size, |
||||
tokenizer=tokenize_fn, |
||||
max_length=512, |
||||
do_sample=True, |
||||
temperature=1.0, |
||||
top_k=50, |
||||
pad_token_id=tokenizer.pad_token_id, |
||||
eos_token_id=tokenizer.eos_token_id, |
||||
) |
||||
|
||||
trainer.fit(prompt_dataloader=prompt_dataloader, |
||||
pretrain_dataloader=pretrain_dataloader, |
||||
num_episodes=args.num_episodes, |
||||
max_timesteps=args.max_timesteps, |
||||
update_timesteps=args.update_timesteps) |
||||
|
||||
# save model checkpoint after fitting |
||||
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer) |
||||
# save optimizer checkpoint on all ranks |
||||
if args.need_optim_ckpt: |
||||
strategy.save_optimizer(actor_optim, |
||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), |
||||
only_rank0=False) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset') |
||||
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') |
||||
parser.add_argument('--strategy', |
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], |
||||
default='naive', |
||||
help='strategy to use') |
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) |
||||
parser.add_argument('--pretrain', type=str, default=None) |
||||
parser.add_argument('--sft_lora_path', type=str, default=None) |
||||
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) |
||||
parser.add_argument('--rm_path', type=str, default=None) |
||||
parser.add_argument('--rm_pretrain', type=str, default=None) |
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') |
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False) |
||||
parser.add_argument('--num_episodes', type=int, default=10) |
||||
parser.add_argument('--max_timesteps', type=int, default=10) |
||||
parser.add_argument('--update_timesteps', type=int, default=10) |
||||
parser.add_argument('--max_epochs', type=int, default=5) |
||||
parser.add_argument('--train_batch_size', type=int, default=2) |
||||
parser.add_argument('--ptx_batch_size', type=int, default=1) |
||||
parser.add_argument('--experience_batch_size', type=int, default=8) |
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") |
||||
parser.add_argument('--kl_coef', type=float, default=0.1) |
||||
parser.add_argument('--ptx_coef', type=float, default=0.9) |
||||
args = parser.parse_args() |
||||
main(args) |
@ -0,0 +1,187 @@
|
||||
import argparse |
||||
import os |
||||
|
||||
import loralib as lora |
||||
import torch |
||||
import torch.distributed as dist |
||||
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset |
||||
from coati.models.base import RewardModel |
||||
from coati.models.bloom import BLOOMLM |
||||
from coati.models.gpt import GPTLM |
||||
from coati.models.llama import LlamaLM |
||||
from coati.models.opt import OPTLM |
||||
from coati.trainer import SFTTrainer |
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy |
||||
from coati.utils import prepare_llama_tokenizer_and_embedding |
||||
from datasets import load_dataset |
||||
from torch.optim import Adam |
||||
from torch.utils.data import DataLoader |
||||
from torch.utils.data.distributed import DistributedSampler |
||||
from transformers import AutoTokenizer, BloomTokenizerFast,AutoModelForCausalLM |
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer |
||||
|
||||
from colossalai.logging import get_dist_logger |
||||
from colossalai.nn.optimizer import HybridAdam |
||||
from colossalai.tensor import ColoParameter |
||||
|
||||
from torch.utils.data.dataloader import default_collate |
||||
from peft import LoraConfig, TaskType,get_peft_model,PeftModel |
||||
from easy_dataset import EasyDataset |
||||
|
||||
def train(args): |
||||
# configure strategy |
||||
if args.strategy == 'naive': |
||||
strategy = NaiveStrategy() |
||||
elif args.strategy == 'ddp': |
||||
strategy = DDPStrategy() |
||||
elif args.strategy == 'colossalai_gemini': |
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') |
||||
elif args.strategy == 'colossalai_zero2': |
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') |
||||
else: |
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"') |
||||
|
||||
# configure model |
||||
with strategy.model_init_context(): |
||||
print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested') |
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device()) |
||||
#if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json |
||||
if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \ |
||||
and os.path.exists(args.save_path+'/adapter_model.bin'): |
||||
print("loading from saved peft model ",args.save_path) |
||||
model = PeftModel.from_pretrained(model, args.save_path) |
||||
else: |
||||
#we'll use peft lora library to do the lora |
||||
lora_rank = args.lora_rank if args.lora_rank > 0 else 32 |
||||
#config lora with rank of lora_rank |
||||
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1) |
||||
model = get_peft_model(model, lora_config) |
||||
model.print_trainable_parameters() |
||||
|
||||
|
||||
# configure tokenizer |
||||
if args.model == 'gpt2': |
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') |
||||
tokenizer.pad_token = tokenizer.eos_token |
||||
elif args.model == 'bloom': |
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) |
||||
tokenizer.pad_token = tokenizer.eos_token |
||||
elif args.model == 'opt': |
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") |
||||
elif args.model == 'llama': |
||||
tokenizer = AutoTokenizer.from_pretrained( |
||||
args.pretrain, |
||||
padding_side="right", |
||||
use_fast=False, |
||||
) |
||||
tokenizer.eos_token = '<\s>' |
||||
else: |
||||
raise ValueError(f'Unsupported model "{args.model}"') |
||||
tokenizer.pad_token = tokenizer.eos_token |
||||
if args.model == 'llama': |
||||
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) |
||||
|
||||
if args.strategy == 'colossalai_gemini': |
||||
# this is a hack to deal with the resized embedding |
||||
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity |
||||
for name, param in model.named_parameters(): |
||||
if not isinstance(param, ColoParameter): |
||||
sub_module_name = '.'.join(name.split('.')[:-1]) |
||||
weight_name = name.split('.')[-1] |
||||
sub_module = model.get_submodule(sub_module_name) |
||||
setattr(sub_module, weight_name, ColoParameter(param)) |
||||
else: |
||||
tokenizer.pad_token = tokenizer.eos_token |
||||
|
||||
# configure optimizer |
||||
if args.strategy.startswith('colossalai'): |
||||
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) |
||||
else: |
||||
optim = Adam(model.parameters(), lr=args.lr) |
||||
|
||||
logger = get_dist_logger() |
||||
logger.set_level('WARNING') |
||||
|
||||
# configure dataset |
||||
law_dataset = EasyDataset(args.dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text) |
||||
train_dataset = law_dataset |
||||
print(train_dataset) |
||||
eval_dataset = None |
||||
if args.eval_dataset is not None: |
||||
eval_dataset = EasyDataset(args.eval_dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text) |
||||
data_collator = default_collate |
||||
if dist.is_initialized() and dist.get_world_size() > 1: |
||||
train_sampler = DistributedSampler(train_dataset, |
||||
shuffle=True, |
||||
seed=42, |
||||
drop_last=True, |
||||
rank=dist.get_rank(), |
||||
num_replicas=dist.get_world_size()) |
||||
if eval_dataset is not None: |
||||
eval_sampler = DistributedSampler(eval_dataset, |
||||
shuffle=False, |
||||
seed=42, |
||||
drop_last=False, |
||||
rank=dist.get_rank(), |
||||
num_replicas=dist.get_world_size()) |
||||
else: |
||||
train_sampler = None |
||||
eval_sampler = None |
||||
|
||||
train_dataloader = DataLoader(train_dataset, |
||||
shuffle=(train_sampler is None), |
||||
sampler=train_sampler, |
||||
batch_size=args.batch_size, |
||||
collate_fn=data_collator, |
||||
pin_memory=True) |
||||
if eval_dataset is not None: |
||||
eval_dataloader = DataLoader(eval_dataset, |
||||
shuffle=(eval_sampler is None), |
||||
sampler=eval_sampler, |
||||
batch_size=args.batch_size, |
||||
collate_fn=data_collator, |
||||
pin_memory=True) |
||||
else: |
||||
eval_dataloader = None |
||||
|
||||
trainer = SFTTrainer(model=model, |
||||
strategy=strategy, |
||||
optim=optim, |
||||
train_dataloader=train_dataloader, |
||||
eval_dataloader=eval_dataloader, |
||||
batch_size=args.batch_size, |
||||
max_epochs=args.max_epochs, |
||||
accimulation_steps=args.accimulation_steps) |
||||
|
||||
trainer.fit(logger=logger, log_interval=args.log_interval) |
||||
|
||||
# save model checkpoint after fitting on only rank0 |
||||
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer) |
||||
# save optimizer checkpoint on all ranks |
||||
if args.need_optim_ckpt: |
||||
strategy.save_optimizer(trainer.optimizer, |
||||
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), |
||||
only_rank0=False) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument('--strategy', |
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], |
||||
default='naive') |
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') |
||||
parser.add_argument('--pretrain', type=str, default=None) |
||||
parser.add_argument('--dataset', type=str, default=None) |
||||
parser.add_argument('--eval_dataset', type=str, default=None) |
||||
parser.add_argument('--save_path', type=str, default='output') |
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False) |
||||
parser.add_argument('--max_epochs', type=int, default=3) |
||||
parser.add_argument('--batch_size', type=int, default=4) |
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") |
||||
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") |
||||
parser.add_argument('--lr', type=float, default=5e-6) |
||||
parser.add_argument('--accimulation_steps', type=int, default=8) |
||||
parser.add_argument('--enable_peft_lora',action='store_true', default=False) |
||||
parser.add_argument("--is_short_text",action='store_true', default=False) |
||||
args = parser.parse_args() |
||||
train(args) |
Loading…
Reference in new issue