diff --git a/ptuning/README.md b/ptuning/README.md index f92a328..dbfca83 100644 --- a/ptuning/README.md +++ b/ptuning/README.md @@ -138,7 +138,7 @@ from transformers import AutoConfig, AutoModel, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) ``` -1. 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数): +1. 如果需要加载的是新 Checkpoint(只需包含 PrefixEncoder 参数): ```python config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128) diff --git a/ptuning/main.py b/ptuning/main.py index 17e18b5..3ce58f6 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -47,6 +47,11 @@ from arguments import ModelArguments, DataTrainingArguments logger = logging.getLogger(__name__) def main(): + + if sys.platform == "win32": + torch.distributed.init_process_group(backend='gloo') + else: + torch.distributed.init_process_group(backend='nccl') parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): diff --git a/ptuning/train.sh b/ptuning/train.sh index efc9a16..bf75216 100644 --- a/ptuning/train.sh +++ b/ptuning/train.sh @@ -1,7 +1,7 @@ PRE_SEQ_LEN=128 LR=2e-2 -CUDA_VISIBLE_DEVICES=0 python3 main.py \ +CUDA_VISIBLE_DEVICES=0 torchrun main.py \ --do_train \ --train_file AdvertiseGen/train.json \ --validation_file AdvertiseGen/dev.json \