mirror of https://github.com/InternLM/InternLM
[Feat]: Support batch inference in agents (#735)
Co-authored-by: wangzy <wangziyi@pjlab.org.cn>pull/740/head
parent
2db5604288
commit
3be5894976
|
@ -32,7 +32,8 @@ python streaming_inference.py \
|
||||||
--backend=lmdeploy \ # For HuggingFace models: hf
|
--backend=lmdeploy \ # For HuggingFace models: hf
|
||||||
--model_path=internlm/internlm2-chat-20b \
|
--model_path=internlm/internlm2-chat-20b \
|
||||||
--tp=2 \
|
--tp=2 \
|
||||||
--temperature=0.0 \
|
--temperature=1.0 \
|
||||||
|
--top_k=1 \
|
||||||
--dataset=math \
|
--dataset=math \
|
||||||
--output_path=math_lmdeploy.jsonl \
|
--output_path=math_lmdeploy.jsonl \
|
||||||
--do_eval
|
--do_eval
|
||||||
|
|
|
@ -32,7 +32,8 @@ python streaming_inference.py \
|
||||||
--backend=lmdeploy \ # For HuggingFace models: hf
|
--backend=lmdeploy \ # For HuggingFace models: hf
|
||||||
--model_path=internlm/internlm2-chat-20b \
|
--model_path=internlm/internlm2-chat-20b \
|
||||||
--tp=2 \
|
--tp=2 \
|
||||||
--temperature=0.0 \
|
--temperature=1.0 \
|
||||||
|
--top_k=1 \
|
||||||
--dataset=math \
|
--dataset=math \
|
||||||
--output_path=math_lmdeploy.jsonl \
|
--output_path=math_lmdeploy.jsonl \
|
||||||
--do_eval
|
--do_eval
|
||||||
|
|
|
@ -30,7 +30,7 @@ import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from math import isclose
|
from math import isclose, ceil
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import jsonlines
|
import jsonlines
|
||||||
|
@ -38,28 +38,38 @@ import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer,
|
from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer,
|
||||||
Internlm2Agent, Internlm2Protocol, LMDeployPipeline,
|
Internlm2Agent, Internlm2Protocol, LMDeployPipeline,
|
||||||
get_tool)
|
IPythonInteractiveManager)
|
||||||
from pebble import ProcessPool
|
from pebble import ProcessPool
|
||||||
from sympy import N, simplify
|
from sympy import N, simplify
|
||||||
from sympy.parsing.latex import parse_latex
|
from sympy.parsing.latex import parse_latex
|
||||||
from sympy.parsing.sympy_parser import parse_expr
|
from sympy.parsing.sympy_parser import parse_expr
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# --------------------- modify the system prompt as needed ---------------------
|
||||||
|
# DEFAULT_PROMPT = (
|
||||||
|
# 'Integrate step-by-step reasoning and Python code to solve math problems '
|
||||||
|
# 'using the following guidelines:\n'
|
||||||
|
# '- Just write jupyter code to solve the problem without giving your thought;\n'
|
||||||
|
# r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
|
||||||
|
# 'units. \n')
|
||||||
|
|
||||||
DEFAULT_PROMPT = (
|
DEFAULT_PROMPT = (
|
||||||
'Integrate step-by-step reasoning and Python code to solve math problems '
|
'Integrate step-by-step reasoning and Python code to solve math problems '
|
||||||
'using the following guidelines:\n'
|
'using the following guidelines:\n'
|
||||||
'- Just write jupyter code to solve the problem without giving your thought;\n'
|
'- Analyze the question and write jupyter code to solve the problem;\n'
|
||||||
r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
|
r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
|
||||||
'units. \n')
|
'units. \n')
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Math Code Interpreter')
|
parser = argparse.ArgumentParser(description='Math Code Interpreter')
|
||||||
parser.add_argument('--backend',
|
parser.add_argument(
|
||||||
type=str,
|
'--backend',
|
||||||
default='lmdeploy',
|
type=str,
|
||||||
help='Which inference framework to use.',
|
default='lmdeploy',
|
||||||
choices=['lmdeploy', 'hf'])
|
help='Which inference framework to use.',
|
||||||
|
choices=['lmdeploy', 'hf'])
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--model_path',
|
'--model_path',
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -71,37 +81,52 @@ def parse_args():
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help='Path to save inference results to, should be a `jsonl` file')
|
help='Path to save inference results to, should be a `jsonl` file')
|
||||||
parser.add_argument('--dataset',
|
parser.add_argument(
|
||||||
type=str,
|
'--dataset',
|
||||||
default='math',
|
type=str,
|
||||||
choices=['gsm8k', 'math'],
|
default='math',
|
||||||
help='Dataset for inference')
|
choices=['gsm8k', 'math'],
|
||||||
|
help='Dataset for inference')
|
||||||
|
parser.add_argument(
|
||||||
|
'--batch_size',
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help='Agent inference batch size')
|
||||||
|
parser.add_argument(
|
||||||
|
'--max_turn',
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help=
|
||||||
|
'Maximum number of interaction rounds between the agent and environment'
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--tp',
|
'--tp',
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help='Number of tensor parallelism. It may be required in LMDelpoy.')
|
help='Number of tensor parallelism. It may be required in LMDelpoy.')
|
||||||
parser.add_argument('--temperature',
|
parser.add_argument(
|
||||||
type=float,
|
'--temperature',
|
||||||
default=0.1,
|
type=float,
|
||||||
help='Temperature in next token prediction')
|
default=0.1,
|
||||||
parser.add_argument('--top_p',
|
help='Temperature in next token prediction')
|
||||||
type=float,
|
parser.add_argument(
|
||||||
default=0.8,
|
'--top_p',
|
||||||
help='Parameter for Top-P Sampling.')
|
type=float,
|
||||||
parser.add_argument('--top_k',
|
default=0.8,
|
||||||
type=int,
|
help='Parameter for Top-P Sampling.')
|
||||||
default=None,
|
parser.add_argument(
|
||||||
help='Parameter for Top-K Sampling.')
|
'--top_k', type=int, default=40, help='Parameter for Top-K Sampling.')
|
||||||
parser.add_argument('--stop_words',
|
parser.add_argument(
|
||||||
type=str,
|
'--stop_words',
|
||||||
default=['<|action_end|>', '<|im_end|>'],
|
type=str,
|
||||||
action='append',
|
default=['<|action_end|>', '<|im_end|>'],
|
||||||
help='Stop words')
|
action='append',
|
||||||
parser.add_argument('--max_new_tokens',
|
help='Stop words')
|
||||||
type=int,
|
parser.add_argument(
|
||||||
default=512,
|
'--max_new_tokens',
|
||||||
help='Number of maximum generated tokens.')
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help='Number of maximum generated tokens.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--do_infer',
|
'--do_infer',
|
||||||
default=True,
|
default=True,
|
||||||
|
@ -113,18 +138,21 @@ def parse_args():
|
||||||
# action='store_false',
|
# action='store_false',
|
||||||
# help='Disable the inference.'
|
# help='Disable the inference.'
|
||||||
# )
|
# )
|
||||||
parser.add_argument('--do_eval',
|
parser.add_argument(
|
||||||
default=False,
|
'--do_eval',
|
||||||
action='store_true',
|
default=False,
|
||||||
help='Whether to evaluate the inference results.')
|
action='store_true',
|
||||||
parser.add_argument('--overwrite',
|
help='Whether to evaluate the inference results.')
|
||||||
default=False,
|
parser.add_argument(
|
||||||
action='store_true',
|
'--overwrite',
|
||||||
help='Whether to overwrite the existing result file')
|
default=False,
|
||||||
parser.add_argument('--debug',
|
action='store_true',
|
||||||
default=False,
|
help='Whether to overwrite the existing result file')
|
||||||
action='store_true',
|
# parser.add_argument(
|
||||||
help='Only infer the first 50 samples')
|
# '--debug',
|
||||||
|
# default=False,
|
||||||
|
# action='store_true',
|
||||||
|
# help='Only infer the first 50 samples')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -473,9 +501,8 @@ def symbolic_equal_process(a, b, output_queue):
|
||||||
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
||||||
output_queue = multiprocessing.Queue()
|
output_queue = multiprocessing.Queue()
|
||||||
process_args = args + (output_queue, )
|
process_args = args + (output_queue, )
|
||||||
process = multiprocessing.Process(target=func,
|
process = multiprocessing.Process(
|
||||||
args=process_args,
|
target=func, args=process_args, kwargs=kwargs)
|
||||||
kwargs=kwargs)
|
|
||||||
process.start()
|
process.start()
|
||||||
process.join(timeout)
|
process.join(timeout)
|
||||||
|
|
||||||
|
@ -487,25 +514,33 @@ def call_with_timeout(func, *args, timeout=1, **kwargs):
|
||||||
return output_queue.get()
|
return output_queue.get()
|
||||||
|
|
||||||
|
|
||||||
def init_agent(backend: str, model_path: str, tp: int, **kwargs):
|
def init_agent(backend: str, max_turn: int, model_path: str, tp: int,
|
||||||
|
**kwargs):
|
||||||
if backend == 'lmdeploy':
|
if backend == 'lmdeploy':
|
||||||
model = LMDeployPipeline(path=model_path,
|
from lmdeploy import TurbomindEngineConfig
|
||||||
meta_template=INTERNLM2_META,
|
model = LMDeployPipeline(
|
||||||
tp=tp,
|
path=model_path,
|
||||||
**kwargs)
|
model_name='internlm2-chat',
|
||||||
|
meta_template=INTERNLM2_META,
|
||||||
|
pipeline_cfg=dict(backend_config=TurbomindEngineConfig(tp=tp)),
|
||||||
|
**kwargs)
|
||||||
elif backend == 'hf':
|
elif backend == 'hf':
|
||||||
model = HFTransformer(path=model_path,
|
model = HFTransformer(
|
||||||
meta_template=INTERNLM2_META,
|
path=model_path, meta_template=INTERNLM2_META, **kwargs)
|
||||||
**kwargs)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
agent = Internlm2Agent(llm=model,
|
agent = Internlm2Agent(
|
||||||
protocol=Internlm2Protocol(
|
llm=model,
|
||||||
meta_prompt=None,
|
protocol=Internlm2Protocol(
|
||||||
interpreter_prompt=DEFAULT_PROMPT),
|
meta_prompt=None, interpreter_prompt=DEFAULT_PROMPT),
|
||||||
interpreter_executor=ActionExecutor(
|
interpreter_executor=ActionExecutor(actions=[
|
||||||
actions=[get_tool('IPythonInteractive')]))
|
IPythonInteractiveManager(
|
||||||
|
max_workers=200,
|
||||||
|
ci_lock=os.path.join(
|
||||||
|
os.path.dirname(__file__), '.ipython.lock'))
|
||||||
|
]),
|
||||||
|
max_turn=max_turn)
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
@ -522,8 +557,8 @@ def predict(args):
|
||||||
d['pred'], d['steps'], d['error'] = [], [], None
|
d['pred'], d['steps'], d['error'] = [], [], None
|
||||||
return d
|
return d
|
||||||
|
|
||||||
dataset = load_dataset('gsm8k', 'main',
|
dataset = load_dataset(
|
||||||
split='test').map(process, True)
|
'gsm8k', 'main', split='test').map(process, True)
|
||||||
|
|
||||||
elif args.dataset == 'math':
|
elif args.dataset == 'math':
|
||||||
|
|
||||||
|
@ -543,14 +578,15 @@ def predict(args):
|
||||||
d['error'] = None
|
d['error'] = None
|
||||||
return d
|
return d
|
||||||
|
|
||||||
dataset = load_dataset('lighteval/MATH',
|
dataset = load_dataset(
|
||||||
split='test').map(process, True)
|
'lighteval/MATH', split='test').map(process, True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
agent = init_agent(
|
agent = init_agent(
|
||||||
backend=args.backend,
|
backend=args.backend,
|
||||||
|
max_turn=args.max_turn,
|
||||||
model_path=args.model_path,
|
model_path=args.model_path,
|
||||||
tp=args.tp,
|
tp=args.tp,
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
|
@ -559,26 +595,35 @@ def predict(args):
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
max_new_tokens=args.max_new_tokens,
|
max_new_tokens=args.max_new_tokens,
|
||||||
)
|
)
|
||||||
with jsonlines.open(args.output_path, 'w') as f:
|
num_batches = ceil(len(dataset) / args.batch_size)
|
||||||
for item in tqdm(
|
with jsonlines.open(args.output_path, 'w', flush=True) as f:
|
||||||
dataset if not args.debug else dataset.select(range(50))):
|
for i in tqdm(range(num_batches)):
|
||||||
|
batch = dataset.select(
|
||||||
|
range(i * args.batch_size,
|
||||||
|
min((i + 1) * args.batch_size, len(dataset))))
|
||||||
|
# for item in tqdm(
|
||||||
|
# dataset if not args.debug else dataset.select(range(50))):
|
||||||
try:
|
try:
|
||||||
ret = agent.chat(item['query'])
|
rets = agent.batch_chat(batch['query'])
|
||||||
item['steps'] = ret.inner_steps
|
for item, ret in zip(batch, rets):
|
||||||
|
item['steps'] = ret.inner_steps
|
||||||
|
|
||||||
lang = [
|
lang = [
|
||||||
step for step in item['steps']
|
step for step in item['steps']
|
||||||
if step['role'] == 'language'
|
if step['role'] == 'language'
|
||||||
]
|
]
|
||||||
item['pred'].append('😭' if not lang else
|
item['pred'].append('😭' if not lang else extract_answer(
|
||||||
extract_answer(lang[-1]['content']) or '😭')
|
lang[-1]['content']) or '😭')
|
||||||
agent._interpreter_executor.actions[
|
f.write(item)
|
||||||
'IPythonInteractive'].reset()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err = str(traceback.format_exc())
|
err = str(traceback.format_exc())
|
||||||
print(f'Error processing index {item["idx"]}: {e}\n{err}')
|
print(f'Processing batch data error: {e}\n{err}')
|
||||||
item['error'] = err
|
for item in batch:
|
||||||
f.write(item)
|
item['error'] = err
|
||||||
|
f.write(item)
|
||||||
|
finally:
|
||||||
|
agent._interpreter_executor.actions[
|
||||||
|
'IPythonInteractiveManager'].reset()
|
||||||
|
|
||||||
|
|
||||||
def evaluate(args):
|
def evaluate(args):
|
||||||
|
@ -606,7 +651,7 @@ def evaluate(args):
|
||||||
timeout_cnt += 1
|
timeout_cnt += 1
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
print(error.__traceback__)
|
print(error.__traceback__)
|
||||||
sys.exit()
|
# sys.exit()
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|
||||||
idx = 0
|
idx = 0
|
||||||
|
|
Loading…
Reference in New Issue