Update main.py

多平台彻底解决 Default process group has not been initialized, please make sure to call init_process_group 问题。
issue里面有多个这个问题了。
pull/1288/head
曾小健 2023-06-30 11:04:12 +08:00 committed by GitHub
parent 5d3f823bcc
commit 363fd5ac4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 0 deletions

View File

@ -48,6 +48,11 @@ logger = logging.getLogger(__name__)
def main():
if sys.platform == "win32":
torch.distributed.init_process_group(backend='gloo')
else:
torch.distributed.init_process_group(backend='nccl')
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,