[Feat]: Support batch inference in agents (#735)

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>
pull/740/head
BraisedPork 2024-04-22 17:03:47 +08:00 committed by GitHub
parent 2db5604288
commit 3be5894976
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 133 additions and 86 deletions

View File

@ -32,7 +32,8 @@ python streaming_inference.py \
--backend=lmdeploy \ # For HuggingFace models: hf --backend=lmdeploy \ # For HuggingFace models: hf
--model_path=internlm/internlm2-chat-20b \ --model_path=internlm/internlm2-chat-20b \
--tp=2 \ --tp=2 \
--temperature=0.0 \ --temperature=1.0 \
--top_k=1 \
--dataset=math \ --dataset=math \
--output_path=math_lmdeploy.jsonl \ --output_path=math_lmdeploy.jsonl \
--do_eval --do_eval

View File

@ -32,7 +32,8 @@ python streaming_inference.py \
--backend=lmdeploy \ # For HuggingFace models: hf --backend=lmdeploy \ # For HuggingFace models: hf
--model_path=internlm/internlm2-chat-20b \ --model_path=internlm/internlm2-chat-20b \
--tp=2 \ --tp=2 \
--temperature=0.0 \ --temperature=1.0 \
--top_k=1 \
--dataset=math \ --dataset=math \
--output_path=math_lmdeploy.jsonl \ --output_path=math_lmdeploy.jsonl \
--do_eval --do_eval

View File

@ -30,7 +30,7 @@ import os
import re import re
import sys import sys
import traceback import traceback
from math import isclose from math import isclose, ceil
from typing import Union from typing import Union
import jsonlines import jsonlines
@ -38,28 +38,38 @@ import numpy as np
from datasets import load_dataset from datasets import load_dataset
from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer, from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer,
Internlm2Agent, Internlm2Protocol, LMDeployPipeline, Internlm2Agent, Internlm2Protocol, LMDeployPipeline,
get_tool) IPythonInteractiveManager)
from pebble import ProcessPool from pebble import ProcessPool
from sympy import N, simplify from sympy import N, simplify
from sympy.parsing.latex import parse_latex from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr from sympy.parsing.sympy_parser import parse_expr
from tqdm import tqdm from tqdm import tqdm
# --------------------- modify the system prompt as needed ---------------------
# DEFAULT_PROMPT = (
# 'Integrate step-by-step reasoning and Python code to solve math problems '
# 'using the following guidelines:\n'
# '- Just write jupyter code to solve the problem without giving your thought;\n'
# r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
# 'units. \n')
DEFAULT_PROMPT = ( DEFAULT_PROMPT = (
'Integrate step-by-step reasoning and Python code to solve math problems ' 'Integrate step-by-step reasoning and Python code to solve math problems '
'using the following guidelines:\n' 'using the following guidelines:\n'
'- Just write jupyter code to solve the problem without giving your thought;\n' '- Analyze the question and write jupyter code to solve the problem;\n'
r"- Present the final result in LaTeX using a '\boxed{{}}' without any " r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
'units. \n') 'units. \n')
# ------------------------------------------------------------------------------
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Math Code Interpreter') parser = argparse.ArgumentParser(description='Math Code Interpreter')
parser.add_argument('--backend', parser.add_argument(
type=str, '--backend',
default='lmdeploy', type=str,
help='Which inference framework to use.', default='lmdeploy',
choices=['lmdeploy', 'hf']) help='Which inference framework to use.',
choices=['lmdeploy', 'hf'])
parser.add_argument( parser.add_argument(
'--model_path', '--model_path',
type=str, type=str,
@ -71,37 +81,52 @@ def parse_args():
type=str, type=str,
required=True, required=True,
help='Path to save inference results to, should be a `jsonl` file') help='Path to save inference results to, should be a `jsonl` file')
parser.add_argument('--dataset', parser.add_argument(
type=str, '--dataset',
default='math', type=str,
choices=['gsm8k', 'math'], default='math',
help='Dataset for inference') choices=['gsm8k', 'math'],
help='Dataset for inference')
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='Agent inference batch size')
parser.add_argument(
'--max_turn',
type=int,
default=3,
help=
'Maximum number of interaction rounds between the agent and environment'
)
parser.add_argument( parser.add_argument(
'--tp', '--tp',
type=int, type=int,
default=1, default=1,
help='Number of tensor parallelism. It may be required in LMDelpoy.') help='Number of tensor parallelism. It may be required in LMDelpoy.')
parser.add_argument('--temperature', parser.add_argument(
type=float, '--temperature',
default=0.1, type=float,
help='Temperature in next token prediction') default=0.1,
parser.add_argument('--top_p', help='Temperature in next token prediction')
type=float, parser.add_argument(
default=0.8, '--top_p',
help='Parameter for Top-P Sampling.') type=float,
parser.add_argument('--top_k', default=0.8,
type=int, help='Parameter for Top-P Sampling.')
default=None, parser.add_argument(
help='Parameter for Top-K Sampling.') '--top_k', type=int, default=40, help='Parameter for Top-K Sampling.')
parser.add_argument('--stop_words', parser.add_argument(
type=str, '--stop_words',
default=['<|action_end|>', '<|im_end|>'], type=str,
action='append', default=['<|action_end|>', '<|im_end|>'],
help='Stop words') action='append',
parser.add_argument('--max_new_tokens', help='Stop words')
type=int, parser.add_argument(
default=512, '--max_new_tokens',
help='Number of maximum generated tokens.') type=int,
default=512,
help='Number of maximum generated tokens.')
parser.add_argument( parser.add_argument(
'--do_infer', '--do_infer',
default=True, default=True,
@ -113,18 +138,21 @@ def parse_args():
# action='store_false', # action='store_false',
# help='Disable the inference.' # help='Disable the inference.'
# ) # )
parser.add_argument('--do_eval', parser.add_argument(
default=False, '--do_eval',
action='store_true', default=False,
help='Whether to evaluate the inference results.') action='store_true',
parser.add_argument('--overwrite', help='Whether to evaluate the inference results.')
default=False, parser.add_argument(
action='store_true', '--overwrite',
help='Whether to overwrite the existing result file') default=False,
parser.add_argument('--debug', action='store_true',
default=False, help='Whether to overwrite the existing result file')
action='store_true', # parser.add_argument(
help='Only infer the first 50 samples') # '--debug',
# default=False,
# action='store_true',
# help='Only infer the first 50 samples')
return parser.parse_args() return parser.parse_args()
@ -473,9 +501,8 @@ def symbolic_equal_process(a, b, output_queue):
def call_with_timeout(func, *args, timeout=1, **kwargs): def call_with_timeout(func, *args, timeout=1, **kwargs):
output_queue = multiprocessing.Queue() output_queue = multiprocessing.Queue()
process_args = args + (output_queue, ) process_args = args + (output_queue, )
process = multiprocessing.Process(target=func, process = multiprocessing.Process(
args=process_args, target=func, args=process_args, kwargs=kwargs)
kwargs=kwargs)
process.start() process.start()
process.join(timeout) process.join(timeout)
@ -487,25 +514,33 @@ def call_with_timeout(func, *args, timeout=1, **kwargs):
return output_queue.get() return output_queue.get()
def init_agent(backend: str, model_path: str, tp: int, **kwargs): def init_agent(backend: str, max_turn: int, model_path: str, tp: int,
**kwargs):
if backend == 'lmdeploy': if backend == 'lmdeploy':
model = LMDeployPipeline(path=model_path, from lmdeploy import TurbomindEngineConfig
meta_template=INTERNLM2_META, model = LMDeployPipeline(
tp=tp, path=model_path,
**kwargs) model_name='internlm2-chat',
meta_template=INTERNLM2_META,
pipeline_cfg=dict(backend_config=TurbomindEngineConfig(tp=tp)),
**kwargs)
elif backend == 'hf': elif backend == 'hf':
model = HFTransformer(path=model_path, model = HFTransformer(
meta_template=INTERNLM2_META, path=model_path, meta_template=INTERNLM2_META, **kwargs)
**kwargs)
else: else:
raise NotImplementedError raise NotImplementedError
agent = Internlm2Agent(llm=model, agent = Internlm2Agent(
protocol=Internlm2Protocol( llm=model,
meta_prompt=None, protocol=Internlm2Protocol(
interpreter_prompt=DEFAULT_PROMPT), meta_prompt=None, interpreter_prompt=DEFAULT_PROMPT),
interpreter_executor=ActionExecutor( interpreter_executor=ActionExecutor(actions=[
actions=[get_tool('IPythonInteractive')])) IPythonInteractiveManager(
max_workers=200,
ci_lock=os.path.join(
os.path.dirname(__file__), '.ipython.lock'))
]),
max_turn=max_turn)
return agent return agent
@ -522,8 +557,8 @@ def predict(args):
d['pred'], d['steps'], d['error'] = [], [], None d['pred'], d['steps'], d['error'] = [], [], None
return d return d
dataset = load_dataset('gsm8k', 'main', dataset = load_dataset(
split='test').map(process, True) 'gsm8k', 'main', split='test').map(process, True)
elif args.dataset == 'math': elif args.dataset == 'math':
@ -543,14 +578,15 @@ def predict(args):
d['error'] = None d['error'] = None
return d return d
dataset = load_dataset('lighteval/MATH', dataset = load_dataset(
split='test').map(process, True) 'lighteval/MATH', split='test').map(process, True)
else: else:
raise NotImplementedError raise NotImplementedError
agent = init_agent( agent = init_agent(
backend=args.backend, backend=args.backend,
max_turn=args.max_turn,
model_path=args.model_path, model_path=args.model_path,
tp=args.tp, tp=args.tp,
temperature=args.temperature, temperature=args.temperature,
@ -559,26 +595,35 @@ def predict(args):
top_k=args.top_k, top_k=args.top_k,
max_new_tokens=args.max_new_tokens, max_new_tokens=args.max_new_tokens,
) )
with jsonlines.open(args.output_path, 'w') as f: num_batches = ceil(len(dataset) / args.batch_size)
for item in tqdm( with jsonlines.open(args.output_path, 'w', flush=True) as f:
dataset if not args.debug else dataset.select(range(50))): for i in tqdm(range(num_batches)):
batch = dataset.select(
range(i * args.batch_size,
min((i + 1) * args.batch_size, len(dataset))))
# for item in tqdm(
# dataset if not args.debug else dataset.select(range(50))):
try: try:
ret = agent.chat(item['query']) rets = agent.batch_chat(batch['query'])
item['steps'] = ret.inner_steps for item, ret in zip(batch, rets):
item['steps'] = ret.inner_steps
lang = [ lang = [
step for step in item['steps'] step for step in item['steps']
if step['role'] == 'language' if step['role'] == 'language'
] ]
item['pred'].append('😭' if not lang else item['pred'].append('😭' if not lang else extract_answer(
extract_answer(lang[-1]['content']) or '😭') lang[-1]['content']) or '😭')
agent._interpreter_executor.actions[ f.write(item)
'IPythonInteractive'].reset()
except Exception as e: except Exception as e:
err = str(traceback.format_exc()) err = str(traceback.format_exc())
print(f'Error processing index {item["idx"]}: {e}\n{err}') print(f'Processing batch data error: {e}\n{err}')
item['error'] = err for item in batch:
f.write(item) item['error'] = err
f.write(item)
finally:
agent._interpreter_executor.actions[
'IPythonInteractiveManager'].reset()
def evaluate(args): def evaluate(args):
@ -606,7 +651,7 @@ def evaluate(args):
timeout_cnt += 1 timeout_cnt += 1
except Exception as error: except Exception as error:
print(error.__traceback__) print(error.__traceback__)
sys.exit() # sys.exit()
progress_bar.update(1) progress_bar.update(1)
idx = 0 idx = 0