mirror of https://github.com/THUDM/ChatGLM-6B
add code
parent
d835c4b001
commit
00046a5750
|
@ -1,28 +1,28 @@
|
||||||
|
PRE_SEQ_LEN=128
|
||||||
|
LR=2e-2
|
||||||
|
|
||||||
LR=1e-4
|
deepspeed --include="localhost:0,1" main.py \
|
||||||
|
|
||||||
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
|
|
||||||
|
|
||||||
deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \
|
|
||||||
--deepspeed deepspeed.json \
|
--deepspeed deepspeed.json \
|
||||||
--do_train \
|
--do_train \
|
||||||
--train_file AdvertiseGen/train.json \
|
--train_file AdvertiseGen/train.json \
|
||||||
--test_file AdvertiseGen/dev.json \
|
--test_file AdvertiseGen/dev.json \
|
||||||
--prompt_column content \
|
--prompt_column content \
|
||||||
--response_column summary \
|
--response_column summary \
|
||||||
|
--preprocessing_num_workers 10 \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--model_name_or_path THUDM/chatglm-6b \
|
--model_name_or_path THUDM/chatglm-6b \
|
||||||
--output_dir ./output/adgen-chatglm-6b-ft-$LR \
|
--output_dir ./output/ds-chatglm-6b-ptuning-$LR \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--max_source_length 64 \
|
--max_source_length 64 \
|
||||||
--max_target_length 64 \
|
--max_target_length 64 \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 2 \
|
||||||
--per_device_eval_batch_size 1 \
|
--per_device_eval_batch_size 1 \
|
||||||
--gradient_accumulation_steps 1 \
|
--gradient_accumulation_steps 1 \
|
||||||
--predict_with_generate \
|
--predict_with_generate \
|
||||||
--max_steps 5000 \
|
--max_steps 1000 \
|
||||||
--logging_steps 10 \
|
--logging_steps 10 \
|
||||||
--save_steps 1000 \
|
--save_steps 1000 \
|
||||||
--learning_rate $LR \
|
--learning_rate $LR \
|
||||||
|
--pre_seq_len $PRE_SEQ_LEN \
|
||||||
|
--save_total_limit 1 \
|
||||||
--fp16
|
--fp16
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,436 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Fine-tuning the library models for sequence to sequence.
|
||||||
|
"""
|
||||||
|
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from datasets import load_dataset
|
||||||
|
import jieba
|
||||||
|
from rouge_chinese import Rouge
|
||||||
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModel,
|
||||||
|
AutoTokenizer,
|
||||||
|
DataCollatorForSeq2Seq,
|
||||||
|
HfArgumentParser,
|
||||||
|
Seq2SeqTrainingArguments,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
|
from trainer_seq2seq import Seq2SeqTrainer
|
||||||
|
|
||||||
|
from arguments import ModelArguments, DataTrainingArguments
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||||
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
|
# If we pass only one argument to the script and it's the path to a json file,
|
||||||
|
# let's parse it to get our arguments.
|
||||||
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_args.should_log:
|
||||||
|
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
|
||||||
|
log_level = training_args.get_process_log_level()
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
# datasets.utils.logging.set_verbosity(log_level)
|
||||||
|
transformers.utils.logging.set_verbosity(log_level)
|
||||||
|
transformers.utils.logging.enable_default_handler()
|
||||||
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
|
# Log on each process the small summary:
|
||||||
|
logger.warning(
|
||||||
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||||
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
|
)
|
||||||
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
|
transformers.logging.set_verbosity_info()
|
||||||
|
# Set seed before initializing model.
|
||||||
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
|
data_files = {}
|
||||||
|
if data_args.train_file is not None:
|
||||||
|
data_files["train"] = data_args.train_file
|
||||||
|
extension = data_args.train_file.split(".")[-1]
|
||||||
|
if data_args.validation_file is not None:
|
||||||
|
data_files["validation"] = data_args.validation_file
|
||||||
|
extension = data_args.validation_file.split(".")[-1]
|
||||||
|
if data_args.test_file is not None:
|
||||||
|
data_files["test"] = data_args.test_file
|
||||||
|
extension = data_args.test_file.split(".")[-1]
|
||||||
|
|
||||||
|
raw_datasets = load_dataset(
|
||||||
|
extension,
|
||||||
|
data_files=data_files,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load pretrained model and tokenizer
|
||||||
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
||||||
|
config.pre_seq_len = model_args.pre_seq_len
|
||||||
|
config.prefix_projection = model_args.prefix_projection
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
if model_args.ptuning_checkpoint is not None:
|
||||||
|
# Evaluation
|
||||||
|
# Loading extra state dict of prefix encoder
|
||||||
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
|
||||||
|
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
|
||||||
|
new_prefix_state_dict = {}
|
||||||
|
for k, v in prefix_state_dict.items():
|
||||||
|
if k.startswith("transformer.prefix_encoder."):
|
||||||
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||||
|
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||||
|
else:
|
||||||
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config,
|
||||||
|
trust_remote_code=True, device_map="auto")
|
||||||
|
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
model.is_parallelizable = True
|
||||||
|
model.model_parallel = True
|
||||||
|
model.config.use_cache = False
|
||||||
|
|
||||||
|
if model_args.quantization_bit is not None:
|
||||||
|
print(f"Quantized to {model_args.quantization_bit} bit")
|
||||||
|
model = model.quantize(model_args.quantization_bit)
|
||||||
|
if model_args.pre_seq_len is not None:
|
||||||
|
# P-tuning v2
|
||||||
|
model = model.half()
|
||||||
|
model.transformer.prefix_encoder.float()
|
||||||
|
else:
|
||||||
|
# Finetune
|
||||||
|
model = model.float()
|
||||||
|
|
||||||
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||||
|
|
||||||
|
# Preprocessing the datasets.
|
||||||
|
# We need to tokenize inputs and targets.
|
||||||
|
if training_args.do_train:
|
||||||
|
column_names = raw_datasets["train"].column_names
|
||||||
|
elif training_args.do_eval:
|
||||||
|
column_names = raw_datasets["validation"].column_names
|
||||||
|
elif training_args.do_predict:
|
||||||
|
column_names = raw_datasets["test"].column_names
|
||||||
|
else:
|
||||||
|
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the column names for input/target.
|
||||||
|
prompt_column = data_args.prompt_column
|
||||||
|
response_column = data_args.response_column
|
||||||
|
history_column = data_args.history_column
|
||||||
|
|
||||||
|
# Temporarily set max_target_length for training.
|
||||||
|
max_target_length = data_args.max_target_length
|
||||||
|
|
||||||
|
def preprocess_function_eval(examples):
|
||||||
|
inputs, targets = [], []
|
||||||
|
for i in range(len(examples[prompt_column])):
|
||||||
|
if examples[prompt_column][i] and examples[response_column][i]:
|
||||||
|
query = examples[prompt_column][i]
|
||||||
|
if history_column is None or len(examples[history_column][i]) == 0:
|
||||||
|
prompt = query
|
||||||
|
else:
|
||||||
|
prompt = ""
|
||||||
|
history = examples[history_column][i]
|
||||||
|
for turn_idx, (old_query, response) in enumerate(history):
|
||||||
|
prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
|
||||||
|
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
||||||
|
inputs.append(prompt)
|
||||||
|
targets.append(examples[response_column][i])
|
||||||
|
|
||||||
|
inputs = [prefix + inp for inp in inputs]
|
||||||
|
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
|
||||||
|
labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
|
||||||
|
|
||||||
|
if data_args.ignore_pad_token_for_loss:
|
||||||
|
labels["input_ids"] = [
|
||||||
|
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||||
|
]
|
||||||
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def preprocess_function_train(examples):
|
||||||
|
max_seq_length = data_args.max_source_length + data_args.max_target_length
|
||||||
|
|
||||||
|
model_inputs = {
|
||||||
|
"input_ids": [],
|
||||||
|
"labels": [],
|
||||||
|
}
|
||||||
|
for i in range(len(examples[prompt_column])):
|
||||||
|
if examples[prompt_column][i] and examples[response_column][i]:
|
||||||
|
query, answer = examples[prompt_column][i], examples[response_column][i]
|
||||||
|
|
||||||
|
if history_column is None:
|
||||||
|
prompt = query
|
||||||
|
else:
|
||||||
|
prompt = ""
|
||||||
|
history = examples[history_column][i]
|
||||||
|
for turn_idx, (old_query, response) in enumerate(history):
|
||||||
|
prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
|
||||||
|
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
||||||
|
|
||||||
|
prompt = prefix + prompt
|
||||||
|
a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||||
|
b_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
||||||
|
|
||||||
|
if len(a_ids) > data_args.max_source_length - 1:
|
||||||
|
a_ids = a_ids[: data_args.max_source_length - 1]
|
||||||
|
|
||||||
|
if len(b_ids) > data_args.max_target_length - 2:
|
||||||
|
b_ids = b_ids[: data_args.max_target_length - 2]
|
||||||
|
|
||||||
|
input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
|
||||||
|
|
||||||
|
context_length = input_ids.index(tokenizer.bos_token_id)
|
||||||
|
mask_position = context_length - 1
|
||||||
|
labels = [-100] * context_length + input_ids[mask_position+1:]
|
||||||
|
|
||||||
|
pad_len = max_seq_length - len(input_ids)
|
||||||
|
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
|
||||||
|
labels = labels + [tokenizer.pad_token_id] * pad_len
|
||||||
|
if data_args.ignore_pad_token_for_loss:
|
||||||
|
labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
|
||||||
|
|
||||||
|
model_inputs["input_ids"].append(input_ids)
|
||||||
|
model_inputs["labels"].append(labels)
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def print_dataset_example(example):
|
||||||
|
print("input_ids",example["input_ids"])
|
||||||
|
print("inputs", tokenizer.decode(example["input_ids"]))
|
||||||
|
print("label_ids", example["labels"])
|
||||||
|
print("labels", tokenizer.decode(example["labels"]))
|
||||||
|
|
||||||
|
if training_args.do_train:
|
||||||
|
if "train" not in raw_datasets:
|
||||||
|
raise ValueError("--do_train requires a train dataset")
|
||||||
|
train_dataset = raw_datasets["train"]
|
||||||
|
if data_args.max_train_samples is not None:
|
||||||
|
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||||
|
train_dataset = train_dataset.select(range(max_train_samples))
|
||||||
|
with training_args.main_process_first(desc="train dataset map pre-processing"):
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
preprocess_function_train,
|
||||||
|
batched=True,
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
remove_columns=column_names,
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
desc="Running tokenizer on train dataset",
|
||||||
|
)
|
||||||
|
print_dataset_example(train_dataset[0])
|
||||||
|
|
||||||
|
if training_args.do_eval:
|
||||||
|
max_target_length = data_args.val_max_target_length
|
||||||
|
if "validation" not in raw_datasets:
|
||||||
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
|
eval_dataset = raw_datasets["validation"]
|
||||||
|
if data_args.max_eval_samples is not None:
|
||||||
|
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||||
|
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
||||||
|
with training_args.main_process_first(desc="validation dataset map pre-processing"):
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
preprocess_function_eval,
|
||||||
|
batched=True,
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
remove_columns=column_names,
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
desc="Running tokenizer on validation dataset",
|
||||||
|
)
|
||||||
|
print_dataset_example(eval_dataset[0])
|
||||||
|
|
||||||
|
if training_args.do_predict:
|
||||||
|
max_target_length = data_args.val_max_target_length
|
||||||
|
if "test" not in raw_datasets:
|
||||||
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
|
predict_dataset = raw_datasets["test"]
|
||||||
|
if data_args.max_predict_samples is not None:
|
||||||
|
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
||||||
|
predict_dataset = predict_dataset.select(range(max_predict_samples))
|
||||||
|
with training_args.main_process_first(desc="prediction dataset map pre-processing"):
|
||||||
|
predict_dataset = predict_dataset.map(
|
||||||
|
preprocess_function_eval,
|
||||||
|
batched=True,
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
remove_columns=column_names,
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
desc="Running tokenizer on prediction dataset",
|
||||||
|
)
|
||||||
|
print_dataset_example(predict_dataset[0])
|
||||||
|
|
||||||
|
# Data collator
|
||||||
|
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
|
tokenizer,
|
||||||
|
model=model,
|
||||||
|
label_pad_token_id=label_pad_token_id,
|
||||||
|
pad_to_multiple_of=None,
|
||||||
|
padding=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Metric
|
||||||
|
def compute_metrics(eval_preds):
|
||||||
|
preds, labels = eval_preds
|
||||||
|
if isinstance(preds, tuple):
|
||||||
|
preds = preds[0]
|
||||||
|
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||||
|
if data_args.ignore_pad_token_for_loss:
|
||||||
|
# Replace -100 in the labels as we can't decode them.
|
||||||
|
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||||
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
|
|
||||||
|
score_dict = {
|
||||||
|
"rouge-1": [],
|
||||||
|
"rouge-2": [],
|
||||||
|
"rouge-l": [],
|
||||||
|
"bleu-4": []
|
||||||
|
}
|
||||||
|
for pred, label in zip(decoded_preds, decoded_labels):
|
||||||
|
hypothesis = list(jieba.cut(pred))
|
||||||
|
reference = list(jieba.cut(label))
|
||||||
|
rouge = Rouge()
|
||||||
|
scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
|
||||||
|
result = scores[0]
|
||||||
|
|
||||||
|
for k, v in result.items():
|
||||||
|
score_dict[k].append(round(v["f"] * 100, 4))
|
||||||
|
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||||
|
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||||
|
|
||||||
|
for k, v in score_dict.items():
|
||||||
|
score_dict[k] = float(np.mean(v))
|
||||||
|
return score_dict
|
||||||
|
|
||||||
|
# Override the decoding parameters of Seq2SeqTrainer
|
||||||
|
# training_args.generation_max_length = (
|
||||||
|
# training_args.generation_max_length
|
||||||
|
# if training_args.generation_max_length is not None
|
||||||
|
# else data_args.val_max_target_length
|
||||||
|
# )
|
||||||
|
# training_args.generation_num_beams = (
|
||||||
|
# data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||||
|
# )
|
||||||
|
# Initialize our Trainer
|
||||||
|
trainer = Seq2SeqTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset if training_args.do_train else None,
|
||||||
|
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||||
|
save_prefixencoder=model_args.pre_seq_len is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
|
checkpoint = None
|
||||||
|
if training_args.resume_from_checkpoint is not None:
|
||||||
|
checkpoint = training_args.resume_from_checkpoint
|
||||||
|
# elif last_checkpoint is not None:
|
||||||
|
# checkpoint = last_checkpoint
|
||||||
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
|
# trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|
||||||
|
metrics = train_result.metrics
|
||||||
|
max_train_samples = (
|
||||||
|
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||||
|
)
|
||||||
|
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||||
|
|
||||||
|
trainer.log_metrics("train", metrics)
|
||||||
|
trainer.save_metrics("train", metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
results = {}
|
||||||
|
max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
|
||||||
|
if training_args.do_eval:
|
||||||
|
logger.info("*** Evaluate ***")
|
||||||
|
metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
|
||||||
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
|
trainer.log_metrics("eval", metrics)
|
||||||
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
if training_args.do_predict:
|
||||||
|
logger.info("*** Predict ***")
|
||||||
|
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
|
||||||
|
metrics = predict_results.metrics
|
||||||
|
max_predict_samples = (
|
||||||
|
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||||
|
)
|
||||||
|
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||||
|
|
||||||
|
trainer.log_metrics("predict", metrics)
|
||||||
|
trainer.save_metrics("predict", metrics)
|
||||||
|
|
||||||
|
if trainer.is_world_process_zero():
|
||||||
|
if training_args.predict_with_generate:
|
||||||
|
predictions = tokenizer.batch_decode(
|
||||||
|
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||||
|
)
|
||||||
|
predictions = [pred.strip() for pred in predictions]
|
||||||
|
labels = tokenizer.batch_decode(
|
||||||
|
predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||||
|
)
|
||||||
|
labels = [label.strip() for label in labels]
|
||||||
|
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
|
||||||
|
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||||
|
for p, l in zip(predictions, labels):
|
||||||
|
res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
|
||||||
|
writer.write(f"{res}\n")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _mp_fn(index):
|
||||||
|
# For xla_spawn (TPUs)
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,28 @@
|
||||||
|
PRE_SEQ_LEN=128
|
||||||
|
LR=2e-2
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=1,2 python3 main_parallel.py \
|
||||||
|
--do_train \
|
||||||
|
--train_file AdvertiseGen/train.json \
|
||||||
|
--test_file AdvertiseGen/dev.json \
|
||||||
|
--prompt_column content \
|
||||||
|
--response_column summary \
|
||||||
|
--preprocessing_num_workers 10 \
|
||||||
|
--overwrite_cache \
|
||||||
|
--model_name_or_path THUDM/chatglm-6b \
|
||||||
|
--output_dir ./output/parallel-chatglm-6b-ptuning-$LR \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--max_source_length 64 \
|
||||||
|
--max_target_length 64 \
|
||||||
|
--per_device_train_batch_size 2 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 1 \
|
||||||
|
--predict_with_generate \
|
||||||
|
--max_steps 1000 \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 1000 \
|
||||||
|
--learning_rate $LR \
|
||||||
|
--pre_seq_len $PRE_SEQ_LEN \
|
||||||
|
--save_total_limit 1 \
|
||||||
|
--gradient_checkpointing \
|
||||||
|
--fp16
|
|
@ -3,10 +3,12 @@ LR=2e-2
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||||
--do_train \
|
--do_train \
|
||||||
|
--fp16 \
|
||||||
--train_file AdvertiseGen/train.json \
|
--train_file AdvertiseGen/train.json \
|
||||||
--validation_file AdvertiseGen/dev.json \
|
--validation_file AdvertiseGen/dev.json \
|
||||||
--prompt_column content \
|
--prompt_column content \
|
||||||
--response_column summary \
|
--response_column summary \
|
||||||
|
--preprocessing_num_workers 10 \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--model_name_or_path THUDM/chatglm-6b \
|
--model_name_or_path THUDM/chatglm-6b \
|
||||||
--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
|
--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
|
||||||
|
@ -21,6 +23,5 @@ CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||||
--logging_steps 10 \
|
--logging_steps 10 \
|
||||||
--save_steps 1000 \
|
--save_steps 1000 \
|
||||||
--learning_rate $LR \
|
--learning_rate $LR \
|
||||||
--pre_seq_len $PRE_SEQ_LEN \
|
--pre_seq_len $PRE_SEQ_LEN
|
||||||
--quantization_bit 4
|
# --quantization_bit 4
|
||||||
|
|
||||||
|
|
1414
ptuning/trainer.py
1414
ptuning/trainer.py
File diff suppressed because it is too large
Load Diff
|
@ -12,22 +12,68 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
from transformers.data.data_collator import DataCollator
|
||||||
|
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from trainer import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PredictionOutput
|
from transformers.modeling_utils import PreTrainedModel, unwrap_model, WEIGHTS_NAME
|
||||||
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
from transformers.trainer_utils import EvalPrediction, PredictionOutput
|
||||||
|
from transformers.training_args import TrainingArguments
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
TRAINING_ARGS_NAME = "training_args.bin"
|
||||||
|
|
||||||
class Seq2SeqTrainer(Trainer):
|
class Seq2SeqTrainer(Trainer):
|
||||||
|
def __init__(self, save_prefixencoder=None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.save_prefixencoder = save_prefixencoder
|
||||||
|
|
||||||
|
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||||
|
# If we are executing this function, we are the process zero, so we don't check for that.
|
||||||
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||||
|
# Save a trained model and configuration using `save_pretrained()`.
|
||||||
|
# They can then be reloaded using `from_pretrained()`
|
||||||
|
if not isinstance(self.model, PreTrainedModel):
|
||||||
|
if isinstance(unwrap_model(self.model), PreTrainedModel):
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
unwrap_model(self.model).save_pretrained(output_dir, state_dict=filtered_state_dict)
|
||||||
|
else:
|
||||||
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
|
else:
|
||||||
|
if self.save_prefixencoder:
|
||||||
|
print("Saving PrefixEncoder")
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
filtered_state_dict = {}
|
||||||
|
for k, v in self.model.named_parameters():
|
||||||
|
if v.requires_grad:
|
||||||
|
filtered_state_dict[k] = state_dict[k]
|
||||||
|
self.model.save_pretrained(output_dir, state_dict=filtered_state_dict)
|
||||||
|
else:
|
||||||
|
print("Saving the whole model")
|
||||||
|
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
# Good practice: save your training arguments together with the trained model
|
||||||
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
eval_dataset: Optional[Dataset] = None,
|
eval_dataset: Optional[Dataset] = None,
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
protobuf
|
protobuf
|
||||||
transformers==4.27.1
|
transformers==4.32.0
|
||||||
cpm_kernels
|
cpm_kernels
|
||||||
torch>=1.10
|
torch>=1.10
|
||||||
gradio
|
gradio
|
||||||
mdtex2html
|
mdtex2html
|
||||||
sentencepiece
|
sentencepiece
|
||||||
accelerate
|
accelerate
|
||||||
|
datasets
|
||||||
|
jieba
|
||||||
|
rouge_chinese
|
||||||
|
deepspeed
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue