mirror of https://github.com/hpcaitech/ColossalAI
add community example dictionary (#3465)
parent
80eba05b0a
commit
6afeb1202a
|
@ -0,0 +1 @@
|
||||||
|
# Community Examples
|
|
@ -1,19 +1,17 @@
|
||||||
import copy
|
import copy
|
||||||
|
import json
|
||||||
from typing import Dict, Sequence
|
from typing import Dict, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
import json
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import json
|
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
|
||||||
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :int = 512) -> Dict:
|
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
|
||||||
"""Tokenize a list of strings."""
|
"""Tokenize a list of strings."""
|
||||||
tokenized_list = [
|
tokenized_list = [
|
||||||
tokenizer(
|
tokenizer(
|
||||||
|
@ -36,15 +34,12 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :in
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
|
||||||
sources: Sequence[str],
|
|
||||||
targets: Sequence[str],
|
|
||||||
tokenizer: AutoTokenizer,
|
|
||||||
max_length :int = 512
|
|
||||||
) -> Dict:
|
|
||||||
"""Preprocess the data by tokenizing."""
|
"""Preprocess the data by tokenizing."""
|
||||||
examples = [s + t for s, t in zip(sources, targets)]
|
examples = [s + t for s, t in zip(sources, targets)]
|
||||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer,max_length) for strings in (examples, sources)]
|
examples_tokenized, sources_tokenized = [
|
||||||
|
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
|
||||||
|
]
|
||||||
input_ids = examples_tokenized["input_ids"]
|
input_ids = examples_tokenized["input_ids"]
|
||||||
labels = copy.deepcopy(input_ids)
|
labels = copy.deepcopy(input_ids)
|
||||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||||
|
@ -53,21 +48,22 @@ def preprocess(
|
||||||
|
|
||||||
|
|
||||||
class EasySupervisedDataset(Dataset):
|
class EasySupervisedDataset(Dataset):
|
||||||
def __init__(self, data_file :str, tokenizer :AutoTokenizer,max_length :int = 512) -> None:
|
|
||||||
super(EasySupervisedDataset,self).__init__()
|
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
|
||||||
with open(data_file,"r",encoding="UTF-8") as f:
|
super(EasySupervisedDataset, self).__init__()
|
||||||
|
with open(data_file, "r", encoding="UTF-8") as f:
|
||||||
all_lines = f.readlines()
|
all_lines = f.readlines()
|
||||||
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
||||||
sources,targets = [],[]
|
sources, targets = [], []
|
||||||
for line in all_lines:
|
for line in all_lines:
|
||||||
if "回答:" in line:
|
if "回答:" in line:
|
||||||
sep_index = line.index("回答:")
|
sep_index = line.index("回答:")
|
||||||
sources.append(line[:sep_index+3])
|
sources.append(line[:sep_index + 3])
|
||||||
targets.append(line[sep_index+3:]+tokenizer.eos_token)
|
targets.append(line[sep_index + 3:] + tokenizer.eos_token)
|
||||||
else:
|
else:
|
||||||
sources.append(line)
|
sources.append(line)
|
||||||
targets.append(""+tokenizer.eos_token)
|
targets.append("" + tokenizer.eos_token)
|
||||||
data_dict = preprocess(sources, targets, tokenizer,max_length)
|
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
||||||
|
|
||||||
self.input_ids = data_dict["input_ids"]
|
self.input_ids = data_dict["input_ids"]
|
||||||
self.labels = data_dict["labels"]
|
self.labels = data_dict["labels"]
|
||||||
|
@ -85,21 +81,21 @@ class EasySupervisedDataset(Dataset):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
|
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
|
||||||
|
|
||||||
|
|
||||||
class EasyPromptsDataset(Dataset):
|
class EasyPromptsDataset(Dataset):
|
||||||
def __init__(self,data_file :str, tokenizer :AutoTokenizer, max_length :int = 96) -> None:
|
|
||||||
super(EasyPromptsDataset,self).__init__()
|
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
|
||||||
with open(data_file,"r",encoding="UTF-8") as f:
|
super(EasyPromptsDataset, self).__init__()
|
||||||
|
with open(data_file, "r", encoding="UTF-8") as f:
|
||||||
all_lines = f.readlines()
|
all_lines = f.readlines()
|
||||||
all_lines = [line if "回答:" not in line else line[:line.index("回答:")+3] for line in all_lines]
|
all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
|
||||||
self.prompts = [
|
self.prompts = [
|
||||||
tokenizer(line,
|
tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
|
||||||
return_tensors='pt',
|
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
|
||||||
max_length=max_length,
|
|
||||||
padding='max_length',
|
|
||||||
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
|
|
||||||
for line in tqdm(all_lines)
|
for line in tqdm(all_lines)
|
||||||
]
|
]
|
||||||
self.data_file = data_file
|
self.data_file = data_file
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.prompts)
|
return len(self.prompts)
|
||||||
|
|
||||||
|
@ -114,8 +110,9 @@ class EasyPromptsDataset(Dataset):
|
||||||
|
|
||||||
|
|
||||||
class EasyRewardDataset(Dataset):
|
class EasyRewardDataset(Dataset):
|
||||||
def __init__(self,train_file :str,tokenizer :AutoTokenizer, special_token = None,max_length = 512) -> None:
|
|
||||||
super(EasyRewardDataset,self).__init__()
|
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
|
||||||
|
super(EasyRewardDataset, self).__init__()
|
||||||
self.chosen = []
|
self.chosen = []
|
||||||
self.reject = []
|
self.reject = []
|
||||||
if special_token is None:
|
if special_token is None:
|
||||||
|
@ -124,11 +121,11 @@ class EasyRewardDataset(Dataset):
|
||||||
self.end_token = special_token
|
self.end_token = special_token
|
||||||
print(self.end_token)
|
print(self.end_token)
|
||||||
#read all lines in the train_file to a list
|
#read all lines in the train_file to a list
|
||||||
with open(train_file,"r",encoding="UTF-8") as f:
|
with open(train_file, "r", encoding="UTF-8") as f:
|
||||||
all_lines = f.readlines()
|
all_lines = f.readlines()
|
||||||
for line in tqdm(all_lines):
|
for line in tqdm(all_lines):
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
prompt = "提问:"+data['prompt']+" 回答:"
|
prompt = "提问:" + data['prompt'] + " 回答:"
|
||||||
|
|
||||||
chosen = prompt + data['chosen'] + self.end_token
|
chosen = prompt + data['chosen'] + self.end_token
|
||||||
chosen_token = tokenizer(chosen,
|
chosen_token = tokenizer(chosen,
|
||||||
|
@ -167,24 +164,27 @@ class EasyRewardDataset(Dataset):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Easy SFT just accept a text file which can be read line by line. However the datasest will group texts together to max_length so LLM will learn the texts meaning better.
|
Easy SFT just accept a text file which can be read line by line. However the datasest will group texts together to max_length so LLM will learn the texts meaning better.
|
||||||
If individual lines are not related, just set is_group_texts to False.
|
If individual lines are not related, just set is_group_texts to False.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
class EasySFTDataset(Dataset):
|
class EasySFTDataset(Dataset):
|
||||||
|
|
||||||
def __init__(self,data_file :str,tokenizer :AutoTokenizer,max_length = 512,is_group_texts = True) -> None:
|
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
#read the data_file line by line
|
#read the data_file line by line
|
||||||
with open(data_file,"r",encoding="UTF-8") as f:
|
with open(data_file, "r", encoding="UTF-8") as f:
|
||||||
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
||||||
raw_input_ids = []
|
raw_input_ids = []
|
||||||
for line in f:
|
for line in f:
|
||||||
encoded_ids = tokenizer.encode(line)
|
encoded_ids = tokenizer.encode(line)
|
||||||
#if the encoded_ids is longer than max_length, then split it into several parts
|
#if the encoded_ids is longer than max_length, then split it into several parts
|
||||||
if len(encoded_ids) > max_length:
|
if len(encoded_ids) > max_length:
|
||||||
for i in range(0,len(encoded_ids),max_length):
|
for i in range(0, len(encoded_ids), max_length):
|
||||||
raw_input_ids.append(encoded_ids[i:i+max_length])
|
raw_input_ids.append(encoded_ids[i:i + max_length])
|
||||||
else:
|
else:
|
||||||
raw_input_ids.append(encoded_ids)
|
raw_input_ids.append(encoded_ids)
|
||||||
|
|
||||||
|
@ -199,23 +199,26 @@ class EasySFTDataset(Dataset):
|
||||||
#pad the current_input_ids to max_length with tokenizer.pad_token_id
|
#pad the current_input_ids to max_length with tokenizer.pad_token_id
|
||||||
padded_length = max_length - len(current_input_ids)
|
padded_length = max_length - len(current_input_ids)
|
||||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||||
grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long))
|
grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||||
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
|
attention_mask.append(
|
||||||
|
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
||||||
current_input_ids = []
|
current_input_ids = []
|
||||||
else:
|
else:
|
||||||
current_input_ids.extend(input_ids)
|
current_input_ids.extend(input_ids)
|
||||||
if len(current_input_ids) > 0:
|
if len(current_input_ids) > 0:
|
||||||
padded_length = max_length - len(current_input_ids)
|
padded_length = max_length - len(current_input_ids)
|
||||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||||
grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long))
|
grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||||
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
|
attention_mask.append(
|
||||||
|
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
||||||
else:
|
else:
|
||||||
#just append the raw_input_ids to max_length
|
#just append the raw_input_ids to max_length
|
||||||
for input_ids in raw_input_ids:
|
for input_ids in raw_input_ids:
|
||||||
padded_length = max_length - len(input_ids)
|
padded_length = max_length - len(input_ids)
|
||||||
input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||||
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
|
attention_mask.append(
|
||||||
grouped_inpup_ids.append(torch.tensor(input_ids,dtype=torch.long))
|
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
||||||
|
grouped_inpup_ids.append(torch.tensor(input_ids, dtype=torch.long))
|
||||||
self.input_ids = grouped_inpup_ids
|
self.input_ids = grouped_inpup_ids
|
||||||
self.labels = copy.deepcopy(self.input_ids)
|
self.labels = copy.deepcopy(self.input_ids)
|
||||||
self.file_name = data_file
|
self.file_name = data_file
|
||||||
|
@ -225,8 +228,8 @@ class EasySFTDataset(Dataset):
|
||||||
return len(self.input_ids)
|
return len(self.input_ids)
|
||||||
|
|
||||||
#get item from dataset
|
#get item from dataset
|
||||||
def __getitem__(self,idx):
|
def __getitem__(self, idx):
|
||||||
return dict(input_ids=self.input_ids[idx],labels=self.labels[idx],attention_mask=self.attention_mask[idx])
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
||||||
|
|
||||||
#generate the dataset description to be printed by print in python
|
#generate the dataset description to be printed by print in python
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -235,8 +238,3 @@ class EasySFTDataset(Dataset):
|
||||||
#generate the dataset description to be printed by print in python
|
#generate the dataset description to be printed by print in python
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,12 +3,12 @@ from typing import Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.modules import Module
|
|
||||||
|
|
||||||
from coati.models.generation import generate
|
from coati.models.generation import generate
|
||||||
from coati.models.utils import log_probs_from_logits,masked_mean
|
from coati.models.utils import log_probs_from_logits, masked_mean
|
||||||
from transformers import BloomConfig,BloomForCausalLM
|
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
from torch.nn.modules import Module
|
||||||
|
from transformers import BloomConfig, BloomForCausalLM
|
||||||
|
|
||||||
|
|
||||||
class Actor(Module):
|
class Actor(Module):
|
||||||
"""
|
"""
|
||||||
|
@ -87,11 +87,10 @@ class BLOOMActor(Actor):
|
||||||
else:
|
else:
|
||||||
model = BloomForCausalLM(BloomConfig())
|
model = BloomForCausalLM(BloomConfig())
|
||||||
if lora_path is not None:
|
if lora_path is not None:
|
||||||
model = PeftModel.from_pretrained(model,lora_path)
|
model = PeftModel.from_pretrained(model, lora_path)
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
super().__init__(model)
|
super().__init__(model)
|
||||||
|
|
||||||
def print_trainable_parameters(self):
|
def print_trainable_parameters(self):
|
||||||
self.get_base_model().print_trainable_parameters()
|
self.get_base_model().print_trainable_parameters()
|
||||||
|
|
|
@ -5,21 +5,22 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
|
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
|
||||||
from coati.models.bloom import BLOOMRM, BLOOMCritic
|
from coati.models.bloom import BLOOMRM, BLOOMCritic
|
||||||
from easy_models import BLOOMActor
|
|
||||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||||
from coati.trainer import PPOTrainer
|
from coati.trainer import PPOTrainer
|
||||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||||
|
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
|
||||||
|
from easy_models import BLOOMActor
|
||||||
|
from peft import PeftModel
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
|
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from peft import PeftModel
|
|
||||||
from easy_dataset import EasyPromptsDataset,EasySupervisedDataset
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# configure strategy
|
# configure strategy
|
||||||
|
@ -41,7 +42,7 @@ def main(args):
|
||||||
if args.model == 'bloom':
|
if args.model == 'bloom':
|
||||||
# initial_model = BLOOMActor(pretrained=args.pretrain)
|
# initial_model = BLOOMActor(pretrained=args.pretrain)
|
||||||
print('Using peft lora to load Bloom model as inital_model')
|
print('Using peft lora to load Bloom model as inital_model')
|
||||||
initial_model = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path)
|
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
||||||
print('Using peft lora to load Bloom model as initial_model (Done)')
|
print('Using peft lora to load Bloom model as initial_model (Done)')
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||||
|
@ -54,7 +55,7 @@ def main(args):
|
||||||
if rm_model_name == 'gpt2':
|
if rm_model_name == 'gpt2':
|
||||||
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
||||||
elif rm_model_name == 'bloom':
|
elif rm_model_name == 'bloom':
|
||||||
print("load bloom reward model ",args.rm_pretrain)
|
print("load bloom reward model ", args.rm_pretrain)
|
||||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
||||||
elif rm_model_name == 'opt':
|
elif rm_model_name == 'opt':
|
||||||
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
||||||
|
@ -75,7 +76,7 @@ def main(args):
|
||||||
if args.model == 'bloom':
|
if args.model == 'bloom':
|
||||||
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||||
print('Using peft lora to load Bloom model as Actor')
|
print('Using peft lora to load Bloom model as Actor')
|
||||||
actor = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path)
|
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
||||||
print('Using peft lora to load Bloom model as Actor (Done)')
|
print('Using peft lora to load Bloom model as Actor (Done)')
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||||
|
@ -83,7 +84,7 @@ def main(args):
|
||||||
if rm_model_name == 'gpt2':
|
if rm_model_name == 'gpt2':
|
||||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||||
elif rm_model_name == 'bloom':
|
elif rm_model_name == 'bloom':
|
||||||
print("load bloom critic ",args.rm_pretrain," lora_rank ",args.lora_rank," use_action_mask ",True)
|
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
|
||||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||||
print("load bloom critic (Done) ")
|
print("load bloom critic (Done) ")
|
||||||
elif rm_model_name == 'opt':
|
elif rm_model_name == 'opt':
|
||||||
|
@ -130,7 +131,7 @@ def main(args):
|
||||||
|
|
||||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||||
|
|
||||||
prompt_dataset = EasyPromptsDataset(args.prompt_path,tokenizer)
|
prompt_dataset = EasyPromptsDataset(args.prompt_path, tokenizer)
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
||||||
else:
|
else:
|
|
@ -14,19 +14,19 @@ from coati.trainer import SFTTrainer
|
||||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from easy_dataset import EasyDataset
|
||||||
|
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.dataloader import default_collate
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from transformers import AutoTokenizer, BloomTokenizerFast,AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomTokenizerFast
|
||||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.tensor import ColoParameter
|
from colossalai.tensor import ColoParameter
|
||||||
|
|
||||||
from torch.utils.data.dataloader import default_collate
|
|
||||||
from peft import LoraConfig, TaskType,get_peft_model,PeftModel
|
|
||||||
from easy_dataset import EasyDataset
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
# configure strategy
|
# configure strategy
|
||||||
|
@ -48,17 +48,20 @@ def train(args):
|
||||||
#if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
|
#if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
|
||||||
if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \
|
if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \
|
||||||
and os.path.exists(args.save_path+'/adapter_model.bin'):
|
and os.path.exists(args.save_path+'/adapter_model.bin'):
|
||||||
print("loading from saved peft model ",args.save_path)
|
print("loading from saved peft model ", args.save_path)
|
||||||
model = PeftModel.from_pretrained(model, args.save_path)
|
model = PeftModel.from_pretrained(model, args.save_path)
|
||||||
else:
|
else:
|
||||||
#we'll use peft lora library to do the lora
|
#we'll use peft lora library to do the lora
|
||||||
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
|
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
|
||||||
#config lora with rank of lora_rank
|
#config lora with rank of lora_rank
|
||||||
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1)
|
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
|
||||||
|
inference_mode=False,
|
||||||
|
r=lora_rank,
|
||||||
|
lora_alpha=32,
|
||||||
|
lora_dropout=0.1)
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
if args.model == 'gpt2':
|
if args.model == 'gpt2':
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
|
@ -103,12 +106,12 @@ def train(args):
|
||||||
logger.set_level('WARNING')
|
logger.set_level('WARNING')
|
||||||
|
|
||||||
# configure dataset
|
# configure dataset
|
||||||
law_dataset = EasyDataset(args.dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text)
|
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
||||||
train_dataset = law_dataset
|
train_dataset = law_dataset
|
||||||
print(train_dataset)
|
print(train_dataset)
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
if args.eval_dataset is not None:
|
if args.eval_dataset is not None:
|
||||||
eval_dataset = EasyDataset(args.eval_dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text)
|
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
||||||
data_collator = default_collate
|
data_collator = default_collate
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
train_sampler = DistributedSampler(train_dataset,
|
train_sampler = DistributedSampler(train_dataset,
|
||||||
|
@ -181,7 +184,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
|
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
|
||||||
parser.add_argument('--lr', type=float, default=5e-6)
|
parser.add_argument('--lr', type=float, default=5e-6)
|
||||||
parser.add_argument('--accimulation_steps', type=int, default=8)
|
parser.add_argument('--accimulation_steps', type=int, default=8)
|
||||||
parser.add_argument('--enable_peft_lora',action='store_true', default=False)
|
parser.add_argument('--enable_peft_lora', action='store_true', default=False)
|
||||||
parser.add_argument("--is_short_text",action='store_true', default=False)
|
parser.add_argument("--is_short_text", action='store_true', default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
Loading…
Reference in New Issue