mirror of https://github.com/InternLM/InternLM
Support inference and evaluation with Math Code Interpreter (#695)
Co-authored-by: wangzy <wangziyi@pjlab.org.cn>pull/721/head
parent
43b7582201
commit
2b221a9f17
|
@ -261,7 +261,7 @@ To learn more about data contamination assessment, please check the [contaminati
|
||||||
### Agent Evaluation
|
### Agent Evaluation
|
||||||
|
|
||||||
- To evaluate tool utilization, please refer to [T-Eval](https://github.com/open-compass/T-Eval).
|
- To evaluate tool utilization, please refer to [T-Eval](https://github.com/open-compass/T-Eval).
|
||||||
- For code interpreter evaluation, use the [gsm-8k-agent](https://github.com/open-compass/opencompass/blob/main/configs/datasets/gsm8k/gsm8k_agent_gen_be1606.py) provided in the repository. Additionally, you need to install [Lagent](https://github.com/InternLM/lagent).
|
- For code interpreter evaluation, use the [Math Agent Evaluation](agent/README.md) provided in the repository.
|
||||||
|
|
||||||
### Subjective Evaluation
|
### Subjective Evaluation
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ InternLM2-Chat, open sourced on January 17, 2024, further enhances its capabilit
|
||||||
The results of InternLM2-Chat-20B on math code interpreter is as below:
|
The results of InternLM2-Chat-20B on math code interpreter is as below:
|
||||||
|
|
||||||
| | GSM8K | MATH |
|
| | GSM8K | MATH |
|
||||||
| :--------------------------------------: | :---: | :--: |
|
| :--------------------------------------: | :---: | :---: |
|
||||||
| InternLM2-Chat-20B | 79.6 | 32.5 |
|
| InternLM2-Chat-20B | 79.6 | 32.5 |
|
||||||
| InternLM2-Chat-20B with Code Interpreter | 84.5 | 51.2 |
|
| InternLM2-Chat-20B with Code Interpreter | 84.5 | 51.2 |
|
||||||
| ChatGPT (GPT-3.5) | 78.2 | 28.0 |
|
| ChatGPT (GPT-3.5) | 78.2 | 28.0 |
|
||||||
|
@ -19,4 +19,69 @@ The results of InternLM2-Chat-20B on math code interpreter is as below:
|
||||||
|
|
||||||
## Usages
|
## Usages
|
||||||
|
|
||||||
We offer examples using [Lagent](lagent.md) to build agents based on InternLM2-Chat to call code interpreter or search API. Additionally, we provide an example code using [PAL to evaluate GSM8K math problems](pal_inference.md) with InternLM-Chat-7B.
|
We offer an example using [Lagent](lagent.md) to build agents based on InternLM2-Chat to call the code interpreter. Firstly install the extra dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the following script to perform inference and evaluation on GSM8K and MATH test.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python streaming_inference.py \
|
||||||
|
--backend=lmdeploy \ # For HuggingFace models: hf
|
||||||
|
--model_path=internlm/internlm2-chat-20b \
|
||||||
|
--tp=2 \
|
||||||
|
--temperature=0.0 \
|
||||||
|
--dataset=math \
|
||||||
|
--output_path=math_lmdeploy.jsonl \
|
||||||
|
--do_eval
|
||||||
|
```
|
||||||
|
|
||||||
|
`output_path` is a jsonl format file to save the inference results. Each line is like
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"idx": 41,
|
||||||
|
"query": "The point $(a, b)$ lies on the line with the equation $3x + 2y = 12.$ When $a = 4$, what is the value of $b$?",
|
||||||
|
"gt": "0",
|
||||||
|
"pred": ["0"],
|
||||||
|
"steps": [
|
||||||
|
{
|
||||||
|
"role": "language",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": {
|
||||||
|
"name": "IPythonInteractive",
|
||||||
|
"parameters": {
|
||||||
|
"command": "```python\nfrom sympy import symbols, solve\n\ndef find_b():\n x, y = symbols('x y')\n equation = 3*x + 2*y - 12\n b = solve(equation.subs(x, 4), y)[0]\n\n return b\n\nresult = find_b()\nprint(result)\n```"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"name": "interpreter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "environment",
|
||||||
|
"content": "0",
|
||||||
|
"name": "interpreter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "language",
|
||||||
|
"content": "The value of $b$ when $a = 4$ is $\\boxed{0}$."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"error": null
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Once it is prepared, just skip the inference stage as follows.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python streaming_inference.py \
|
||||||
|
--output_path=math_lmdeploy.jsonl \
|
||||||
|
--no-do_infer \
|
||||||
|
--do_eval
|
||||||
|
```
|
||||||
|
|
||||||
|
Please refer to [`streaming_inference.py`](streaming_inference.py) for more information about the arguments.
|
||||||
|
|
|
@ -11,7 +11,7 @@ InternLM2-Chat 进一步提高了它在代码解释和通用工具调用方面
|
||||||
以下是 InternLM2-Chat-20B 在数学代码解释器上的结果。
|
以下是 InternLM2-Chat-20B 在数学代码解释器上的结果。
|
||||||
|
|
||||||
| | GSM8K | MATH |
|
| | GSM8K | MATH |
|
||||||
| :---------------------------------: | :---: | :--: |
|
| :---------------------------------: | :---: | :---: |
|
||||||
| InternLM2-Chat-20B 单纯依靠内在能力 | 79.6 | 32.5 |
|
| InternLM2-Chat-20B 单纯依靠内在能力 | 79.6 | 32.5 |
|
||||||
| InternLM2-Chat-20B 配合代码解释器 | 84.5 | 51.2 |
|
| InternLM2-Chat-20B 配合代码解释器 | 84.5 | 51.2 |
|
||||||
| ChatGPT (GPT-3.5) | 78.2 | 28.0 |
|
| ChatGPT (GPT-3.5) | 78.2 | 28.0 |
|
||||||
|
@ -19,4 +19,69 @@ InternLM2-Chat 进一步提高了它在代码解释和通用工具调用方面
|
||||||
|
|
||||||
## 体验
|
## 体验
|
||||||
|
|
||||||
我们提供了使用 [Lagent](lagent_zh-CN.md) 来基于 InternLM2-Chat 构建智能体调用代码解释器或者搜索等工具的例子。同时,我们也提供了采用 [PAL 评测 GSM8K 数学题](pal_inference_zh-CN.md) InternLM-Chat-7B 的样例。
|
我们提供了使用 [Lagent](lagent_zh-CN.md) 来基于 InternLM2-Chat 构建智能体调用代码解释器的例子。首先安装额外依赖:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
运行以下脚本在 GSM8K 和 MATH 测试集上进行推理和评估:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python streaming_inference.py \
|
||||||
|
--backend=lmdeploy \ # For HuggingFace models: hf
|
||||||
|
--model_path=internlm/internlm2-chat-20b \
|
||||||
|
--tp=2 \
|
||||||
|
--temperature=0.0 \
|
||||||
|
--dataset=math \
|
||||||
|
--output_path=math_lmdeploy.jsonl \
|
||||||
|
--do_eval
|
||||||
|
```
|
||||||
|
|
||||||
|
`output_path` 是一个存储推理结果的 jsonl 格式文件,每行形如:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"idx": 41,
|
||||||
|
"query": "The point $(a, b)$ lies on the line with the equation $3x + 2y = 12.$ When $a = 4$, what is the value of $b$?",
|
||||||
|
"gt": "0",
|
||||||
|
"pred": ["0"],
|
||||||
|
"steps": [
|
||||||
|
{
|
||||||
|
"role": "language",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": {
|
||||||
|
"name": "IPythonInteractive",
|
||||||
|
"parameters": {
|
||||||
|
"command": "```python\nfrom sympy import symbols, solve\n\ndef find_b():\n x, y = symbols('x y')\n equation = 3*x + 2*y - 12\n b = solve(equation.subs(x, 4), y)[0]\n\n return b\n\nresult = find_b()\nprint(result)\n```"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"name": "interpreter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "environment",
|
||||||
|
"content": "0",
|
||||||
|
"name": "interpreter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "language",
|
||||||
|
"content": "The value of $b$ when $a = 4$ is $\\boxed{0}$."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"error": null
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
如果已经准备好了该文件,可直接跳过推理阶段进行评估:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python streaming_inference.py \
|
||||||
|
--output_path=math_lmdeploy.jsonl \
|
||||||
|
--no-do_infer \
|
||||||
|
--do_eval
|
||||||
|
```
|
||||||
|
|
||||||
|
请参考 [`streaming_inference.py`](streaming_inference.py) 获取更多关于参数的信息。
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
lmdeploy>=0.2.2
|
||||||
|
datasets
|
||||||
|
tqdm
|
||||||
|
numpy
|
||||||
|
pebble
|
||||||
|
jsonlines
|
||||||
|
sympy==1.12
|
||||||
|
antlr4-python3-runtime==4.11.0
|
||||||
|
lagent
|
||||||
|
einops
|
|
@ -0,0 +1,681 @@
|
||||||
|
# flake8: noqa
|
||||||
|
# isort: skip_file
|
||||||
|
|
||||||
|
# This logic is modified from ToRA:
|
||||||
|
# - https://github.com/microsoft/ToRA
|
||||||
|
#
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from math import isclose
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
import numpy as np
|
||||||
|
from datasets import load_dataset
|
||||||
|
from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer,
|
||||||
|
Internlm2Agent, Internlm2Protocol, LMDeployPipeline,
|
||||||
|
get_tool)
|
||||||
|
from pebble import ProcessPool
|
||||||
|
from sympy import N, simplify
|
||||||
|
from sympy.parsing.latex import parse_latex
|
||||||
|
from sympy.parsing.sympy_parser import parse_expr
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
DEFAULT_PROMPT = (
|
||||||
|
'Integrate step-by-step reasoning and Python code to solve math problems '
|
||||||
|
'using the following guidelines:\n'
|
||||||
|
'- Just write jupyter code to solve the problem without giving your thought;\n'
|
||||||
|
r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
|
||||||
|
'units. \n')
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Math Code Interpreter')
|
||||||
|
parser.add_argument('--backend',
|
||||||
|
type=str,
|
||||||
|
default='lmdeploy',
|
||||||
|
help='Which inference framework to use.',
|
||||||
|
choices=['lmdeploy', 'hf'])
|
||||||
|
parser.add_argument(
|
||||||
|
'--model_path',
|
||||||
|
type=str,
|
||||||
|
default='internlm/internlm2-chat-7b',
|
||||||
|
help='Path or name to the model, could be HuggingFace model specifier.'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--output_path',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='Path to save inference results to, should be a `jsonl` file')
|
||||||
|
parser.add_argument('--dataset',
|
||||||
|
type=str,
|
||||||
|
default='math',
|
||||||
|
choices=['gsm8k', 'math'],
|
||||||
|
help='Dataset for inference')
|
||||||
|
parser.add_argument(
|
||||||
|
'--tp',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='Number of tensor parallelism. It may be required in LMDelpoy.')
|
||||||
|
parser.add_argument('--temperature',
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help='Temperature in next token prediction')
|
||||||
|
parser.add_argument('--top_p',
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help='Parameter for Top-P Sampling.')
|
||||||
|
parser.add_argument('--top_k',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='Parameter for Top-K Sampling.')
|
||||||
|
parser.add_argument('--stop_words',
|
||||||
|
type=str,
|
||||||
|
default=['<|action_end|>', '<|im_end|>'],
|
||||||
|
action='append',
|
||||||
|
help='Stop words')
|
||||||
|
parser.add_argument('--max_new_tokens',
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help='Number of maximum generated tokens.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--do_infer',
|
||||||
|
default=True,
|
||||||
|
action=argparse.BooleanOptionalAction, # python > 3.8
|
||||||
|
help='Whether to launch model inference.')
|
||||||
|
# parser.add_argument(
|
||||||
|
# '--no-do_infer',
|
||||||
|
# dest='do_infer',
|
||||||
|
# action='store_false',
|
||||||
|
# help='Disable the inference.'
|
||||||
|
# )
|
||||||
|
parser.add_argument('--do_eval',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='Whether to evaluate the inference results.')
|
||||||
|
parser.add_argument('--overwrite',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='Whether to overwrite the existing result file')
|
||||||
|
parser.add_argument('--debug',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='Only infer the first 50 samples')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_fracs(string):
|
||||||
|
substrs = string.split('\\frac')
|
||||||
|
new_str = substrs[0]
|
||||||
|
if len(substrs) > 1:
|
||||||
|
substrs = substrs[1:]
|
||||||
|
for substr in substrs:
|
||||||
|
new_str += '\\frac'
|
||||||
|
if len(substr) > 0 and substr[0] == '{':
|
||||||
|
new_str += substr
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
assert len(substr) >= 2
|
||||||
|
except Exception:
|
||||||
|
return string
|
||||||
|
a = substr[0]
|
||||||
|
b = substr[1]
|
||||||
|
if b != '{':
|
||||||
|
if len(substr) > 2:
|
||||||
|
post_substr = substr[2:]
|
||||||
|
new_str += '{' + a + '}{' + b + '}' + post_substr
|
||||||
|
else:
|
||||||
|
new_str += '{' + a + '}{' + b + '}'
|
||||||
|
else:
|
||||||
|
if len(substr) > 2:
|
||||||
|
post_substr = substr[2:]
|
||||||
|
new_str += '{' + a + '}' + b + post_substr
|
||||||
|
else:
|
||||||
|
new_str += '{' + a + '}' + b
|
||||||
|
string = new_str
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_a_slash_b(string):
|
||||||
|
if len(string.split('/')) != 2:
|
||||||
|
return string
|
||||||
|
a = string.split('/')[0]
|
||||||
|
b = string.split('/')[1]
|
||||||
|
try:
|
||||||
|
if 'sqrt' not in a:
|
||||||
|
a = int(a)
|
||||||
|
if 'sqrt' not in b:
|
||||||
|
b = int(b)
|
||||||
|
assert string == '{}/{}'.format(a, b)
|
||||||
|
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
|
||||||
|
return new_string
|
||||||
|
except Exception:
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_sqrt(string):
|
||||||
|
_string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string)
|
||||||
|
return _string
|
||||||
|
|
||||||
|
|
||||||
|
def strip_string(string):
|
||||||
|
string = str(string).strip()
|
||||||
|
# linebreaks
|
||||||
|
string = string.replace('\n', '')
|
||||||
|
|
||||||
|
# right "."
|
||||||
|
string = string.rstrip('.')
|
||||||
|
|
||||||
|
# remove inverse spaces
|
||||||
|
string = string.replace('\\!', '')
|
||||||
|
string = string.replace('\\ ', '')
|
||||||
|
|
||||||
|
# replace \\ with \
|
||||||
|
string = string.replace('\\\\', '\\')
|
||||||
|
string = string.replace('\\\\', '\\')
|
||||||
|
|
||||||
|
# replace tfrac and dfrac with frac
|
||||||
|
string = string.replace('tfrac', 'frac')
|
||||||
|
string = string.replace('dfrac', 'frac')
|
||||||
|
|
||||||
|
# remove \left and \right
|
||||||
|
string = string.replace('\\left', '')
|
||||||
|
string = string.replace('\\right', '')
|
||||||
|
|
||||||
|
# Remove unit: miles, dollars if after is not none
|
||||||
|
_string = re.sub(r'\\text{.*?}$', '', string).strip()
|
||||||
|
if _string != '' and _string != string:
|
||||||
|
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
||||||
|
string = _string
|
||||||
|
|
||||||
|
# Remove circ (degrees)
|
||||||
|
string = string.replace('^{\\circ}', '')
|
||||||
|
string = string.replace('^\\circ', '')
|
||||||
|
|
||||||
|
# remove dollar signs
|
||||||
|
string = string.replace('\\$', '')
|
||||||
|
string = string.replace('$', '')
|
||||||
|
|
||||||
|
string = string.replace('\\text', '')
|
||||||
|
string = string.replace('x\\in', '')
|
||||||
|
|
||||||
|
# remove percentage
|
||||||
|
string = string.replace('\\%', '')
|
||||||
|
string = string.replace('\%', '')
|
||||||
|
string = string.replace('%', '')
|
||||||
|
|
||||||
|
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||||
|
string = string.replace(' .', ' 0.')
|
||||||
|
string = string.replace('{.', '{0.')
|
||||||
|
|
||||||
|
# cdot
|
||||||
|
string = string.replace('\\cdot', '')
|
||||||
|
|
||||||
|
# inf
|
||||||
|
string = string.replace('infinity', '\\infty')
|
||||||
|
if '\\infty' not in string:
|
||||||
|
string = string.replace('inf', '\\infty')
|
||||||
|
string = string.replace('+\\inity', '\\infty')
|
||||||
|
|
||||||
|
# and
|
||||||
|
string = string.replace('and', '')
|
||||||
|
string = string.replace('\\mathbf', '')
|
||||||
|
|
||||||
|
# use regex to remove \mbox{...}
|
||||||
|
string = re.sub(r'\\mbox{.*?}', '', string)
|
||||||
|
|
||||||
|
# quote
|
||||||
|
string.replace("'", '')
|
||||||
|
string.replace('"', '')
|
||||||
|
|
||||||
|
# i, j
|
||||||
|
if 'j' in string and 'i' not in string:
|
||||||
|
string = string.replace('j', 'i')
|
||||||
|
|
||||||
|
# replace a.000b where b is not number or b is end, with ab, use regex
|
||||||
|
string = re.sub(r'(\d+)\.0+([^\d])', r'\1\2', string)
|
||||||
|
string = re.sub(r'(\d+)\.0+$', r'\1', string)
|
||||||
|
|
||||||
|
# if empty, return empty string
|
||||||
|
if len(string) == 0:
|
||||||
|
return string
|
||||||
|
if string[0] == '.':
|
||||||
|
string = '0' + string
|
||||||
|
|
||||||
|
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||||
|
if len(string.split('=')) == 2:
|
||||||
|
if len(string.split('=')[0]) <= 2:
|
||||||
|
string = string.split('=')[1]
|
||||||
|
|
||||||
|
string = _fix_sqrt(string)
|
||||||
|
string = string.replace(' ', '')
|
||||||
|
|
||||||
|
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||||
|
string = _fix_fracs(string)
|
||||||
|
|
||||||
|
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||||
|
string = _fix_a_slash_b(string)
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def last_boxed_only_string(string):
|
||||||
|
idx = string.rfind('\\boxed')
|
||||||
|
if idx < 0:
|
||||||
|
idx = string.rfind('\\fbox')
|
||||||
|
if idx < 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
i = idx
|
||||||
|
right_brace_idx = None
|
||||||
|
num_left_braces_open = 0
|
||||||
|
while i < len(string):
|
||||||
|
if string[i] == '{':
|
||||||
|
num_left_braces_open += 1
|
||||||
|
if string[i] == '}':
|
||||||
|
num_left_braces_open -= 1
|
||||||
|
if num_left_braces_open == 0:
|
||||||
|
right_brace_idx = i
|
||||||
|
break
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if right_brace_idx is None:
|
||||||
|
retval = None
|
||||||
|
else:
|
||||||
|
retval = string[idx:right_brace_idx + 1]
|
||||||
|
|
||||||
|
return retval
|
||||||
|
|
||||||
|
|
||||||
|
def extract_answer(pred_str):
|
||||||
|
if 'boxed' not in pred_str:
|
||||||
|
return ''
|
||||||
|
answer = pred_str.split('boxed')[-1]
|
||||||
|
if len(answer) == 0:
|
||||||
|
return ''
|
||||||
|
elif (answer[0] == '{'):
|
||||||
|
stack = 1
|
||||||
|
a = ''
|
||||||
|
for c in answer[1:]:
|
||||||
|
if (c == '{'):
|
||||||
|
stack += 1
|
||||||
|
a += c
|
||||||
|
elif (c == '}'):
|
||||||
|
stack -= 1
|
||||||
|
if (stack == 0): break
|
||||||
|
a += c
|
||||||
|
else:
|
||||||
|
a += c
|
||||||
|
else:
|
||||||
|
a = answer.split('$')[0].strip()
|
||||||
|
|
||||||
|
pred = a.split('\n')[0]
|
||||||
|
if pred != '' and pred[0] == ':':
|
||||||
|
pred = pred[1:]
|
||||||
|
if pred != '' and pred[-1] == '.':
|
||||||
|
pred = pred[:-1]
|
||||||
|
if pred != '' and pred[-1] == '/':
|
||||||
|
pred = pred[:-1]
|
||||||
|
pred = strip_string(pred)
|
||||||
|
return pred
|
||||||
|
|
||||||
|
|
||||||
|
def is_digit(s):
|
||||||
|
try:
|
||||||
|
float(str(s).replace(',', ''))
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def math_equal(
|
||||||
|
prediction: Union[bool, float, str],
|
||||||
|
reference: Union[float, str],
|
||||||
|
include_percentage: bool = True,
|
||||||
|
is_close: bool = True,
|
||||||
|
tolerance: float = 1e-4,
|
||||||
|
timeout: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""Exact match of math if and only if:
|
||||||
|
|
||||||
|
1. numerical equal: both can convert to float and are equal
|
||||||
|
2. symbolic equal: both can convert to sympy expression and are equal
|
||||||
|
"""
|
||||||
|
try: # 1. numerical equal
|
||||||
|
if is_digit(prediction) and is_digit(reference):
|
||||||
|
prediction = float(str(prediction).replace(',', ''))
|
||||||
|
reference = float(str(reference).replace(',', ''))
|
||||||
|
# number questions
|
||||||
|
if include_percentage:
|
||||||
|
gt_result = [reference / 100, reference, reference * 100]
|
||||||
|
else:
|
||||||
|
gt_result = [reference]
|
||||||
|
for item in gt_result:
|
||||||
|
try:
|
||||||
|
if is_close:
|
||||||
|
if isclose(item, prediction, rel_tol=tolerance):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if item == prediction:
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not prediction and prediction not in [0, False]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 2. symbolic equal
|
||||||
|
reference = str(reference).strip()
|
||||||
|
prediction = str(prediction).strip()
|
||||||
|
|
||||||
|
## deal with [], (), {}
|
||||||
|
pred_str, ref_str = prediction, reference
|
||||||
|
if (prediction.startswith('[') and prediction.endswith(']')
|
||||||
|
and not reference.startswith('(')) or (
|
||||||
|
prediction.startswith('(') and prediction.endswith(')')
|
||||||
|
and not reference.startswith('[')):
|
||||||
|
pred_str = pred_str.strip('[]()')
|
||||||
|
ref_str = ref_str.strip('[]()')
|
||||||
|
for s in ['{', '}', '(', ')']:
|
||||||
|
ref_str = ref_str.replace(s, '')
|
||||||
|
pred_str = pred_str.replace(s, '')
|
||||||
|
if pred_str == ref_str:
|
||||||
|
return True
|
||||||
|
|
||||||
|
## [a, b] vs. [c, d], return a==c and b==d
|
||||||
|
if ((prediction.startswith('[') and prediction.endswith(']')) and
|
||||||
|
(reference.startswith('[') and reference.endswith(']'))
|
||||||
|
or (prediction.startswith('(') and prediction.endswith(')')) and
|
||||||
|
(reference.startswith('(') and reference.endswith(')'))):
|
||||||
|
pred_parts = prediction[1:-1].split(',')
|
||||||
|
ref_parts = reference[1:-1].split(',')
|
||||||
|
if len(pred_parts) == len(ref_parts):
|
||||||
|
if all([
|
||||||
|
math_equal(pred_parts[i], ref_parts[i], include_percentage,
|
||||||
|
is_close) for i in range(len(pred_parts))
|
||||||
|
]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# symbolic equal with sympy
|
||||||
|
if timeout:
|
||||||
|
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if symbolic_equal(prediction, reference):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def math_equal_process(param):
|
||||||
|
return math_equal(param[-2], param[-1])
|
||||||
|
|
||||||
|
|
||||||
|
def symbolic_equal(a, b):
|
||||||
|
|
||||||
|
def _parse(s):
|
||||||
|
for f in [parse_latex, parse_expr]:
|
||||||
|
try:
|
||||||
|
return f(s)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return s
|
||||||
|
|
||||||
|
a = _parse(a)
|
||||||
|
b = _parse(b)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if simplify(a - b) == 0:
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isclose(N(a), N(b), rel_tol=1e-3):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def symbolic_equal_process(a, b, output_queue):
|
||||||
|
result = symbolic_equal(a, b)
|
||||||
|
output_queue.put(result)
|
||||||
|
|
||||||
|
|
||||||
|
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
||||||
|
output_queue = multiprocessing.Queue()
|
||||||
|
process_args = args + (output_queue, )
|
||||||
|
process = multiprocessing.Process(target=func,
|
||||||
|
args=process_args,
|
||||||
|
kwargs=kwargs)
|
||||||
|
process.start()
|
||||||
|
process.join(timeout)
|
||||||
|
|
||||||
|
if process.is_alive():
|
||||||
|
process.terminate()
|
||||||
|
process.join()
|
||||||
|
return False
|
||||||
|
|
||||||
|
return output_queue.get()
|
||||||
|
|
||||||
|
|
||||||
|
def init_agent(backend: str, model_path: str, tp: int, **kwargs):
|
||||||
|
if backend == 'lmdeploy':
|
||||||
|
model = LMDeployPipeline(path=model_path,
|
||||||
|
meta_template=INTERNLM2_META,
|
||||||
|
tp=tp,
|
||||||
|
**kwargs)
|
||||||
|
elif backend == 'hf':
|
||||||
|
model = HFTransformer(path=model_path,
|
||||||
|
meta_template=INTERNLM2_META,
|
||||||
|
**kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
agent = Internlm2Agent(llm=model,
|
||||||
|
protocol=Internlm2Protocol(
|
||||||
|
meta_prompt=None,
|
||||||
|
interpreter_prompt=DEFAULT_PROMPT),
|
||||||
|
interpreter_executor=ActionExecutor(
|
||||||
|
actions=[get_tool('IPythonInteractive')]))
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
def predict(args):
|
||||||
|
if args.dataset == 'gsm8k':
|
||||||
|
|
||||||
|
def process(d, k):
|
||||||
|
d['answer'] = re.sub(r'#### (.+)', r'The answer is \1',
|
||||||
|
re.sub(r'<<.*?>>', '',
|
||||||
|
d['answer'])).replace('$', '')
|
||||||
|
d['idx'] = k
|
||||||
|
d['query'] = d['question'].replace('$', '')
|
||||||
|
d['gt'] = re.search('The answer is (.+)', d['answer'])[1]
|
||||||
|
d['pred'], d['steps'], d['error'] = [], [], None
|
||||||
|
return d
|
||||||
|
|
||||||
|
dataset = load_dataset('gsm8k', 'main',
|
||||||
|
split='test').map(process, True)
|
||||||
|
|
||||||
|
elif args.dataset == 'math':
|
||||||
|
|
||||||
|
def process(d, k):
|
||||||
|
d['idx'] = k
|
||||||
|
d['query'] = d['problem']
|
||||||
|
gt = extract_answer(d['solution'])
|
||||||
|
if '\\boxed{90\\text{ square\nunits}}' in d['solution']:
|
||||||
|
gt = '90'
|
||||||
|
elif '$6$ is our answer' in d['solution']:
|
||||||
|
gt = '6'
|
||||||
|
elif gt.startswith('x\\in'):
|
||||||
|
gt = gt[len('x\\in'):]
|
||||||
|
gt = strip_string(gt)
|
||||||
|
d['gt'] = gt
|
||||||
|
d['pred'], d['steps'] = [], []
|
||||||
|
d['error'] = None
|
||||||
|
return d
|
||||||
|
|
||||||
|
dataset = load_dataset('lighteval/MATH',
|
||||||
|
split='test').map(process, True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
agent = init_agent(
|
||||||
|
backend=args.backend,
|
||||||
|
model_path=args.model_path,
|
||||||
|
tp=args.tp,
|
||||||
|
temperature=args.temperature,
|
||||||
|
stop_words=args.stop_words,
|
||||||
|
top_p=args.top_p,
|
||||||
|
top_k=args.top_k,
|
||||||
|
max_new_tokens=args.max_new_tokens,
|
||||||
|
)
|
||||||
|
with jsonlines.open(args.output_path, 'w') as f:
|
||||||
|
for item in tqdm(
|
||||||
|
dataset if not args.debug else dataset.select(range(50))):
|
||||||
|
try:
|
||||||
|
ret = agent.chat(item['query'])
|
||||||
|
item['steps'] = ret.inner_steps
|
||||||
|
|
||||||
|
lang = [
|
||||||
|
step for step in item['steps']
|
||||||
|
if step['role'] == 'language'
|
||||||
|
]
|
||||||
|
item['pred'].append('😭' if not lang else
|
||||||
|
extract_answer(lang[-1]['content']) or '😭')
|
||||||
|
agent._interpreter_executor.actions[
|
||||||
|
'IPythonInteractive'].reset()
|
||||||
|
except Exception as e:
|
||||||
|
err = str(traceback.format_exc())
|
||||||
|
print(f'Error processing index {item["idx"]}: {e}\n{err}')
|
||||||
|
item['error'] = err
|
||||||
|
f.write(item)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(args):
|
||||||
|
samples = [sample for sample in jsonlines.open(args.output_path)]
|
||||||
|
scores = []
|
||||||
|
timeout_cnt = 0
|
||||||
|
with ProcessPool() as pool:
|
||||||
|
future = pool.map(
|
||||||
|
math_equal_process,
|
||||||
|
[(idx, pred, sample['gt']) for idx, sample in enumerate(samples)
|
||||||
|
for pred in sample['pred']],
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
|
iterator = future.result()
|
||||||
|
with tqdm(total=len(samples), desc='Evaluate') as progress_bar:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
result = next(iterator)
|
||||||
|
scores.append(result)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
except TimeoutError as error:
|
||||||
|
print(error)
|
||||||
|
scores.append(False)
|
||||||
|
timeout_cnt += 1
|
||||||
|
except Exception as error:
|
||||||
|
print(error.__traceback__)
|
||||||
|
sys.exit()
|
||||||
|
progress_bar.update(1)
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
score_mat = []
|
||||||
|
for sample in samples:
|
||||||
|
sample['score'] = scores[idx:idx + len(sample['pred'])]
|
||||||
|
assert len(sample['score']) == len(sample['pred'])
|
||||||
|
score_mat.append(sample['score'])
|
||||||
|
idx += len(sample['pred'])
|
||||||
|
|
||||||
|
max_len = max([len(s) for s in score_mat])
|
||||||
|
|
||||||
|
for i, s in enumerate(score_mat):
|
||||||
|
if len(s) < max_len:
|
||||||
|
score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad
|
||||||
|
|
||||||
|
# output mean of each column of scores
|
||||||
|
col_means = np.array(score_mat).mean(axis=0)
|
||||||
|
mean_score = list(np.round(col_means * 100, decimals=1))
|
||||||
|
|
||||||
|
result_str = f'Num samples: {len(samples)}\n' \
|
||||||
|
f'Num scores: {len(scores)}\n' \
|
||||||
|
f'Sum scores: {sum(scores)}\n' \
|
||||||
|
f'Timeout samples: {timeout_cnt}\n' \
|
||||||
|
f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \
|
||||||
|
f'Mean score: {mean_score}\n'
|
||||||
|
|
||||||
|
# each type score
|
||||||
|
if 'type' in samples[0]:
|
||||||
|
type_scores = {}
|
||||||
|
for sample in samples:
|
||||||
|
if sample['type'] not in type_scores:
|
||||||
|
type_scores[sample['type']] = []
|
||||||
|
type_scores[sample['type']].append(sample['score'][-1])
|
||||||
|
type_scores = {
|
||||||
|
k: np.round(np.array(v).mean() * 100, decimals=1)
|
||||||
|
for k, v in type_scores.items()
|
||||||
|
}
|
||||||
|
type_scores = {
|
||||||
|
k: v
|
||||||
|
for k, v in sorted(type_scores.items(), key=lambda item: item[0])
|
||||||
|
}
|
||||||
|
result_str += f'Type scores: {type_scores}\n'
|
||||||
|
|
||||||
|
print(result_str)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
if args.do_infer and os.path.exists(
|
||||||
|
args.output_path) and not args.overwrite:
|
||||||
|
args.do_infer = False
|
||||||
|
print(f'File {args.output_path} already exists. '
|
||||||
|
f'Please add the `--overwrite` flag if needed.')
|
||||||
|
if args.do_infer:
|
||||||
|
predict(args)
|
||||||
|
if args.do_eval:
|
||||||
|
if not args.do_infer:
|
||||||
|
evaluate(args)
|
||||||
|
else:
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
res = subprocess.run(
|
||||||
|
[
|
||||||
|
sys.executable, __file__, '--output_path',
|
||||||
|
args.output_path, '--no-do_infer', '--do_eval'
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
print(res.stdout)
|
Loading…
Reference in New Issue