Merge branch 'feature/math_ci' into 'main'

Update the agent section

See merge request openmmlab/bigmodel/InternLM!17
pull/752/head
lvchengqi 2024-07-02 11:52:07 +00:00
commit bb28b48c5e
7 changed files with 175 additions and 203 deletions

View File

@ -73,7 +73,7 @@ InternLM2.5 系列模型在本仓库正式发布,具有如下特性:
目前 InternLM 2.5 系列只发布了 7B 大小的模型,我们接下来将开源 1.8B 和 20B 的版本。7B 为轻量级的研究和应用提供了一个轻便但性能不俗的模型20B 模型的综合性能更为强劲,可以有效支持更加复杂的实用场景。每个规格不同模型关系如下所示:
1. **InternLM2.5**:经历了大规模预训练的基座模型,是我们推荐的在大部分应用中考虑选用的优秀基座。
2. **InternLM2.5-Chat**: 对话模型,在 InternLM2.5 基座上经历了有监督微调和 online RLHF。InternLM25-Chat 面向对话交互进行了优化,具有较好的指令遵循、共情聊天和调用工具等的能力,是我们推荐直接用于下游应用的模型。
2. **InternLM2.5-Chat**: 对话模型,在 InternLM2.5 基座上经历了有监督微调和 online RLHF。InternLM2.5-Chat 面向对话交互进行了优化,具有较好的指令遵循、共情聊天和调用工具等的能力,是我们推荐直接用于下游应用的模型。
3. **InternLM2.5-Chat-1M**: InternLM2.5-Chat-1M 支持一百万字超长上下文,并具有和 InternLM2.5-Chat 相当的综合性能表现。
**局限性:** 尽管在训练过程中我们非常注重模型的安全性,尽力促使模型输出符合伦理和法律要求的文本,但受限于模型大小以及概率生成范式,模型可能会产生各种不符合预期的输出,例如回复内容包含偏见、歧视等有害内容,请勿传播这些内容。由于传播不良信息导致的任何后果,本项目不承担责任。

View File

@ -4,77 +4,80 @@ English | [简体中文](README_zh-CN.md)
## Introduction
InternLM-Chat-7B v1.1 has been released as the first open-source model with code interpreter capabilities, supporting external tools such as Python code interpreter and search engine.
InternLM2.5-Chat, open sourced on June 30, 2024, further enhances its capabilities in code interpreter and general tool utilization. With improved and more generalized instruction understanding, tool selection, and reflection abilities, InternLM2.5-Chat can more reliably support complex agents and multi-step tool calling for more intricate tasks. When combined with a code interpreter, InternLM2.5-Chat obtains comparable results to GPT-4 on MATH. Leveraging strong foundational capabilities in mathematics and tools, InternLM2.5-Chat provides practical data analysis capabilities.
InternLM2-Chat, open sourced on January 17, 2024, further enhances its capabilities in code interpreter and general tool utilization. With improved and more generalized instruction understanding, tool selection, and reflection abilities, InternLM2-Chat can more reliably support complex agents and multi-step tool calling for more intricate tasks. InternLM2-Chat exhibits decent computational and reasoning abilities even without external tools, surpassing ChatGPT in mathematical performance. When combined with a code interpreter, InternLM2-Chat-20B obtains comparable results to GPT-4 on GSM8K and MATH. Leveraging strong foundational capabilities in mathematics and tools, InternLM2-Chat provides practical data analysis capabilities.
The results of InternLM2.5-Chat on math code interpreter is as below:
The results of InternLM2-Chat-20B on math code interpreter is as below:
| | GSM8K | MATH |
| :--------------------------------------: | :---: | :---: |
| InternLM2-Chat-20B | 79.6 | 32.5 |
| InternLM2-Chat-20B with Code Interpreter | 84.5 | 51.2 |
| ChatGPT (GPT-3.5) | 78.2 | 28.0 |
| GPT-4 | 91.4 | 45.8 |
| Models | Tool-Integrated | MATH |
| :-----------------: | :-------------: | :--: |
| InternLM2-Chat-7B | w/ | 45.1 |
| InternLM2-Chat-20B | w/ | 51.2 |
| InternLM2.5-7B-Chat | w/ | 63.0 |
| gpt-4-0125-preview | w/o | 64.2 |
## Usages
We offer an example using [Lagent](lagent.md) to build agents based on InternLM2-Chat to call the code interpreter. Firstly install the extra dependencies:
We offer an example using [Lagent](lagent.md) to build agents based on InternLM2.5-Chat to call the code interpreter. Firstly install the extra dependencies:
```bash
pip install -r requirements.txt
```
Run the following script to perform inference and evaluation on GSM8K and MATH test.
Run the following script to perform inference and evaluation on MATH test.
```bash
python streaming_inference.py \
--backend=lmdeploy \ # For HuggingFace models: hf
--model_path=internlm/internlm2-chat-20b \
--tp=2 \
--model_path=internlm/internlm2_5-7b-chat \
--tp=1 \
--temperature=1.0 \
--top_k=1 \
--dataset=math \
--output_path=math_lmdeploy.jsonl \
--do_eval
```
`output_path` is a jsonl format file to save the inference results. Each line is like
```json
````json
{
"idx": 41,
"query": "The point $(a, b)$ lies on the line with the equation $3x + 2y = 12.$ When $a = 4$, what is the value of $b$?",
"gt": "0",
"pred": ["0"],
"problem": "The midpoint of the line segment between $(x,y)$ and $(2,4)$ is $(-7,0)$. Find $(x,y)$.",
"level": "Level 3",
"type": "Algebra",
"solution": "Applying the midpoint formula gives $$\\left(\\frac{2+x}{2},\\frac{4+y}{2}\\right)=(-7,0).$$Solving $\\frac{2+x}{2}=-7$ for $x$ and $\\frac{4+y}{2} = 0$ for $y$ we find $(x,y)$ to be $\\boxed{(-16,-4)}$.",
"idx": 306,
"query": "The midpoint of the line segment between $(x,y)$ and $(2,4)$ is $(-7,0)$. Find $(x,y)$.",
"gt": "(-16,-4)",
"pred": [
"(-16,-4)"
],
"steps": [
{
"role": "language",
"content": ""
"content": "To find the coordinates of the point $(x,y)$, we can use the midpoint formula. The midpoint $M$ of a line segment with endpoints $(x_1, y_1)$ and $(x_2, y_2)$ is given by:\n\n\\[M = \\left(\\frac{x_1 + x_2}{2}, \\frac{y_1 + y_2}{2}\\right)\\]\n\nGiven that the midpoint is $(-7,0)$ and one endpoint is $(2,4)$, we can set up equations to solve for $x$ and $y$.\n\nLet's start by setting up the equations based on the midpoint formula:\n\n\\[\\frac{x + 2}{2} = -7\\]\n\\[\\frac{y + 4}{2} = 0\\]\n\nNow, let's solve these equations using Python to find the values of $x$ and $y$."
},
{
"role": "tool",
"content": {
"name": "IPythonInteractive",
"name": "AsyncIPythonInteractiveManager",
"parameters": {
"command": "```python\nfrom sympy import symbols, solve\n\ndef find_b():\n x, y = symbols('x y')\n equation = 3*x + 2*y - 12\n b = solve(equation.subs(x, 4), y)[0]\n\n return b\n\nresult = find_b()\nprint(result)\n```"
"command": "```python\nfrom sympy import symbols, Eq, solve\n\n# Define symbols\nx, y = symbols('x y')\n\n# Define equations based on the midpoint formula\neq1 = Eq((x + 2)/2, -7)\neq2 = Eq((y + 4)/2, 0)\n\n# Solve equations\nsolution = solve((eq1, eq2), (x, y))\n\nsolution\n```"
}
},
"name": "interpreter"
},
{
"role": "environment",
"content": "0",
"content": "{x: -16, y: -4}",
"name": "interpreter"
},
{
"role": "language",
"content": "The value of $b$ when $a = 4$ is $\\boxed{0}$."
"content": "After solving the equations, we find that the coordinates of the point $(x,y)$ are $(-16, -4)$. Therefore, the solution to the problem is:\n\n\\[\\boxed{(-16, -4)}\\]"
}
],
"error": null
}
```
````
Once it is prepared, just skip the inference stage as follows.

View File

@ -4,77 +4,80 @@
## 简介
InternLM-Chat-7B v1.1 是首个具有代码解释能力的开源对话模型,支持 Python 解释器和搜索引擎等外部工具
InternLM2.5-Chat 在代码解释和通用工具调用方面的能力得到进一步提升。基于更强和更具有泛化性的指令理解、工具筛选与结果反思等能力新版模型可以更可靠地支持复杂智能体的搭建支持对工具进行有效的多轮调用完成较复杂的任务。在配合代码解释器code-interpreter的条件下InternLM2.5-Chat 在 MATH 上可以达到和 GPT-4 相仿的水平。基于在数理和工具方面强大的基础能力InternLM2.5-Chat 提供了实用的数据分析能力
InternLM2-Chat 进一步提高了它在代码解释和通用工具调用方面的能力。基于更强和更具有泛化性的指令理解、工具筛选与结果反思等能力,新版模型可以更可靠地支持复杂智能体的搭建,支持对工具进行有效的多轮调用,完成较复杂的任务。模型在不使用外部工具的条件下已具备不错的计算能力和推理能力,数理表现超过 ChatGPT在配合代码解释器code-interpreter的条件下InternLM2-Chat-20B 在 GSM8K 和 MATH 上可以达到和 GPT-4 相仿的水平。基于在数理和工具方面强大的基础能力InternLM2-Chat 提供了实用的数据分析能力
以下是 InternLM2.5-Chat 在数学代码解释器上的结果
以下是 InternLM2-Chat-20B 在数学代码解释器上的结果。
| | GSM8K | MATH |
| :---------------------------------: | :---: | :---: |
| InternLM2-Chat-20B 单纯依靠内在能力 | 79.6 | 32.5 |
| InternLM2-Chat-20B 配合代码解释器 | 84.5 | 51.2 |
| ChatGPT (GPT-3.5) | 78.2 | 28.0 |
| GPT-4 | 91.4 | 45.8 |
| 模型 | 是否集成工具 | MATH |
| :-----------------: | :----------: | :--: |
| InternLM2-Chat-7B | w/ | 45.1 |
| InternLM2-Chat-20B | w/ | 51.2 |
| InternLM2.5-7B-Chat | w/ | 63.0 |
| gpt-4-0125-preview | w/o | 64.2 |
## 体验
我们提供了使用 [Lagent](lagent_zh-CN.md) 来基于 InternLM2-Chat 构建智能体调用代码解释器的例子。首先安装额外依赖:
我们提供了使用 [Lagent](lagent_zh-CN.md) 来基于 InternLM2.5-Chat 构建智能体调用代码解释器的例子。首先安装额外依赖:
```bash
pip install -r requirements.txt
```
运行以下脚本在 GSM8K 和 MATH 测试集上进行推理和评估:
运行以下脚本在 MATH 测试集上进行推理和评估:
```bash
python streaming_inference.py \
--backend=lmdeploy \ # For HuggingFace models: hf
--model_path=internlm/internlm2-chat-20b \
--tp=2 \
--model_path=internlm/internlm2_5-7b-chat \
--tp=1 \
--temperature=1.0 \
--top_k=1 \
--dataset=math \
--output_path=math_lmdeploy.jsonl \
--do_eval
```
`output_path` 是一个存储推理结果的 jsonl 格式文件,每行形如:
```json
````json
{
"idx": 41,
"query": "The point $(a, b)$ lies on the line with the equation $3x + 2y = 12.$ When $a = 4$, what is the value of $b$?",
"gt": "0",
"pred": ["0"],
"problem": "The midpoint of the line segment between $(x,y)$ and $(2,4)$ is $(-7,0)$. Find $(x,y)$.",
"level": "Level 3",
"type": "Algebra",
"solution": "Applying the midpoint formula gives $$\\left(\\frac{2+x}{2},\\frac{4+y}{2}\\right)=(-7,0).$$Solving $\\frac{2+x}{2}=-7$ for $x$ and $\\frac{4+y}{2} = 0$ for $y$ we find $(x,y)$ to be $\\boxed{(-16,-4)}$.",
"idx": 306,
"query": "The midpoint of the line segment between $(x,y)$ and $(2,4)$ is $(-7,0)$. Find $(x,y)$.",
"gt": "(-16,-4)",
"pred": [
"(-16,-4)"
],
"steps": [
{
"role": "language",
"content": ""
"content": "To find the coordinates of the point $(x,y)$, we can use the midpoint formula. The midpoint $M$ of a line segment with endpoints $(x_1, y_1)$ and $(x_2, y_2)$ is given by:\n\n\\[M = \\left(\\frac{x_1 + x_2}{2}, \\frac{y_1 + y_2}{2}\\right)\\]\n\nGiven that the midpoint is $(-7,0)$ and one endpoint is $(2,4)$, we can set up equations to solve for $x$ and $y$.\n\nLet's start by setting up the equations based on the midpoint formula:\n\n\\[\\frac{x + 2}{2} = -7\\]\n\\[\\frac{y + 4}{2} = 0\\]\n\nNow, let's solve these equations using Python to find the values of $x$ and $y$."
},
{
"role": "tool",
"content": {
"name": "IPythonInteractive",
"name": "AsyncIPythonInteractiveManager",
"parameters": {
"command": "```python\nfrom sympy import symbols, solve\n\ndef find_b():\n x, y = symbols('x y')\n equation = 3*x + 2*y - 12\n b = solve(equation.subs(x, 4), y)[0]\n\n return b\n\nresult = find_b()\nprint(result)\n```"
"command": "```python\nfrom sympy import symbols, Eq, solve\n\n# Define symbols\nx, y = symbols('x y')\n\n# Define equations based on the midpoint formula\neq1 = Eq((x + 2)/2, -7)\neq2 = Eq((y + 4)/2, 0)\n\n# Solve equations\nsolution = solve((eq1, eq2), (x, y))\n\nsolution\n```"
}
},
"name": "interpreter"
},
{
"role": "environment",
"content": "0",
"content": "{x: -16, y: -4}",
"name": "interpreter"
},
{
"role": "language",
"content": "The value of $b$ when $a = 4$ is $\\boxed{0}$."
"content": "After solving the equations, we find that the coordinates of the point $(x,y)$ are $(-16, -4)$. Therefore, the solution to the problem is:\n\n\\[\\boxed{(-16, -4)}\\]"
}
],
"error": null
}
```
````
如果已经准备好了该文件,可直接跳过推理阶段进行评估:

View File

@ -38,7 +38,7 @@ Then you can chat through the UI shown as below
![image](https://github.com/InternLM/lagent/assets/24622904/3aebb8b4-07d1-42a2-9da3-46080c556f68)
## Run a ReAct agent with InternLM2-Chat
## Run a ReAct agent with InternLM2.5-Chat
**NOTE:** If you want to run a HuggingFace model, please run `pip install -e .[all]` first.

View File

@ -38,7 +38,7 @@ streamlit run examples/react_web_demo.py
![image](https://github.com/InternLM/lagent/assets/24622904/3aebb8b4-07d1-42a2-9da3-46080c556f68)
## 用 InternLM-Chat 构建一个 ReAct 智能体
## 用 InternLM2.5-Chat 构建一个 ReAct 智能体
\*\*注意:\*\*如果你想要启动一个 HuggingFace 的模型,请先运行 pip install -e .\[all\]。
@ -49,7 +49,7 @@ from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter
from lagent.llms import HFTransformer
# Initialize the HFTransformer-based Language Model (llm) and provide the model name.
llm = HFTransformer('internlm/internlm-chat-7b-v1_1')
llm = HFTransformer('internlm/internlm2_5-7b-chat')
# Initialize the Google Search tool and provide your API key.
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')

View File

@ -1,10 +1,10 @@
lmdeploy>=0.2.2
antlr4-python3-runtime==4.11.0
datasets
tqdm
einops
jsonlines
lagent @ git+https://github.com/InternLM/lagent@main
lmdeploy>=0.2.2
numpy
pebble
jsonlines
sympy==1.12
antlr4-python3-runtime==4.11.0
lagent
einops
tqdm

View File

@ -46,13 +46,6 @@ from sympy.parsing.sympy_parser import parse_expr
from tqdm import tqdm
# --------------------- modify the system prompt as needed ---------------------
# DEFAULT_PROMPT = (
# 'Integrate step-by-step reasoning and Python code to solve math problems '
# 'using the following guidelines:\n'
# '- Just write jupyter code to solve the problem without giving your thought;\n'
# r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
# 'units. \n')
DEFAULT_PROMPT = (
'Integrate step-by-step reasoning and Python code to solve math problems '
'using the following guidelines:\n'
@ -64,16 +57,15 @@ DEFAULT_PROMPT = (
def parse_args():
parser = argparse.ArgumentParser(description='Math Code Interpreter')
parser.add_argument(
'--backend',
type=str,
default='lmdeploy',
help='Which inference framework to use.',
choices=['lmdeploy', 'hf'])
parser.add_argument('--backend',
type=str,
default='lmdeploy',
help='Which inference framework to use.',
choices=['lmdeploy', 'hf'])
parser.add_argument(
'--model_path',
type=str,
default='internlm/internlm2_5-7b-chat',
default='internlm/internlm2-chat-7b',
help='Path or name to the model, could be HuggingFace model specifier.'
)
parser.add_argument(
@ -81,21 +73,14 @@ def parse_args():
type=str,
required=True,
help='Path to save inference results to, should be a `jsonl` file')
parser.add_argument(
'--dataset',
type=str,
default='math',
choices=['gsm8k', 'math'],
help='Dataset for inference')
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='Agent inference batch size')
parser.add_argument('--batch_size',
type=int,
default=100,
help='Agent inference batch size')
parser.add_argument(
'--max_turn',
type=int,
default=3,
default=5,
help=
'Maximum number of interaction rounds between the agent and environment'
)
@ -104,29 +89,27 @@ def parse_args():
type=int,
default=1,
help='Number of tensor parallelism. It may be required in LMDelpoy.')
parser.add_argument(
'--temperature',
type=float,
default=0.1,
help='Temperature in next token prediction')
parser.add_argument(
'--top_p',
type=float,
default=0.8,
help='Parameter for Top-P Sampling.')
parser.add_argument(
'--top_k', type=int, default=40, help='Parameter for Top-K Sampling.')
parser.add_argument(
'--stop_words',
type=str,
default=['<|action_end|>', '<|im_end|>'],
action='append',
help='Stop words')
parser.add_argument(
'--max_new_tokens',
type=int,
default=512,
help='Number of maximum generated tokens.')
parser.add_argument('--temperature',
type=float,
default=0.1,
help='Temperature in next token prediction')
parser.add_argument('--top_p',
type=float,
default=0.8,
help='Parameter for Top-P Sampling.')
parser.add_argument('--top_k',
type=int,
default=40,
help='Parameter for Top-K Sampling.')
parser.add_argument('--stop_words',
type=str,
default=['<|action_end|>', '<|im_end|>'],
action='append',
help='Stop words')
parser.add_argument('--max_new_tokens',
type=int,
default=512,
help='Number of maximum generated tokens.')
parser.add_argument(
'--do_infer',
default=True,
@ -138,21 +121,14 @@ def parse_args():
# action='store_false',
# help='Disable the inference.'
# )
parser.add_argument(
'--do_eval',
default=False,
action='store_true',
help='Whether to evaluate the inference results.')
parser.add_argument(
'--overwrite',
default=False,
action='store_true',
help='Whether to overwrite the existing result file')
# parser.add_argument(
# '--debug',
# default=False,
# action='store_true',
# help='Only infer the first 50 samples')
parser.add_argument('--do_eval',
default=False,
action='store_true',
help='Whether to evaluate the inference results.')
parser.add_argument('--overwrite',
default=False,
action='store_true',
help='Whether to overwrite the existing result file')
return parser.parse_args()
@ -339,28 +315,41 @@ def last_boxed_only_string(string):
return retval
def extract_answer(pred_str):
if 'boxed' not in pred_str:
return ''
answer = pred_str.split('boxed')[-1]
if len(answer) == 0:
return ''
elif (answer[0] == '{'):
stack = 1
a = ''
for c in answer[1:]:
if (c == '{'):
stack += 1
a += c
elif (c == '}'):
stack -= 1
if (stack == 0): break
a += c
else:
a += c
else:
a = answer.split('$')[0].strip()
def extract_answer(pred_str: str, execute: bool = False) -> str:
if re.search('\boxed|boxed', pred_str):
answer = re.split('\boxed|boxed', pred_str)[-1]
if len(answer) == 0:
return ''
elif (answer[0] == '{'):
stack = 1
a = ''
for c in answer[1:]:
if (c == '{'):
stack += 1
a += c
elif (c == '}'):
stack -= 1
if (stack == 0): break
a += c
else:
a += c
else:
a = answer.split('$')[0].strip()
elif re.search('[Tt]he (final )?answer is:?', pred_str):
a = re.split('[Tt]he (final )?answer is:?',
pred_str)[-1].strip().rstrip('.')
elif pred_str.startswith('```python') and execute:
# fall back to program
from lagent import get_tool
a = get_tool('IPythonInteractive').exec(pred_str).value or ''
else: # use the last number
pred = re.findall(r'-?\d*\.?\d+', pred_str.replace(',', ''))
if len(pred) >= 1:
a = pred[-1]
else:
a = ''
# multiple lines
pred = a.split('\n')[0]
if pred != '' and pred[0] == ':':
pred = pred[1:]
@ -501,8 +490,9 @@ def symbolic_equal_process(a, b, output_queue):
def call_with_timeout(func, *args, timeout=1, **kwargs):
output_queue = multiprocessing.Queue()
process_args = args + (output_queue, )
process = multiprocessing.Process(
target=func, args=process_args, kwargs=kwargs)
process = multiprocessing.Process(target=func,
args=process_args,
kwargs=kwargs)
process.start()
process.join(timeout)
@ -525,65 +515,45 @@ def init_agent(backend: str, max_turn: int, model_path: str, tp: int,
pipeline_cfg=dict(backend_config=TurbomindEngineConfig(tp=tp)),
**kwargs)
elif backend == 'hf':
model = HFTransformer(
path=model_path, meta_template=INTERNLM2_META, **kwargs)
model = HFTransformer(path=model_path,
meta_template=INTERNLM2_META,
**kwargs)
else:
raise NotImplementedError
agent = Internlm2Agent(
llm=model,
protocol=Internlm2Protocol(
meta_prompt=None, interpreter_prompt=DEFAULT_PROMPT),
protocol=Internlm2Protocol(meta_prompt=None,
interpreter_prompt=DEFAULT_PROMPT),
interpreter_executor=ActionExecutor(actions=[
IPythonInteractiveManager(
max_workers=200,
ci_lock=os.path.join(
os.path.dirname(__file__), '.ipython.lock'))
IPythonInteractiveManager(max_workers=200,
ci_lock=os.path.join(
os.path.dirname(__file__),
'.ipython.lock'))
]),
max_turn=max_turn)
return agent
def predict(args):
if args.dataset == 'gsm8k':
def process(d, k):
d['answer'] = re.sub(r'#### (.+)', r'The answer is \1',
re.sub(r'<<.*?>>', '',
d['answer'])).replace('$', '')
d['idx'] = k
d['query'] = d['question'].replace('$', '')
d['gt'] = re.search('The answer is (.+)', d['answer'])[1]
d['pred'], d['steps'], d['error'] = [], [], None
return d
dataset = load_dataset(
'gsm8k', 'main', split='test').map(process, True)
elif args.dataset == 'math':
def process(d, k):
d['idx'] = k
d['query'] = d['problem']
gt = extract_answer(d['solution'])
if '\\boxed{90\\text{ square\nunits}}' in d['solution']:
gt = '90'
elif '$6$ is our answer' in d['solution']:
gt = '6'
elif gt.startswith('x\\in'):
gt = gt[len('x\\in'):]
gt = strip_string(gt)
d['gt'] = gt
d['pred'], d['steps'] = [], []
d['error'] = None
return d
dataset = load_dataset(
'lighteval/MATH', split='test').map(process, True)
else:
raise NotImplementedError
def process(d, k):
d['idx'] = k
d['query'] = d['problem']
gt = extract_answer(d['solution'])
if '\\boxed{90\\text{ square\nunits}}' in d['solution']:
gt = '90'
elif '$6$ is our answer' in d['solution']:
gt = '6'
elif gt.startswith('x\\in'):
gt = gt[len('x\\in'):]
gt = strip_string(gt)
d['gt'] = gt
d['pred'], d['steps'] = [], []
d['error'] = None
return d
dataset = load_dataset('lighteval/MATH', split='test').map(process, True)
agent = init_agent(
backend=args.backend,
max_turn=args.max_turn,
@ -601,19 +571,14 @@ def predict(args):
batch = dataset.select(
range(i * args.batch_size,
min((i + 1) * args.batch_size, len(dataset))))
# for item in tqdm(
# dataset if not args.debug else dataset.select(range(50))):
try:
rets = agent.batch_chat(batch['query'])
for item, ret in zip(batch, rets):
item['steps'] = ret.inner_steps
lang = [
step for step in item['steps']
if step['role'] == 'language'
]
item['pred'].append('😭' if not lang else extract_answer(
lang[-1]['content']) or '😭')
last = item['steps'][-1]
item['pred'].append(
extract_answer(last['content']) if last['role'] ==
'language' else '😭')
f.write(item)
except Exception as e:
err = str(traceback.format_exc())
@ -651,6 +616,7 @@ def evaluate(args):
timeout_cnt += 1
except Exception as error:
print(error.__traceback__)
scores.append(False)
# sys.exit()
progress_bar.update(1)