☀ feat: 训练相关修改

pull/1066/head
DealiAxy 2023-05-19 17:35:06 +08:00
parent 5b55467895
commit 6a43132a39
4 changed files with 13 additions and 11 deletions

View File

@ -2,10 +2,10 @@ PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
STEP=3000
CUDA_VISIBLE_DEVICES=0 python3 main.py \
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py \
--do_predict \
--validation_file AdvertiseGen/dev.json \
--test_file AdvertiseGen/dev.json \
--validation_file data/AdvertiseGen/dev.json \
--test_file data/AdvertiseGen/dev.json \
--overwrite_cache \
--prompt_column content \
--response_column summary \

View File

@ -1,10 +1,12 @@
PRE_SEQ_LEN=128
LR=2e-2
CUDA_VISIBLE_DEVICES=0 python3 main.py \
export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32'
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py \
--do_train \
--train_file AdvertiseGen/train.json \
--validation_file AdvertiseGen/dev.json \
--train_file data/AdvertiseGen/train.json \
--validation_file data/AdvertiseGen/dev.json \
--prompt_column content \
--response_column summary \
--overwrite_cache \

View File

@ -119,8 +119,7 @@ with gr.Blocks() as demo:
def main():
global model, tokenizer
parser = HfArgumentParser((
ModelArguments))
parser = HfArgumentParser((ModelArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
@ -158,7 +157,7 @@ def main():
model.transformer.prefix_encoder.float().cuda()
model = model.eval()
demo.queue().launch(share=False, inbrowser=True)
demo.queue().launch(share=False, inbrowser=True, server_port=11001)

View File

@ -1,7 +1,8 @@
PRE_SEQ_LEN=128
CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \
CUDA_VISIBLE_DEVICES=0,1 python3 web_demo.py \
--model_name_or_path THUDM/chatglm-6b \
--ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \
--pre_seq_len $PRE_SEQ_LEN
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4