add community example dictionary (#3465)

pull/3343/head
Fazzie-Maqianli 2023-04-06 15:04:48 +08:00 committed by GitHub
parent 80eba05b0a
commit 6afeb1202a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 94 additions and 92 deletions

View File

@ -0,0 +1 @@
# Community Examples

View File

@ -10,7 +10,7 @@ Since the current pypi peft package(0.2) has some bugs, please install the peft
git clone https://github.com/huggingface/peft
cd peft
pip install .
```
```
# Usage
For SFT training, just call train_peft_sft.py
@ -21,4 +21,4 @@ For stage-3 rlhf training, call train_peft_prompts.py.
Its arguments are almost idential to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
# Dataformat
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt.
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt.

View File

@ -1,19 +1,17 @@
import copy
import json
from typing import Dict, Sequence
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
import torch
from tqdm import tqdm
import json
from tqdm import tqdm
import json
IGNORE_INDEX = -100
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :int = 512) -> Dict:
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
@ -36,15 +34,12 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :in
)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: AutoTokenizer,
max_length :int = 512
) -> Dict:
def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer,max_length) for strings in (examples, sources)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
@ -53,59 +48,60 @@ def preprocess(
class EasySupervisedDataset(Dataset):
def __init__(self, data_file :str, tokenizer :AutoTokenizer,max_length :int = 512) -> None:
super(EasySupervisedDataset,self).__init__()
with open(data_file,"r",encoding="UTF-8") as f:
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
super(EasySupervisedDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
sources,targets = [],[]
sources, targets = [], []
for line in all_lines:
if "回答:" in line:
sep_index = line.index("回答:")
sources.append(line[:sep_index+3])
targets.append(line[sep_index+3:]+tokenizer.eos_token)
sources.append(line[:sep_index + 3])
targets.append(line[sep_index + 3:] + tokenizer.eos_token)
else:
sources.append(line)
targets.append(""+tokenizer.eos_token)
data_dict = preprocess(sources, targets, tokenizer,max_length)
targets.append("" + tokenizer.eos_token)
data_dict = preprocess(sources, targets, tokenizer, max_length)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
self.data_file = data_file
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
def __repr__(self):
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
def __str__(self):
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
class EasyPromptsDataset(Dataset):
def __init__(self,data_file :str, tokenizer :AutoTokenizer, max_length :int = 96) -> None:
super(EasyPromptsDataset,self).__init__()
with open(data_file,"r",encoding="UTF-8") as f:
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
super(EasyPromptsDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
all_lines = [line if "回答:" not in line else line[:line.index("回答:")+3] for line in all_lines]
all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
self.prompts = [
tokenizer(line,
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
for line in tqdm(all_lines)
]
self.data_file = data_file
def __len__(self):
return len(self.prompts)
def __getitem__(self, idx):
return self.prompts[idx]
def __repr__(self):
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
@ -114,8 +110,9 @@ class EasyPromptsDataset(Dataset):
class EasyRewardDataset(Dataset):
def __init__(self,train_file :str,tokenizer :AutoTokenizer, special_token = None,max_length = 512) -> None:
super(EasyRewardDataset,self).__init__()
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
super(EasyRewardDataset, self).__init__()
self.chosen = []
self.reject = []
if special_token is None:
@ -124,11 +121,11 @@ class EasyRewardDataset(Dataset):
self.end_token = special_token
print(self.end_token)
#read all lines in the train_file to a list
with open(train_file,"r",encoding="UTF-8") as f:
with open(train_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
for line in tqdm(all_lines):
data = json.loads(line)
prompt = "提问:"+data['prompt']+" 回答:"
prompt = "提问:" + data['prompt'] + " 回答:"
chosen = prompt + data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
@ -159,7 +156,7 @@ class EasyRewardDataset(Dataset):
def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]
#python representation of the object and the string representation of the object
def __repr__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
@ -167,27 +164,30 @@ class EasyRewardDataset(Dataset):
def __str__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
'''
Easy SFT just accept a text file which can be read line by line. However the datasest will group texts together to max_length so LLM will learn the texts meaning better.
If individual lines are not related, just set is_group_texts to False.
'''
class EasySFTDataset(Dataset):
def __init__(self,data_file :str,tokenizer :AutoTokenizer,max_length = 512,is_group_texts = True) -> None:
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
super().__init__()
#read the data_file line by line
with open(data_file,"r",encoding="UTF-8") as f:
with open(data_file, "r", encoding="UTF-8") as f:
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
raw_input_ids = []
for line in f:
encoded_ids = tokenizer.encode(line)
#if the encoded_ids is longer than max_length, then split it into several parts
if len(encoded_ids) > max_length:
for i in range(0,len(encoded_ids),max_length):
raw_input_ids.append(encoded_ids[i:i+max_length])
for i in range(0, len(encoded_ids), max_length):
raw_input_ids.append(encoded_ids[i:i + max_length])
else:
raw_input_ids.append(encoded_ids)
grouped_inpup_ids = []
current_input_ids = []
attention_mask = []
@ -199,44 +199,42 @@ class EasySFTDataset(Dataset):
#pad the current_input_ids to max_length with tokenizer.pad_token_id
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long))
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
current_input_ids = []
else:
current_input_ids.extend(input_ids)
if len(current_input_ids) > 0:
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long))
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
else:
#just append the raw_input_ids to max_length
for input_ids in raw_input_ids:
padded_length = max_length - len(input_ids)
input_ids.extend([tokenizer.pad_token_id] * padded_length)
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
grouped_inpup_ids.append(torch.tensor(input_ids,dtype=torch.long))
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
grouped_inpup_ids.append(torch.tensor(input_ids, dtype=torch.long))
self.input_ids = grouped_inpup_ids
self.labels = copy.deepcopy(self.input_ids)
self.file_name = data_file
self.attention_mask = attention_mask
def __len__(self):
return len(self.input_ids)
#get item from dataset
def __getitem__(self,idx):
return dict(input_ids=self.input_ids[idx],labels=self.labels[idx],attention_mask=self.attention_mask[idx])
def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
#generate the dataset description to be printed by print in python
def __repr__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
#generate the dataset description to be printed by print in python
def __str__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"

View File

@ -3,12 +3,12 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import Module
from coati.models.generation import generate
from coati.models.utils import log_probs_from_logits,masked_mean
from transformers import BloomConfig,BloomForCausalLM
from coati.models.utils import log_probs_from_logits, masked_mean
from peft import PeftModel
from torch.nn.modules import Module
from transformers import BloomConfig, BloomForCausalLM
class Actor(Module):
"""
@ -87,11 +87,10 @@ class BLOOMActor(Actor):
else:
model = BloomForCausalLM(BloomConfig())
if lora_path is not None:
model = PeftModel.from_pretrained(model,lora_path)
model = PeftModel.from_pretrained(model, lora_path)
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model)
def print_trainable_parameters(self):
self.get_base_model().print_trainable_parameters()

View File

@ -5,21 +5,22 @@ import torch
import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMCritic
from easy_models import BLOOMActor
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.trainer import PPOTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
from easy_models import BLOOMActor
from peft import PeftModel
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
from colossalai.nn.optimizer import HybridAdam
from peft import PeftModel
from easy_dataset import EasyPromptsDataset,EasySupervisedDataset
def main(args):
# configure strategy
@ -41,7 +42,7 @@ def main(args):
if args.model == 'bloom':
# initial_model = BLOOMActor(pretrained=args.pretrain)
print('Using peft lora to load Bloom model as inital_model')
initial_model = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path)
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
print('Using peft lora to load Bloom model as initial_model (Done)')
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
@ -54,7 +55,7 @@ def main(args):
if rm_model_name == 'gpt2':
reward_model = GPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'bloom':
print("load bloom reward model ",args.rm_pretrain)
print("load bloom reward model ", args.rm_pretrain)
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'opt':
reward_model = OPTRM(pretrained=args.rm_pretrain)
@ -75,7 +76,7 @@ def main(args):
if args.model == 'bloom':
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
print('Using peft lora to load Bloom model as Actor')
actor = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path)
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
print('Using peft lora to load Bloom model as Actor (Done)')
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
@ -83,7 +84,7 @@ def main(args):
if rm_model_name == 'gpt2':
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'bloom':
print("load bloom critic ",args.rm_pretrain," lora_rank ",args.lora_rank," use_action_mask ",True)
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
print("load bloom critic (Done) ")
elif rm_model_name == 'opt':
@ -130,7 +131,7 @@ def main(args):
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
prompt_dataset = EasyPromptsDataset(args.prompt_path,tokenizer)
prompt_dataset = EasyPromptsDataset(args.prompt_path, tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1:
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:

View File

@ -14,19 +14,19 @@ from coati.trainer import SFTTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding
from datasets import load_dataset
from easy_dataset import EasyDataset
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast,AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter
from torch.utils.data.dataloader import default_collate
from peft import LoraConfig, TaskType,get_peft_model,PeftModel
from easy_dataset import EasyDataset
def train(args):
# configure strategy
@ -48,17 +48,20 @@ def train(args):
#if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \
and os.path.exists(args.save_path+'/adapter_model.bin'):
print("loading from saved peft model ",args.save_path)
print("loading from saved peft model ", args.save_path)
model = PeftModel.from_pretrained(model, args.save_path)
else:
#we'll use peft lora library to do the lora
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
#config lora with rank of lora_rank
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1)
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=lora_rank,
lora_alpha=32,
lora_dropout=0.1)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
@ -103,12 +106,12 @@ def train(args):
logger.set_level('WARNING')
# configure dataset
law_dataset = EasyDataset(args.dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text)
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
train_dataset = law_dataset
print(train_dataset)
eval_dataset = None
if args.eval_dataset is not None:
eval_dataset = EasyDataset(args.eval_dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text)
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
data_collator = default_collate
if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset,
@ -181,7 +184,7 @@ if __name__ == '__main__':
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
parser.add_argument('--lr', type=float, default=5e-6)
parser.add_argument('--accimulation_steps', type=int, default=8)
parser.add_argument('--enable_peft_lora',action='store_true', default=False)
parser.add_argument("--is_short_text",action='store_true', default=False)
parser.add_argument('--enable_peft_lora', action='store_true', default=False)
parser.add_argument("--is_short_text", action='store_true', default=False)
args = parser.parse_args()
train(args)