diff --git a/ptuning/main.py b/ptuning/main.py index 43ecdf8..49b08b0 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -382,9 +382,10 @@ def main(): # Evaluation results = {} + max_seq_length = data_args.max_source_length + data_args.max_target_length + 1 if training_args.do_eval: logger.info("*** Evaluate ***") - metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=512, temperature=0.95) + metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) @@ -393,8 +394,7 @@ def main(): if training_args.do_predict: logger.info("*** Predict ***") - - predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=512, do_sample=True, top_p=0.7, temperature=0.95) + predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95) metrics = predict_results.metrics max_predict_samples = ( data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)