mirror of https://github.com/THUDM/ChatGLM-6B
Merge 363fd5ac4f
into 401bf3a8a7
commit
2f991160c1
|
@ -138,7 +138,7 @@ from transformers import AutoConfig, AutoModel, AutoTokenizer
|
|||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||
```
|
||||
|
||||
1. 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数):
|
||||
1. 如果需要加载的是新 Checkpoint(只需包含 PrefixEncoder 参数):
|
||||
|
||||
```python
|
||||
config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
|
||||
|
|
|
@ -47,6 +47,11 @@ from arguments import ModelArguments, DataTrainingArguments
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
|
||||
if sys.platform == "win32":
|
||||
torch.distributed.init_process_group(backend='gloo')
|
||||
else:
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
PRE_SEQ_LEN=128
|
||||
LR=2e-2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||
CUDA_VISIBLE_DEVICES=0 torchrun main.py \
|
||||
--do_train \
|
||||
--train_file AdvertiseGen/train.json \
|
||||
--validation_file AdvertiseGen/dev.json \
|
||||
|
|
Loading…
Reference in New Issue