Update main.py

fix:修复预测时输入输出长度和超过512时,导致所有预测结果为空的bug
pull/954/head
邓钢清 2023-05-08 10:58:24 +08:00 committed by GitHub
parent 9d7bcf62e8
commit 486edfed81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -382,9 +382,10 @@ def main():
# Evaluation
results = {}
max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=512, temperature=0.95)
metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
@ -393,8 +394,7 @@ def main():
if training_args.do_predict:
logger.info("*** Predict ***")
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=512, do_sample=True, top_p=0.7, temperature=0.95)
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
metrics = predict_results.metrics
max_predict_samples = (
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)