pull/1173/merge
Barbery 2024-07-28 21:18:06 +08:00 committed by GitHub
commit 4b35efdf50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -47,7 +47,7 @@ from arguments import ModelArguments, DataTrainingArguments
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def main(): def main():
torch.distributed.init_process_group(backend='nccl')
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file, # If we pass only one argument to the script and it's the path to a json file,

View File

@ -1,7 +1,7 @@
PRE_SEQ_LEN=128 PRE_SEQ_LEN=128
LR=2e-2 LR=2e-2
CUDA_VISIBLE_DEVICES=0 python3 main.py \ CUDA_VISIBLE_DEVICES=0 torchrun main.py \
--do_train \ --do_train \
--train_file AdvertiseGen/train.json \ --train_file AdvertiseGen/train.json \
--validation_file AdvertiseGen/dev.json \ --validation_file AdvertiseGen/dev.json \