mirror of https://github.com/InternLM/InternLM
[Feat]: Support batch inference in agents (#735)
Co-authored-by: wangzy <wangziyi@pjlab.org.cn>pull/740/head
parent
2db5604288
commit
3be5894976
|
@ -32,7 +32,8 @@ python streaming_inference.py \
|
|||
--backend=lmdeploy \ # For HuggingFace models: hf
|
||||
--model_path=internlm/internlm2-chat-20b \
|
||||
--tp=2 \
|
||||
--temperature=0.0 \
|
||||
--temperature=1.0 \
|
||||
--top_k=1 \
|
||||
--dataset=math \
|
||||
--output_path=math_lmdeploy.jsonl \
|
||||
--do_eval
|
||||
|
|
|
@ -32,7 +32,8 @@ python streaming_inference.py \
|
|||
--backend=lmdeploy \ # For HuggingFace models: hf
|
||||
--model_path=internlm/internlm2-chat-20b \
|
||||
--tp=2 \
|
||||
--temperature=0.0 \
|
||||
--temperature=1.0 \
|
||||
--top_k=1 \
|
||||
--dataset=math \
|
||||
--output_path=math_lmdeploy.jsonl \
|
||||
--do_eval
|
||||
|
|
|
@ -30,7 +30,7 @@ import os
|
|||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from math import isclose
|
||||
from math import isclose, ceil
|
||||
from typing import Union
|
||||
|
||||
import jsonlines
|
||||
|
@ -38,28 +38,38 @@ import numpy as np
|
|||
from datasets import load_dataset
|
||||
from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer,
|
||||
Internlm2Agent, Internlm2Protocol, LMDeployPipeline,
|
||||
get_tool)
|
||||
IPythonInteractiveManager)
|
||||
from pebble import ProcessPool
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
from tqdm import tqdm
|
||||
|
||||
# --------------------- modify the system prompt as needed ---------------------
|
||||
# DEFAULT_PROMPT = (
|
||||
# 'Integrate step-by-step reasoning and Python code to solve math problems '
|
||||
# 'using the following guidelines:\n'
|
||||
# '- Just write jupyter code to solve the problem without giving your thought;\n'
|
||||
# r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
|
||||
# 'units. \n')
|
||||
|
||||
DEFAULT_PROMPT = (
|
||||
'Integrate step-by-step reasoning and Python code to solve math problems '
|
||||
'using the following guidelines:\n'
|
||||
'- Just write jupyter code to solve the problem without giving your thought;\n'
|
||||
'- Analyze the question and write jupyter code to solve the problem;\n'
|
||||
r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
|
||||
'units. \n')
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Math Code Interpreter')
|
||||
parser.add_argument('--backend',
|
||||
type=str,
|
||||
default='lmdeploy',
|
||||
help='Which inference framework to use.',
|
||||
choices=['lmdeploy', 'hf'])
|
||||
parser.add_argument(
|
||||
'--backend',
|
||||
type=str,
|
||||
default='lmdeploy',
|
||||
help='Which inference framework to use.',
|
||||
choices=['lmdeploy', 'hf'])
|
||||
parser.add_argument(
|
||||
'--model_path',
|
||||
type=str,
|
||||
|
@ -71,37 +81,52 @@ def parse_args():
|
|||
type=str,
|
||||
required=True,
|
||||
help='Path to save inference results to, should be a `jsonl` file')
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
default='math',
|
||||
choices=['gsm8k', 'math'],
|
||||
help='Dataset for inference')
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
type=str,
|
||||
default='math',
|
||||
choices=['gsm8k', 'math'],
|
||||
help='Dataset for inference')
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=100,
|
||||
help='Agent inference batch size')
|
||||
parser.add_argument(
|
||||
'--max_turn',
|
||||
type=int,
|
||||
default=3,
|
||||
help=
|
||||
'Maximum number of interaction rounds between the agent and environment'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--tp',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of tensor parallelism. It may be required in LMDelpoy.')
|
||||
parser.add_argument('--temperature',
|
||||
type=float,
|
||||
default=0.1,
|
||||
help='Temperature in next token prediction')
|
||||
parser.add_argument('--top_p',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Parameter for Top-P Sampling.')
|
||||
parser.add_argument('--top_k',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Parameter for Top-K Sampling.')
|
||||
parser.add_argument('--stop_words',
|
||||
type=str,
|
||||
default=['<|action_end|>', '<|im_end|>'],
|
||||
action='append',
|
||||
help='Stop words')
|
||||
parser.add_argument('--max_new_tokens',
|
||||
type=int,
|
||||
default=512,
|
||||
help='Number of maximum generated tokens.')
|
||||
parser.add_argument(
|
||||
'--temperature',
|
||||
type=float,
|
||||
default=0.1,
|
||||
help='Temperature in next token prediction')
|
||||
parser.add_argument(
|
||||
'--top_p',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Parameter for Top-P Sampling.')
|
||||
parser.add_argument(
|
||||
'--top_k', type=int, default=40, help='Parameter for Top-K Sampling.')
|
||||
parser.add_argument(
|
||||
'--stop_words',
|
||||
type=str,
|
||||
default=['<|action_end|>', '<|im_end|>'],
|
||||
action='append',
|
||||
help='Stop words')
|
||||
parser.add_argument(
|
||||
'--max_new_tokens',
|
||||
type=int,
|
||||
default=512,
|
||||
help='Number of maximum generated tokens.')
|
||||
parser.add_argument(
|
||||
'--do_infer',
|
||||
default=True,
|
||||
|
@ -113,18 +138,21 @@ def parse_args():
|
|||
# action='store_false',
|
||||
# help='Disable the inference.'
|
||||
# )
|
||||
parser.add_argument('--do_eval',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Whether to evaluate the inference results.')
|
||||
parser.add_argument('--overwrite',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Whether to overwrite the existing result file')
|
||||
parser.add_argument('--debug',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Only infer the first 50 samples')
|
||||
parser.add_argument(
|
||||
'--do_eval',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Whether to evaluate the inference results.')
|
||||
parser.add_argument(
|
||||
'--overwrite',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Whether to overwrite the existing result file')
|
||||
# parser.add_argument(
|
||||
# '--debug',
|
||||
# default=False,
|
||||
# action='store_true',
|
||||
# help='Only infer the first 50 samples')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -473,9 +501,8 @@ def symbolic_equal_process(a, b, output_queue):
|
|||
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
||||
output_queue = multiprocessing.Queue()
|
||||
process_args = args + (output_queue, )
|
||||
process = multiprocessing.Process(target=func,
|
||||
args=process_args,
|
||||
kwargs=kwargs)
|
||||
process = multiprocessing.Process(
|
||||
target=func, args=process_args, kwargs=kwargs)
|
||||
process.start()
|
||||
process.join(timeout)
|
||||
|
||||
|
@ -487,25 +514,33 @@ def call_with_timeout(func, *args, timeout=1, **kwargs):
|
|||
return output_queue.get()
|
||||
|
||||
|
||||
def init_agent(backend: str, model_path: str, tp: int, **kwargs):
|
||||
def init_agent(backend: str, max_turn: int, model_path: str, tp: int,
|
||||
**kwargs):
|
||||
if backend == 'lmdeploy':
|
||||
model = LMDeployPipeline(path=model_path,
|
||||
meta_template=INTERNLM2_META,
|
||||
tp=tp,
|
||||
**kwargs)
|
||||
from lmdeploy import TurbomindEngineConfig
|
||||
model = LMDeployPipeline(
|
||||
path=model_path,
|
||||
model_name='internlm2-chat',
|
||||
meta_template=INTERNLM2_META,
|
||||
pipeline_cfg=dict(backend_config=TurbomindEngineConfig(tp=tp)),
|
||||
**kwargs)
|
||||
elif backend == 'hf':
|
||||
model = HFTransformer(path=model_path,
|
||||
meta_template=INTERNLM2_META,
|
||||
**kwargs)
|
||||
model = HFTransformer(
|
||||
path=model_path, meta_template=INTERNLM2_META, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
agent = Internlm2Agent(llm=model,
|
||||
protocol=Internlm2Protocol(
|
||||
meta_prompt=None,
|
||||
interpreter_prompt=DEFAULT_PROMPT),
|
||||
interpreter_executor=ActionExecutor(
|
||||
actions=[get_tool('IPythonInteractive')]))
|
||||
agent = Internlm2Agent(
|
||||
llm=model,
|
||||
protocol=Internlm2Protocol(
|
||||
meta_prompt=None, interpreter_prompt=DEFAULT_PROMPT),
|
||||
interpreter_executor=ActionExecutor(actions=[
|
||||
IPythonInteractiveManager(
|
||||
max_workers=200,
|
||||
ci_lock=os.path.join(
|
||||
os.path.dirname(__file__), '.ipython.lock'))
|
||||
]),
|
||||
max_turn=max_turn)
|
||||
return agent
|
||||
|
||||
|
||||
|
@ -522,8 +557,8 @@ def predict(args):
|
|||
d['pred'], d['steps'], d['error'] = [], [], None
|
||||
return d
|
||||
|
||||
dataset = load_dataset('gsm8k', 'main',
|
||||
split='test').map(process, True)
|
||||
dataset = load_dataset(
|
||||
'gsm8k', 'main', split='test').map(process, True)
|
||||
|
||||
elif args.dataset == 'math':
|
||||
|
||||
|
@ -543,14 +578,15 @@ def predict(args):
|
|||
d['error'] = None
|
||||
return d
|
||||
|
||||
dataset = load_dataset('lighteval/MATH',
|
||||
split='test').map(process, True)
|
||||
dataset = load_dataset(
|
||||
'lighteval/MATH', split='test').map(process, True)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
agent = init_agent(
|
||||
backend=args.backend,
|
||||
max_turn=args.max_turn,
|
||||
model_path=args.model_path,
|
||||
tp=args.tp,
|
||||
temperature=args.temperature,
|
||||
|
@ -559,26 +595,35 @@ def predict(args):
|
|||
top_k=args.top_k,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
)
|
||||
with jsonlines.open(args.output_path, 'w') as f:
|
||||
for item in tqdm(
|
||||
dataset if not args.debug else dataset.select(range(50))):
|
||||
num_batches = ceil(len(dataset) / args.batch_size)
|
||||
with jsonlines.open(args.output_path, 'w', flush=True) as f:
|
||||
for i in tqdm(range(num_batches)):
|
||||
batch = dataset.select(
|
||||
range(i * args.batch_size,
|
||||
min((i + 1) * args.batch_size, len(dataset))))
|
||||
# for item in tqdm(
|
||||
# dataset if not args.debug else dataset.select(range(50))):
|
||||
try:
|
||||
ret = agent.chat(item['query'])
|
||||
item['steps'] = ret.inner_steps
|
||||
rets = agent.batch_chat(batch['query'])
|
||||
for item, ret in zip(batch, rets):
|
||||
item['steps'] = ret.inner_steps
|
||||
|
||||
lang = [
|
||||
step for step in item['steps']
|
||||
if step['role'] == 'language'
|
||||
]
|
||||
item['pred'].append('😭' if not lang else
|
||||
extract_answer(lang[-1]['content']) or '😭')
|
||||
agent._interpreter_executor.actions[
|
||||
'IPythonInteractive'].reset()
|
||||
lang = [
|
||||
step for step in item['steps']
|
||||
if step['role'] == 'language'
|
||||
]
|
||||
item['pred'].append('😭' if not lang else extract_answer(
|
||||
lang[-1]['content']) or '😭')
|
||||
f.write(item)
|
||||
except Exception as e:
|
||||
err = str(traceback.format_exc())
|
||||
print(f'Error processing index {item["idx"]}: {e}\n{err}')
|
||||
item['error'] = err
|
||||
f.write(item)
|
||||
print(f'Processing batch data error: {e}\n{err}')
|
||||
for item in batch:
|
||||
item['error'] = err
|
||||
f.write(item)
|
||||
finally:
|
||||
agent._interpreter_executor.actions[
|
||||
'IPythonInteractiveManager'].reset()
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
|
@ -606,7 +651,7 @@ def evaluate(args):
|
|||
timeout_cnt += 1
|
||||
except Exception as error:
|
||||
print(error.__traceback__)
|
||||
sys.exit()
|
||||
# sys.exit()
|
||||
progress_bar.update(1)
|
||||
|
||||
idx = 0
|
||||
|
|
Loading…
Reference in New Issue