[Feat]: Support batch inference in agents (#735)

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>
pull/740/head
BraisedPork 2024-04-22 17:03:47 +08:00 committed by GitHub
parent 2db5604288
commit 3be5894976
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 133 additions and 86 deletions

View File

@ -32,7 +32,8 @@ python streaming_inference.py \
--backend=lmdeploy \ # For HuggingFace models: hf
--model_path=internlm/internlm2-chat-20b \
--tp=2 \
--temperature=0.0 \
--temperature=1.0 \
--top_k=1 \
--dataset=math \
--output_path=math_lmdeploy.jsonl \
--do_eval

View File

@ -32,7 +32,8 @@ python streaming_inference.py \
--backend=lmdeploy \ # For HuggingFace models: hf
--model_path=internlm/internlm2-chat-20b \
--tp=2 \
--temperature=0.0 \
--temperature=1.0 \
--top_k=1 \
--dataset=math \
--output_path=math_lmdeploy.jsonl \
--do_eval

View File

@ -30,7 +30,7 @@ import os
import re
import sys
import traceback
from math import isclose
from math import isclose, ceil
from typing import Union
import jsonlines
@ -38,24 +38,34 @@ import numpy as np
from datasets import load_dataset
from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer,
Internlm2Agent, Internlm2Protocol, LMDeployPipeline,
get_tool)
IPythonInteractiveManager)
from pebble import ProcessPool
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
from tqdm import tqdm
# --------------------- modify the system prompt as needed ---------------------
# DEFAULT_PROMPT = (
# 'Integrate step-by-step reasoning and Python code to solve math problems '
# 'using the following guidelines:\n'
# '- Just write jupyter code to solve the problem without giving your thought;\n'
# r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
# 'units. \n')
DEFAULT_PROMPT = (
'Integrate step-by-step reasoning and Python code to solve math problems '
'using the following guidelines:\n'
'- Just write jupyter code to solve the problem without giving your thought;\n'
'- Analyze the question and write jupyter code to solve the problem;\n'
r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
'units. \n')
# ------------------------------------------------------------------------------
def parse_args():
parser = argparse.ArgumentParser(description='Math Code Interpreter')
parser.add_argument('--backend',
parser.add_argument(
'--backend',
type=str,
default='lmdeploy',
help='Which inference framework to use.',
@ -71,34 +81,49 @@ def parse_args():
type=str,
required=True,
help='Path to save inference results to, should be a `jsonl` file')
parser.add_argument('--dataset',
parser.add_argument(
'--dataset',
type=str,
default='math',
choices=['gsm8k', 'math'],
help='Dataset for inference')
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='Agent inference batch size')
parser.add_argument(
'--max_turn',
type=int,
default=3,
help=
'Maximum number of interaction rounds between the agent and environment'
)
parser.add_argument(
'--tp',
type=int,
default=1,
help='Number of tensor parallelism. It may be required in LMDelpoy.')
parser.add_argument('--temperature',
parser.add_argument(
'--temperature',
type=float,
default=0.1,
help='Temperature in next token prediction')
parser.add_argument('--top_p',
parser.add_argument(
'--top_p',
type=float,
default=0.8,
help='Parameter for Top-P Sampling.')
parser.add_argument('--top_k',
type=int,
default=None,
help='Parameter for Top-K Sampling.')
parser.add_argument('--stop_words',
parser.add_argument(
'--top_k', type=int, default=40, help='Parameter for Top-K Sampling.')
parser.add_argument(
'--stop_words',
type=str,
default=['<|action_end|>', '<|im_end|>'],
action='append',
help='Stop words')
parser.add_argument('--max_new_tokens',
parser.add_argument(
'--max_new_tokens',
type=int,
default=512,
help='Number of maximum generated tokens.')
@ -113,18 +138,21 @@ def parse_args():
# action='store_false',
# help='Disable the inference.'
# )
parser.add_argument('--do_eval',
parser.add_argument(
'--do_eval',
default=False,
action='store_true',
help='Whether to evaluate the inference results.')
parser.add_argument('--overwrite',
parser.add_argument(
'--overwrite',
default=False,
action='store_true',
help='Whether to overwrite the existing result file')
parser.add_argument('--debug',
default=False,
action='store_true',
help='Only infer the first 50 samples')
# parser.add_argument(
# '--debug',
# default=False,
# action='store_true',
# help='Only infer the first 50 samples')
return parser.parse_args()
@ -473,9 +501,8 @@ def symbolic_equal_process(a, b, output_queue):
def call_with_timeout(func, *args, timeout=1, **kwargs):
output_queue = multiprocessing.Queue()
process_args = args + (output_queue, )
process = multiprocessing.Process(target=func,
args=process_args,
kwargs=kwargs)
process = multiprocessing.Process(
target=func, args=process_args, kwargs=kwargs)
process.start()
process.join(timeout)
@ -487,25 +514,33 @@ def call_with_timeout(func, *args, timeout=1, **kwargs):
return output_queue.get()
def init_agent(backend: str, model_path: str, tp: int, **kwargs):
def init_agent(backend: str, max_turn: int, model_path: str, tp: int,
**kwargs):
if backend == 'lmdeploy':
model = LMDeployPipeline(path=model_path,
from lmdeploy import TurbomindEngineConfig
model = LMDeployPipeline(
path=model_path,
model_name='internlm2-chat',
meta_template=INTERNLM2_META,
tp=tp,
pipeline_cfg=dict(backend_config=TurbomindEngineConfig(tp=tp)),
**kwargs)
elif backend == 'hf':
model = HFTransformer(path=model_path,
meta_template=INTERNLM2_META,
**kwargs)
model = HFTransformer(
path=model_path, meta_template=INTERNLM2_META, **kwargs)
else:
raise NotImplementedError
agent = Internlm2Agent(llm=model,
agent = Internlm2Agent(
llm=model,
protocol=Internlm2Protocol(
meta_prompt=None,
interpreter_prompt=DEFAULT_PROMPT),
interpreter_executor=ActionExecutor(
actions=[get_tool('IPythonInteractive')]))
meta_prompt=None, interpreter_prompt=DEFAULT_PROMPT),
interpreter_executor=ActionExecutor(actions=[
IPythonInteractiveManager(
max_workers=200,
ci_lock=os.path.join(
os.path.dirname(__file__), '.ipython.lock'))
]),
max_turn=max_turn)
return agent
@ -522,8 +557,8 @@ def predict(args):
d['pred'], d['steps'], d['error'] = [], [], None
return d
dataset = load_dataset('gsm8k', 'main',
split='test').map(process, True)
dataset = load_dataset(
'gsm8k', 'main', split='test').map(process, True)
elif args.dataset == 'math':
@ -543,14 +578,15 @@ def predict(args):
d['error'] = None
return d
dataset = load_dataset('lighteval/MATH',
split='test').map(process, True)
dataset = load_dataset(
'lighteval/MATH', split='test').map(process, True)
else:
raise NotImplementedError
agent = init_agent(
backend=args.backend,
max_turn=args.max_turn,
model_path=args.model_path,
tp=args.tp,
temperature=args.temperature,
@ -559,26 +595,35 @@ def predict(args):
top_k=args.top_k,
max_new_tokens=args.max_new_tokens,
)
with jsonlines.open(args.output_path, 'w') as f:
for item in tqdm(
dataset if not args.debug else dataset.select(range(50))):
num_batches = ceil(len(dataset) / args.batch_size)
with jsonlines.open(args.output_path, 'w', flush=True) as f:
for i in tqdm(range(num_batches)):
batch = dataset.select(
range(i * args.batch_size,
min((i + 1) * args.batch_size, len(dataset))))
# for item in tqdm(
# dataset if not args.debug else dataset.select(range(50))):
try:
ret = agent.chat(item['query'])
rets = agent.batch_chat(batch['query'])
for item, ret in zip(batch, rets):
item['steps'] = ret.inner_steps
lang = [
step for step in item['steps']
if step['role'] == 'language'
]
item['pred'].append('😭' if not lang else
extract_answer(lang[-1]['content']) or '😭')
agent._interpreter_executor.actions[
'IPythonInteractive'].reset()
item['pred'].append('😭' if not lang else extract_answer(
lang[-1]['content']) or '😭')
f.write(item)
except Exception as e:
err = str(traceback.format_exc())
print(f'Error processing index {item["idx"]}: {e}\n{err}')
print(f'Processing batch data error: {e}\n{err}')
for item in batch:
item['error'] = err
f.write(item)
finally:
agent._interpreter_executor.actions[
'IPythonInteractiveManager'].reset()
def evaluate(args):
@ -606,7 +651,7 @@ def evaluate(args):
timeout_cnt += 1
except Exception as error:
print(error.__traceback__)
sys.exit()
# sys.exit()
progress_bar.update(1)
idx = 0