diff --git a/ptuning/train.sh b/ptuning/train.sh index efc9a16..bf75216 100644 --- a/ptuning/train.sh +++ b/ptuning/train.sh @@ -1,7 +1,7 @@ PRE_SEQ_LEN=128 LR=2e-2 -CUDA_VISIBLE_DEVICES=0 python3 main.py \ +CUDA_VISIBLE_DEVICES=0 torchrun main.py \ --do_train \ --train_file AdvertiseGen/train.json \ --validation_file AdvertiseGen/dev.json \