diff --git a/ptuning/main.py b/ptuning/main.py index 17e18b5..3ce58f6 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -47,6 +47,11 @@ from arguments import ModelArguments, DataTrainingArguments logger = logging.getLogger(__name__) def main(): + + if sys.platform == "win32": + torch.distributed.init_process_group(backend='gloo') + else: + torch.distributed.init_process_group(backend='nccl') parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):