Support inference and evaluation with Math Code Interpreter (#695)

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>
pull/721/head
BraisedPork 2024-03-08 14:32:33 +08:00 committed by GitHub
parent 43b7582201
commit 2b221a9f17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 836 additions and 15 deletions

View File

@ -261,7 +261,7 @@ To learn more about data contamination assessment, please check the [contaminati
### Agent Evaluation ### Agent Evaluation
- To evaluate tool utilization, please refer to [T-Eval](https://github.com/open-compass/T-Eval). - To evaluate tool utilization, please refer to [T-Eval](https://github.com/open-compass/T-Eval).
- For code interpreter evaluation, use the [gsm-8k-agent](https://github.com/open-compass/opencompass/blob/main/configs/datasets/gsm8k/gsm8k_agent_gen_be1606.py) provided in the repository. Additionally, you need to install [Lagent](https://github.com/InternLM/lagent). - For code interpreter evaluation, use the [Math Agent Evaluation](agent/README.md) provided in the repository.
### Subjective Evaluation ### Subjective Evaluation

View File

@ -10,13 +10,78 @@ InternLM2-Chat, open sourced on January 17, 2024, further enhances its capabilit
The results of InternLM2-Chat-20B on math code interpreter is as below: The results of InternLM2-Chat-20B on math code interpreter is as below:
| | GSM8K | MATH | | | GSM8K | MATH |
| :--------------------------------------: | :---: | :--: | | :--------------------------------------: | :---: | :---: |
| InternLM2-Chat-20B | 79.6 | 32.5 | | InternLM2-Chat-20B | 79.6 | 32.5 |
| InternLM2-Chat-20B with Code Interpreter | 84.5 | 51.2 | | InternLM2-Chat-20B with Code Interpreter | 84.5 | 51.2 |
| ChatGPT (GPT-3.5) | 78.2 | 28.0 | | ChatGPT (GPT-3.5) | 78.2 | 28.0 |
| GPT-4 | 91.4 | 45.8 | | GPT-4 | 91.4 | 45.8 |
## Usages ## Usages
We offer examples using [Lagent](lagent.md) to build agents based on InternLM2-Chat to call code interpreter or search API. Additionally, we provide an example code using [PAL to evaluate GSM8K math problems](pal_inference.md) with InternLM-Chat-7B. We offer an example using [Lagent](lagent.md) to build agents based on InternLM2-Chat to call the code interpreter. Firstly install the extra dependencies:
```bash
pip install -r requirements.txt
```
Run the following script to perform inference and evaluation on GSM8K and MATH test.
```bash
python streaming_inference.py \
--backend=lmdeploy \ # For HuggingFace models: hf
--model_path=internlm/internlm2-chat-20b \
--tp=2 \
--temperature=0.0 \
--dataset=math \
--output_path=math_lmdeploy.jsonl \
--do_eval
```
`output_path` is a jsonl format file to save the inference results. Each line is like
```json
{
"idx": 41,
"query": "The point $(a, b)$ lies on the line with the equation $3x + 2y = 12.$ When $a = 4$, what is the value of $b$?",
"gt": "0",
"pred": ["0"],
"steps": [
{
"role": "language",
"content": ""
},
{
"role": "tool",
"content": {
"name": "IPythonInteractive",
"parameters": {
"command": "```python\nfrom sympy import symbols, solve\n\ndef find_b():\n x, y = symbols('x y')\n equation = 3*x + 2*y - 12\n b = solve(equation.subs(x, 4), y)[0]\n\n return b\n\nresult = find_b()\nprint(result)\n```"
}
},
"name": "interpreter"
},
{
"role": "environment",
"content": "0",
"name": "interpreter"
},
{
"role": "language",
"content": "The value of $b$ when $a = 4$ is $\\boxed{0}$."
}
],
"error": null
}
```
Once it is prepared, just skip the inference stage as follows.
```bash
python streaming_inference.py \
--output_path=math_lmdeploy.jsonl \
--no-do_infer \
--do_eval
```
Please refer to [`streaming_inference.py`](streaming_inference.py) for more information about the arguments.

View File

@ -10,13 +10,78 @@ InternLM2-Chat 进一步提高了它在代码解释和通用工具调用方面
以下是 InternLM2-Chat-20B 在数学代码解释器上的结果。 以下是 InternLM2-Chat-20B 在数学代码解释器上的结果。
| | GSM8K | MATH | | | GSM8K | MATH |
| :---------------------------------: | :---: | :--: | | :---------------------------------: | :---: | :---: |
| InternLM2-Chat-20B 单纯依靠内在能力 | 79.6 | 32.5 | | InternLM2-Chat-20B 单纯依靠内在能力 | 79.6 | 32.5 |
| InternLM2-Chat-20B 配合代码解释器 | 84.5 | 51.2 | | InternLM2-Chat-20B 配合代码解释器 | 84.5 | 51.2 |
| ChatGPT (GPT-3.5) | 78.2 | 28.0 | | ChatGPT (GPT-3.5) | 78.2 | 28.0 |
| GPT-4 | 91.4 | 45.8 | | GPT-4 | 91.4 | 45.8 |
## 体验 ## 体验
我们提供了使用 [Lagent](lagent_zh-CN.md) 来基于 InternLM2-Chat 构建智能体调用代码解释器或者搜索等工具的例子。同时,我们也提供了采用 [PAL 评测 GSM8K 数学题](pal_inference_zh-CN.md) InternLM-Chat-7B 的样例。 我们提供了使用 [Lagent](lagent_zh-CN.md) 来基于 InternLM2-Chat 构建智能体调用代码解释器的例子。首先安装额外依赖:
```bash
pip install -r requirements.txt
```
运行以下脚本在 GSM8K 和 MATH 测试集上进行推理和评估:
```bash
python streaming_inference.py \
--backend=lmdeploy \ # For HuggingFace models: hf
--model_path=internlm/internlm2-chat-20b \
--tp=2 \
--temperature=0.0 \
--dataset=math \
--output_path=math_lmdeploy.jsonl \
--do_eval
```
`output_path` 是一个存储推理结果的 jsonl 格式文件,每行形如:
```json
{
"idx": 41,
"query": "The point $(a, b)$ lies on the line with the equation $3x + 2y = 12.$ When $a = 4$, what is the value of $b$?",
"gt": "0",
"pred": ["0"],
"steps": [
{
"role": "language",
"content": ""
},
{
"role": "tool",
"content": {
"name": "IPythonInteractive",
"parameters": {
"command": "```python\nfrom sympy import symbols, solve\n\ndef find_b():\n x, y = symbols('x y')\n equation = 3*x + 2*y - 12\n b = solve(equation.subs(x, 4), y)[0]\n\n return b\n\nresult = find_b()\nprint(result)\n```"
}
},
"name": "interpreter"
},
{
"role": "environment",
"content": "0",
"name": "interpreter"
},
{
"role": "language",
"content": "The value of $b$ when $a = 4$ is $\\boxed{0}$."
}
],
"error": null
}
```
如果已经准备好了该文件,可直接跳过推理阶段进行评估:
```bash
python streaming_inference.py \
--output_path=math_lmdeploy.jsonl \
--no-do_infer \
--do_eval
```
请参考 [`streaming_inference.py`](streaming_inference.py) 获取更多关于参数的信息。

10
agent/requirements.txt Normal file
View File

@ -0,0 +1,10 @@
lmdeploy>=0.2.2
datasets
tqdm
numpy
pebble
jsonlines
sympy==1.12
antlr4-python3-runtime==4.11.0
lagent
einops

View File

@ -0,0 +1,681 @@
# flake8: noqa
# isort: skip_file
# This logic is modified from ToRA:
# - https://github.com/microsoft/ToRA
#
# Copyright (c) Microsoft Corporation.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE
import argparse
import multiprocessing
import os
import re
import sys
import traceback
from math import isclose
from typing import Union
import jsonlines
import numpy as np
from datasets import load_dataset
from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer,
Internlm2Agent, Internlm2Protocol, LMDeployPipeline,
get_tool)
from pebble import ProcessPool
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
from tqdm import tqdm
DEFAULT_PROMPT = (
'Integrate step-by-step reasoning and Python code to solve math problems '
'using the following guidelines:\n'
'- Just write jupyter code to solve the problem without giving your thought;\n'
r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
'units. \n')
def parse_args():
parser = argparse.ArgumentParser(description='Math Code Interpreter')
parser.add_argument('--backend',
type=str,
default='lmdeploy',
help='Which inference framework to use.',
choices=['lmdeploy', 'hf'])
parser.add_argument(
'--model_path',
type=str,
default='internlm/internlm2-chat-7b',
help='Path or name to the model, could be HuggingFace model specifier.'
)
parser.add_argument(
'--output_path',
type=str,
required=True,
help='Path to save inference results to, should be a `jsonl` file')
parser.add_argument('--dataset',
type=str,
default='math',
choices=['gsm8k', 'math'],
help='Dataset for inference')
parser.add_argument(
'--tp',
type=int,
default=1,
help='Number of tensor parallelism. It may be required in LMDelpoy.')
parser.add_argument('--temperature',
type=float,
default=0.1,
help='Temperature in next token prediction')
parser.add_argument('--top_p',
type=float,
default=0.8,
help='Parameter for Top-P Sampling.')
parser.add_argument('--top_k',
type=int,
default=None,
help='Parameter for Top-K Sampling.')
parser.add_argument('--stop_words',
type=str,
default=['<|action_end|>', '<|im_end|>'],
action='append',
help='Stop words')
parser.add_argument('--max_new_tokens',
type=int,
default=512,
help='Number of maximum generated tokens.')
parser.add_argument(
'--do_infer',
default=True,
action=argparse.BooleanOptionalAction, # python > 3.8
help='Whether to launch model inference.')
# parser.add_argument(
# '--no-do_infer',
# dest='do_infer',
# action='store_false',
# help='Disable the inference.'
# )
parser.add_argument('--do_eval',
default=False,
action='store_true',
help='Whether to evaluate the inference results.')
parser.add_argument('--overwrite',
default=False,
action='store_true',
help='Whether to overwrite the existing result file')
parser.add_argument('--debug',
default=False,
action='store_true',
help='Only infer the first 50 samples')
return parser.parse_args()
def _fix_fracs(string):
substrs = string.split('\\frac')
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += '\\frac'
if len(substr) > 0 and substr[0] == '{':
new_str += substr
else:
try:
assert len(substr) >= 2
except Exception:
return string
a = substr[0]
b = substr[1]
if b != '{':
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}{' + b + '}' + post_substr
else:
new_str += '{' + a + '}{' + b + '}'
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}' + b + post_substr
else:
new_str += '{' + a + '}' + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split('/')) != 2:
return string
a = string.split('/')[0]
b = string.split('/')[1]
try:
if 'sqrt' not in a:
a = int(a)
if 'sqrt' not in b:
b = int(b)
assert string == '{}/{}'.format(a, b)
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
return new_string
except Exception:
return string
def _fix_sqrt(string):
_string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string)
return _string
def strip_string(string):
string = str(string).strip()
# linebreaks
string = string.replace('\n', '')
# right "."
string = string.rstrip('.')
# remove inverse spaces
string = string.replace('\\!', '')
string = string.replace('\\ ', '')
# replace \\ with \
string = string.replace('\\\\', '\\')
string = string.replace('\\\\', '\\')
# replace tfrac and dfrac with frac
string = string.replace('tfrac', 'frac')
string = string.replace('dfrac', 'frac')
# remove \left and \right
string = string.replace('\\left', '')
string = string.replace('\\right', '')
# Remove unit: miles, dollars if after is not none
_string = re.sub(r'\\text{.*?}$', '', string).strip()
if _string != '' and _string != string:
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
string = _string
# Remove circ (degrees)
string = string.replace('^{\\circ}', '')
string = string.replace('^\\circ', '')
# remove dollar signs
string = string.replace('\\$', '')
string = string.replace('$', '')
string = string.replace('\\text', '')
string = string.replace('x\\in', '')
# remove percentage
string = string.replace('\\%', '')
string = string.replace('\%', '')
string = string.replace('%', '')
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(' .', ' 0.')
string = string.replace('{.', '{0.')
# cdot
string = string.replace('\\cdot', '')
# inf
string = string.replace('infinity', '\\infty')
if '\\infty' not in string:
string = string.replace('inf', '\\infty')
string = string.replace('+\\inity', '\\infty')
# and
string = string.replace('and', '')
string = string.replace('\\mathbf', '')
# use regex to remove \mbox{...}
string = re.sub(r'\\mbox{.*?}', '', string)
# quote
string.replace("'", '')
string.replace('"', '')
# i, j
if 'j' in string and 'i' not in string:
string = string.replace('j', 'i')
# replace a.000b where b is not number or b is end, with ab, use regex
string = re.sub(r'(\d+)\.0+([^\d])', r'\1\2', string)
string = re.sub(r'(\d+)\.0+$', r'\1', string)
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == '.':
string = '0' + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split('=')) == 2:
if len(string.split('=')[0]) <= 2:
string = string.split('=')[1]
string = _fix_sqrt(string)
string = string.replace(' ', '')
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def last_boxed_only_string(string):
idx = string.rfind('\\boxed')
if idx < 0:
idx = string.rfind('\\fbox')
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == '{':
num_left_braces_open += 1
if string[i] == '}':
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
def extract_answer(pred_str):
if 'boxed' not in pred_str:
return ''
answer = pred_str.split('boxed')[-1]
if len(answer) == 0:
return ''
elif (answer[0] == '{'):
stack = 1
a = ''
for c in answer[1:]:
if (c == '{'):
stack += 1
a += c
elif (c == '}'):
stack -= 1
if (stack == 0): break
a += c
else:
a += c
else:
a = answer.split('$')[0].strip()
pred = a.split('\n')[0]
if pred != '' and pred[0] == ':':
pred = pred[1:]
if pred != '' and pred[-1] == '.':
pred = pred[:-1]
if pred != '' and pred[-1] == '/':
pred = pred[:-1]
pred = strip_string(pred)
return pred
def is_digit(s):
try:
float(str(s).replace(',', ''))
return True
except ValueError:
return False
def math_equal(
prediction: Union[bool, float, str],
reference: Union[float, str],
include_percentage: bool = True,
is_close: bool = True,
tolerance: float = 1e-4,
timeout: bool = False,
) -> bool:
"""Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal
"""
try: # 1. numerical equal
if is_digit(prediction) and is_digit(reference):
prediction = float(str(prediction).replace(',', ''))
reference = float(str(reference).replace(',', ''))
# number questions
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
try:
if is_close:
if isclose(item, prediction, rel_tol=tolerance):
return True
else:
if item == prediction:
return True
except Exception:
continue
return False
except Exception:
pass
if not prediction and prediction not in [0, False]:
return False
# 2. symbolic equal
reference = str(reference).strip()
prediction = str(prediction).strip()
## deal with [], (), {}
pred_str, ref_str = prediction, reference
if (prediction.startswith('[') and prediction.endswith(']')
and not reference.startswith('(')) or (
prediction.startswith('(') and prediction.endswith(')')
and not reference.startswith('[')):
pred_str = pred_str.strip('[]()')
ref_str = ref_str.strip('[]()')
for s in ['{', '}', '(', ')']:
ref_str = ref_str.replace(s, '')
pred_str = pred_str.replace(s, '')
if pred_str == ref_str:
return True
## [a, b] vs. [c, d], return a==c and b==d
if ((prediction.startswith('[') and prediction.endswith(']')) and
(reference.startswith('[') and reference.endswith(']'))
or (prediction.startswith('(') and prediction.endswith(')')) and
(reference.startswith('(') and reference.endswith(')'))):
pred_parts = prediction[1:-1].split(',')
ref_parts = reference[1:-1].split(',')
if len(pred_parts) == len(ref_parts):
if all([
math_equal(pred_parts[i], ref_parts[i], include_percentage,
is_close) for i in range(len(pred_parts))
]):
return True
# symbolic equal with sympy
if timeout:
if call_with_timeout(symbolic_equal_process, prediction, reference):
return True
else:
if symbolic_equal(prediction, reference):
return True
return False
def math_equal_process(param):
return math_equal(param[-2], param[-1])
def symbolic_equal(a, b):
def _parse(s):
for f in [parse_latex, parse_expr]:
try:
return f(s)
except Exception:
pass
return s
a = _parse(a)
b = _parse(b)
try:
if simplify(a - b) == 0:
return True
except Exception:
pass
try:
if isclose(N(a), N(b), rel_tol=1e-3):
return True
except Exception:
pass
return False
def symbolic_equal_process(a, b, output_queue):
result = symbolic_equal(a, b)
output_queue.put(result)
def call_with_timeout(func, *args, timeout=1, **kwargs):
output_queue = multiprocessing.Queue()
process_args = args + (output_queue, )
process = multiprocessing.Process(target=func,
args=process_args,
kwargs=kwargs)
process.start()
process.join(timeout)
if process.is_alive():
process.terminate()
process.join()
return False
return output_queue.get()
def init_agent(backend: str, model_path: str, tp: int, **kwargs):
if backend == 'lmdeploy':
model = LMDeployPipeline(path=model_path,
meta_template=INTERNLM2_META,
tp=tp,
**kwargs)
elif backend == 'hf':
model = HFTransformer(path=model_path,
meta_template=INTERNLM2_META,
**kwargs)
else:
raise NotImplementedError
agent = Internlm2Agent(llm=model,
protocol=Internlm2Protocol(
meta_prompt=None,
interpreter_prompt=DEFAULT_PROMPT),
interpreter_executor=ActionExecutor(
actions=[get_tool('IPythonInteractive')]))
return agent
def predict(args):
if args.dataset == 'gsm8k':
def process(d, k):
d['answer'] = re.sub(r'#### (.+)', r'The answer is \1',
re.sub(r'<<.*?>>', '',
d['answer'])).replace('$', '')
d['idx'] = k
d['query'] = d['question'].replace('$', '')
d['gt'] = re.search('The answer is (.+)', d['answer'])[1]
d['pred'], d['steps'], d['error'] = [], [], None
return d
dataset = load_dataset('gsm8k', 'main',
split='test').map(process, True)
elif args.dataset == 'math':
def process(d, k):
d['idx'] = k
d['query'] = d['problem']
gt = extract_answer(d['solution'])
if '\\boxed{90\\text{ square\nunits}}' in d['solution']:
gt = '90'
elif '$6$ is our answer' in d['solution']:
gt = '6'
elif gt.startswith('x\\in'):
gt = gt[len('x\\in'):]
gt = strip_string(gt)
d['gt'] = gt
d['pred'], d['steps'] = [], []
d['error'] = None
return d
dataset = load_dataset('lighteval/MATH',
split='test').map(process, True)
else:
raise NotImplementedError
agent = init_agent(
backend=args.backend,
model_path=args.model_path,
tp=args.tp,
temperature=args.temperature,
stop_words=args.stop_words,
top_p=args.top_p,
top_k=args.top_k,
max_new_tokens=args.max_new_tokens,
)
with jsonlines.open(args.output_path, 'w') as f:
for item in tqdm(
dataset if not args.debug else dataset.select(range(50))):
try:
ret = agent.chat(item['query'])
item['steps'] = ret.inner_steps
lang = [
step for step in item['steps']
if step['role'] == 'language'
]
item['pred'].append('😭' if not lang else
extract_answer(lang[-1]['content']) or '😭')
agent._interpreter_executor.actions[
'IPythonInteractive'].reset()
except Exception as e:
err = str(traceback.format_exc())
print(f'Error processing index {item["idx"]}: {e}\n{err}')
item['error'] = err
f.write(item)
def evaluate(args):
samples = [sample for sample in jsonlines.open(args.output_path)]
scores = []
timeout_cnt = 0
with ProcessPool() as pool:
future = pool.map(
math_equal_process,
[(idx, pred, sample['gt']) for idx, sample in enumerate(samples)
for pred in sample['pred']],
timeout=20,
)
iterator = future.result()
with tqdm(total=len(samples), desc='Evaluate') as progress_bar:
while True:
try:
result = next(iterator)
scores.append(result)
except StopIteration:
break
except TimeoutError as error:
print(error)
scores.append(False)
timeout_cnt += 1
except Exception as error:
print(error.__traceback__)
sys.exit()
progress_bar.update(1)
idx = 0
score_mat = []
for sample in samples:
sample['score'] = scores[idx:idx + len(sample['pred'])]
assert len(sample['score']) == len(sample['pred'])
score_mat.append(sample['score'])
idx += len(sample['pred'])
max_len = max([len(s) for s in score_mat])
for i, s in enumerate(score_mat):
if len(s) < max_len:
score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad
# output mean of each column of scores
col_means = np.array(score_mat).mean(axis=0)
mean_score = list(np.round(col_means * 100, decimals=1))
result_str = f'Num samples: {len(samples)}\n' \
f'Num scores: {len(scores)}\n' \
f'Sum scores: {sum(scores)}\n' \
f'Timeout samples: {timeout_cnt}\n' \
f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \
f'Mean score: {mean_score}\n'
# each type score
if 'type' in samples[0]:
type_scores = {}
for sample in samples:
if sample['type'] not in type_scores:
type_scores[sample['type']] = []
type_scores[sample['type']].append(sample['score'][-1])
type_scores = {
k: np.round(np.array(v).mean() * 100, decimals=1)
for k, v in type_scores.items()
}
type_scores = {
k: v
for k, v in sorted(type_scores.items(), key=lambda item: item[0])
}
result_str += f'Type scores: {type_scores}\n'
print(result_str)
if __name__ == '__main__':
args = parse_args()
if args.do_infer and os.path.exists(
args.output_path) and not args.overwrite:
args.do_infer = False
print(f'File {args.output_path} already exists. '
f'Please add the `--overwrite` flag if needed.')
if args.do_infer:
predict(args)
if args.do_eval:
if not args.do_infer:
evaluate(args)
else:
import subprocess
res = subprocess.run(
[
sys.executable, __file__, '--output_path',
args.output_path, '--no-do_infer', '--do_eval'
],
capture_output=True,
text=True,
check=True,
)
print(res.stdout)