change `max_tokens` to `max_new_tokens`

pull/695/head
wangzy 2024-02-05 16:24:03 +08:00
parent be4507361a
commit 62713e2a83
1 changed files with 2 additions and 2 deletions

View File

@ -98,7 +98,7 @@ def parse_args():
default=['<|action_end|>', '<|im_end|>'], default=['<|action_end|>', '<|im_end|>'],
action='append', action='append',
help='Stop words') help='Stop words')
parser.add_argument('--max_tokens', parser.add_argument('--max_new_tokens',
type=int, type=int,
default=512, default=512,
help='Number of maximum generated tokens.') help='Number of maximum generated tokens.')
@ -557,7 +557,7 @@ def predict(args):
stop_words=args.stop_words, stop_words=args.stop_words,
top_p=args.top_p, top_p=args.top_p,
top_k=args.top_k, top_k=args.top_k,
max_tokens=args.max_tokens, max_new_tokens=args.max_new_tokens,
) )
with jsonlines.open(args.output_path, 'w') as f: with jsonlines.open(args.output_path, 'w') as f:
for item in tqdm( for item in tqdm(