mirror of https://github.com/THUDM/ChatGLM-6B
parent
9d7bcf62e8
commit
486edfed81
|
@ -382,9 +382,10 @@ def main():
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
|
max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=512, temperature=0.95)
|
metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
|
||||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
|
@ -393,8 +394,7 @@ def main():
|
||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logger.info("*** Predict ***")
|
logger.info("*** Predict ***")
|
||||||
|
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
|
||||||
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=512, do_sample=True, top_p=0.7, temperature=0.95)
|
|
||||||
metrics = predict_results.metrics
|
metrics = predict_results.metrics
|
||||||
max_predict_samples = (
|
max_predict_samples = (
|
||||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||||
|
|
Loading…
Reference in New Issue