mirror of https://github.com/hpcaitech/ColossalAI
add community example dictionary (#3465)
parent
80eba05b0a
commit
6afeb1202a
|
@ -0,0 +1 @@
|
|||
# Community Examples
|
|
@ -10,7 +10,7 @@ Since the current pypi peft package(0.2) has some bugs, please install the peft
|
|||
git clone https://github.com/huggingface/peft
|
||||
cd peft
|
||||
pip install .
|
||||
```
|
||||
```
|
||||
|
||||
# Usage
|
||||
For SFT training, just call train_peft_sft.py
|
||||
|
@ -21,4 +21,4 @@ For stage-3 rlhf training, call train_peft_prompts.py.
|
|||
Its arguments are almost idential to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
|
||||
|
||||
# Dataformat
|
||||
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt.
|
||||
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt.
|
|
@ -1,19 +1,17 @@
|
|||
import copy
|
||||
import json
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :int = 512) -> Dict:
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
|
@ -36,15 +34,12 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :in
|
|||
)
|
||||
|
||||
|
||||
def preprocess(
|
||||
sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: AutoTokenizer,
|
||||
max_length :int = 512
|
||||
) -> Dict:
|
||||
def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer,max_length) for strings in (examples, sources)]
|
||||
examples_tokenized, sources_tokenized = [
|
||||
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
|
||||
]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
|
@ -53,59 +48,60 @@ def preprocess(
|
|||
|
||||
|
||||
class EasySupervisedDataset(Dataset):
|
||||
def __init__(self, data_file :str, tokenizer :AutoTokenizer,max_length :int = 512) -> None:
|
||||
super(EasySupervisedDataset,self).__init__()
|
||||
with open(data_file,"r",encoding="UTF-8") as f:
|
||||
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
|
||||
super(EasySupervisedDataset, self).__init__()
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
||||
sources,targets = [],[]
|
||||
sources, targets = [], []
|
||||
for line in all_lines:
|
||||
if "回答:" in line:
|
||||
sep_index = line.index("回答:")
|
||||
sources.append(line[:sep_index+3])
|
||||
targets.append(line[sep_index+3:]+tokenizer.eos_token)
|
||||
sources.append(line[:sep_index + 3])
|
||||
targets.append(line[sep_index + 3:] + tokenizer.eos_token)
|
||||
else:
|
||||
sources.append(line)
|
||||
targets.append(""+tokenizer.eos_token)
|
||||
data_dict = preprocess(sources, targets, tokenizer,max_length)
|
||||
targets.append("" + tokenizer.eos_token)
|
||||
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
self.input_ids = data_dict["input_ids"]
|
||||
self.labels = data_dict["labels"]
|
||||
self.data_file = data_file
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
|
||||
|
||||
|
||||
class EasyPromptsDataset(Dataset):
|
||||
def __init__(self,data_file :str, tokenizer :AutoTokenizer, max_length :int = 96) -> None:
|
||||
super(EasyPromptsDataset,self).__init__()
|
||||
with open(data_file,"r",encoding="UTF-8") as f:
|
||||
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
|
||||
super(EasyPromptsDataset, self).__init__()
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
all_lines = [line if "回答:" not in line else line[:line.index("回答:")+3] for line in all_lines]
|
||||
all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
|
||||
self.prompts = [
|
||||
tokenizer(line,
|
||||
return_tensors='pt',
|
||||
max_length=max_length,
|
||||
padding='max_length',
|
||||
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
|
||||
tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
|
||||
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
|
||||
for line in tqdm(all_lines)
|
||||
]
|
||||
self.data_file = data_file
|
||||
|
||||
def __len__(self):
|
||||
return len(self.prompts)
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.prompts[idx]
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
|
||||
|
||||
|
@ -114,8 +110,9 @@ class EasyPromptsDataset(Dataset):
|
|||
|
||||
|
||||
class EasyRewardDataset(Dataset):
|
||||
def __init__(self,train_file :str,tokenizer :AutoTokenizer, special_token = None,max_length = 512) -> None:
|
||||
super(EasyRewardDataset,self).__init__()
|
||||
|
||||
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
|
||||
super(EasyRewardDataset, self).__init__()
|
||||
self.chosen = []
|
||||
self.reject = []
|
||||
if special_token is None:
|
||||
|
@ -124,11 +121,11 @@ class EasyRewardDataset(Dataset):
|
|||
self.end_token = special_token
|
||||
print(self.end_token)
|
||||
#read all lines in the train_file to a list
|
||||
with open(train_file,"r",encoding="UTF-8") as f:
|
||||
with open(train_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
for line in tqdm(all_lines):
|
||||
data = json.loads(line)
|
||||
prompt = "提问:"+data['prompt']+" 回答:"
|
||||
prompt = "提问:" + data['prompt'] + " 回答:"
|
||||
|
||||
chosen = prompt + data['chosen'] + self.end_token
|
||||
chosen_token = tokenizer(chosen,
|
||||
|
@ -159,7 +156,7 @@ class EasyRewardDataset(Dataset):
|
|||
def __getitem__(self, idx):
|
||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
||||
"input_ids"], self.reject[idx]["attention_mask"]
|
||||
|
||||
|
||||
#python representation of the object and the string representation of the object
|
||||
def __repr__(self):
|
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||
|
@ -167,27 +164,30 @@ class EasyRewardDataset(Dataset):
|
|||
def __str__(self):
|
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||
|
||||
|
||||
'''
|
||||
Easy SFT just accept a text file which can be read line by line. However the datasest will group texts together to max_length so LLM will learn the texts meaning better.
|
||||
If individual lines are not related, just set is_group_texts to False.
|
||||
'''
|
||||
|
||||
|
||||
class EasySFTDataset(Dataset):
|
||||
|
||||
def __init__(self,data_file :str,tokenizer :AutoTokenizer,max_length = 512,is_group_texts = True) -> None:
|
||||
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
|
||||
super().__init__()
|
||||
#read the data_file line by line
|
||||
with open(data_file,"r",encoding="UTF-8") as f:
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
||||
raw_input_ids = []
|
||||
for line in f:
|
||||
encoded_ids = tokenizer.encode(line)
|
||||
#if the encoded_ids is longer than max_length, then split it into several parts
|
||||
if len(encoded_ids) > max_length:
|
||||
for i in range(0,len(encoded_ids),max_length):
|
||||
raw_input_ids.append(encoded_ids[i:i+max_length])
|
||||
for i in range(0, len(encoded_ids), max_length):
|
||||
raw_input_ids.append(encoded_ids[i:i + max_length])
|
||||
else:
|
||||
raw_input_ids.append(encoded_ids)
|
||||
|
||||
|
||||
grouped_inpup_ids = []
|
||||
current_input_ids = []
|
||||
attention_mask = []
|
||||
|
@ -199,44 +199,42 @@ class EasySFTDataset(Dataset):
|
|||
#pad the current_input_ids to max_length with tokenizer.pad_token_id
|
||||
padded_length = max_length - len(current_input_ids)
|
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long))
|
||||
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
|
||||
grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
||||
current_input_ids = []
|
||||
else:
|
||||
current_input_ids.extend(input_ids)
|
||||
if len(current_input_ids) > 0:
|
||||
padded_length = max_length - len(current_input_ids)
|
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long))
|
||||
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
|
||||
grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
||||
else:
|
||||
#just append the raw_input_ids to max_length
|
||||
for input_ids in raw_input_ids:
|
||||
padded_length = max_length - len(input_ids)
|
||||
input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long))
|
||||
grouped_inpup_ids.append(torch.tensor(input_ids,dtype=torch.long))
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
||||
grouped_inpup_ids.append(torch.tensor(input_ids, dtype=torch.long))
|
||||
self.input_ids = grouped_inpup_ids
|
||||
self.labels = copy.deepcopy(self.input_ids)
|
||||
self.file_name = data_file
|
||||
self.attention_mask = attention_mask
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
|
||||
#get item from dataset
|
||||
def __getitem__(self,idx):
|
||||
return dict(input_ids=self.input_ids[idx],labels=self.labels[idx],attention_mask=self.attention_mask[idx])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
||||
|
||||
#generate the dataset description to be printed by print in python
|
||||
def __repr__(self):
|
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||
|
||||
|
||||
#generate the dataset description to be printed by print in python
|
||||
def __str__(self):
|
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -3,12 +3,12 @@ from typing import Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules import Module
|
||||
|
||||
from coati.models.generation import generate
|
||||
from coati.models.utils import log_probs_from_logits,masked_mean
|
||||
from transformers import BloomConfig,BloomForCausalLM
|
||||
from coati.models.utils import log_probs_from_logits, masked_mean
|
||||
from peft import PeftModel
|
||||
from torch.nn.modules import Module
|
||||
from transformers import BloomConfig, BloomForCausalLM
|
||||
|
||||
|
||||
class Actor(Module):
|
||||
"""
|
||||
|
@ -87,11 +87,10 @@ class BLOOMActor(Actor):
|
|||
else:
|
||||
model = BloomForCausalLM(BloomConfig())
|
||||
if lora_path is not None:
|
||||
model = PeftModel.from_pretrained(model,lora_path)
|
||||
model = PeftModel.from_pretrained(model, lora_path)
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model)
|
||||
|
||||
|
||||
def print_trainable_parameters(self):
|
||||
self.get_base_model().print_trainable_parameters()
|
||||
|
|
@ -5,21 +5,22 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
|
||||
from coati.models.bloom import BLOOMRM, BLOOMCritic
|
||||
from easy_models import BLOOMActor
|
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
|
||||
from easy_models import BLOOMActor
|
||||
from peft import PeftModel
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from peft import PeftModel
|
||||
from easy_dataset import EasyPromptsDataset,EasySupervisedDataset
|
||||
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
|
@ -41,7 +42,7 @@ def main(args):
|
|||
if args.model == 'bloom':
|
||||
# initial_model = BLOOMActor(pretrained=args.pretrain)
|
||||
print('Using peft lora to load Bloom model as inital_model')
|
||||
initial_model = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path)
|
||||
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
||||
print('Using peft lora to load Bloom model as initial_model (Done)')
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
@ -54,7 +55,7 @@ def main(args):
|
|||
if rm_model_name == 'gpt2':
|
||||
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == 'bloom':
|
||||
print("load bloom reward model ",args.rm_pretrain)
|
||||
print("load bloom reward model ", args.rm_pretrain)
|
||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == 'opt':
|
||||
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
||||
|
@ -75,7 +76,7 @@ def main(args):
|
|||
if args.model == 'bloom':
|
||||
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
print('Using peft lora to load Bloom model as Actor')
|
||||
actor = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path)
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
||||
print('Using peft lora to load Bloom model as Actor (Done)')
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
@ -83,7 +84,7 @@ def main(args):
|
|||
if rm_model_name == 'gpt2':
|
||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif rm_model_name == 'bloom':
|
||||
print("load bloom critic ",args.rm_pretrain," lora_rank ",args.lora_rank," use_action_mask ",True)
|
||||
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
|
||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
print("load bloom critic (Done) ")
|
||||
elif rm_model_name == 'opt':
|
||||
|
@ -130,7 +131,7 @@ def main(args):
|
|||
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
|
||||
prompt_dataset = EasyPromptsDataset(args.prompt_path,tokenizer)
|
||||
prompt_dataset = EasyPromptsDataset(args.prompt_path, tokenizer)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
|
@ -14,19 +14,19 @@ from coati.trainer import SFTTrainer
|
|||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from datasets import load_dataset
|
||||
from easy_dataset import EasyDataset
|
||||
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast,AutoModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ColoParameter
|
||||
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from peft import LoraConfig, TaskType,get_peft_model,PeftModel
|
||||
from easy_dataset import EasyDataset
|
||||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
|
@ -48,17 +48,20 @@ def train(args):
|
|||
#if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
|
||||
if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \
|
||||
and os.path.exists(args.save_path+'/adapter_model.bin'):
|
||||
print("loading from saved peft model ",args.save_path)
|
||||
print("loading from saved peft model ", args.save_path)
|
||||
model = PeftModel.from_pretrained(model, args.save_path)
|
||||
else:
|
||||
#we'll use peft lora library to do the lora
|
||||
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
|
||||
#config lora with rank of lora_rank
|
||||
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1)
|
||||
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=lora_rank,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1)
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
@ -103,12 +106,12 @@ def train(args):
|
|||
logger.set_level('WARNING')
|
||||
|
||||
# configure dataset
|
||||
law_dataset = EasyDataset(args.dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text)
|
||||
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
||||
train_dataset = law_dataset
|
||||
print(train_dataset)
|
||||
eval_dataset = None
|
||||
if args.eval_dataset is not None:
|
||||
eval_dataset = EasyDataset(args.eval_dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text)
|
||||
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
||||
data_collator = default_collate
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
train_sampler = DistributedSampler(train_dataset,
|
||||
|
@ -181,7 +184,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument('--lr', type=float, default=5e-6)
|
||||
parser.add_argument('--accimulation_steps', type=int, default=8)
|
||||
parser.add_argument('--enable_peft_lora',action='store_true', default=False)
|
||||
parser.add_argument("--is_short_text",action='store_true', default=False)
|
||||
parser.add_argument('--enable_peft_lora', action='store_true', default=False)
|
||||
parser.add_argument("--is_short_text", action='store_true', default=False)
|
||||
args = parser.parse_args()
|
||||
train(args)
|
Loading…
Reference in New Issue