From 893706a82d5529e6a99861c31335c99b440f8842 Mon Sep 17 00:00:00 2001 From: rainatam Date: Fri, 31 Mar 2023 18:12:04 +0800 Subject: [PATCH] Update train script --- ptuning/arguments.py | 4 ++-- ptuning/train.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ptuning/arguments.py b/ptuning/arguments.py index 1c61f97..95d766f 100644 --- a/ptuning/arguments.py +++ b/ptuning/arguments.py @@ -203,8 +203,8 @@ class DataTrainingArguments: def __post_init__(self): - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") + if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None: + raise ValueError("Need either a dataset name or a training/validation/test file.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] diff --git a/ptuning/train.sh b/ptuning/train.sh index 1d03a25..3189829 100644 --- a/ptuning/train.sh +++ b/ptuning/train.sh @@ -9,7 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python3 main.py \ --response_column summary \ --overwrite_cache \ --model_name_or_path THUDM/chatglm-6b \ - --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR-dev \ + --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ --overwrite_output_dir \ --max_source_length 64 \ --max_target_length 64 \