From 9afc36ae5080e6199abd91534954b77895b797a1 Mon Sep 17 00:00:00 2001 From: Barbery <380032007@qq.com> Date: Fri, 2 Jun 2023 14:24:15 +0800 Subject: [PATCH 1/2] use torchrun --- ptuning/train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ptuning/train.sh b/ptuning/train.sh index efc9a16..bf75216 100644 --- a/ptuning/train.sh +++ b/ptuning/train.sh @@ -1,7 +1,7 @@ PRE_SEQ_LEN=128 LR=2e-2 -CUDA_VISIBLE_DEVICES=0 python3 main.py \ +CUDA_VISIBLE_DEVICES=0 torchrun main.py \ --do_train \ --train_file AdvertiseGen/train.json \ --validation_file AdvertiseGen/dev.json \ From 921d9d4f1355c57ad6dfd3604ee310725fb55027 Mon Sep 17 00:00:00 2001 From: Barbery <380032007@qq.com> Date: Fri, 2 Jun 2023 14:28:58 +0800 Subject: [PATCH 2/2] add init_process_group --- ptuning/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ptuning/main.py b/ptuning/main.py index 17e18b5..ff01de2 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -47,7 +47,7 @@ from arguments import ModelArguments, DataTrainingArguments logger = logging.getLogger(__name__) def main(): - + torch.distributed.init_process_group(backend='nccl') parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file,