mirror of https://github.com/THUDM/ChatGLM-6B
Merge 921d9d4f13
into 401bf3a8a7
commit
4b35efdf50
|
@ -47,7 +47,7 @@ from arguments import ModelArguments, DataTrainingArguments
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
PRE_SEQ_LEN=128
|
||||
LR=2e-2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||
CUDA_VISIBLE_DEVICES=0 torchrun main.py \
|
||||
--do_train \
|
||||
--train_file AdvertiseGen/train.json \
|
||||
--validation_file AdvertiseGen/dev.json \
|
||||
|
|
Loading…
Reference in New Issue