pull/1288/merge
Dr. Artificial曾小健 2024-08-22 17:35:46 +08:00 committed by GitHub
commit 2f991160c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 2 deletions

View File

@ -138,7 +138,7 @@ from transformers import AutoConfig, AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
```
1. 如果需要加载的是新 Checkpoint只包含 PrefixEncoder 参数):
1. 如果需要加载的是新 Checkpoint包含 PrefixEncoder 参数):
```python
config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)

View File

@ -47,6 +47,11 @@ from arguments import ModelArguments, DataTrainingArguments
logger = logging.getLogger(__name__)
def main():
if sys.platform == "win32":
torch.distributed.init_process_group(backend='gloo')
else:
torch.distributed.init_process_group(backend='nccl')
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):

View File

@ -1,7 +1,7 @@
PRE_SEQ_LEN=128
LR=2e-2
CUDA_VISIBLE_DEVICES=0 python3 main.py \
CUDA_VISIBLE_DEVICES=0 torchrun main.py \
--do_train \
--train_file AdvertiseGen/train.json \
--validation_file AdvertiseGen/dev.json \