diff --git a/README.md b/README.md index c223e15..3227c57 100644 --- a/README.md +++ b/README.md @@ -261,7 +261,7 @@ To learn more about data contamination assessment, please check the [contaminati ### Agent Evaluation - To evaluate tool utilization, please refer to [T-Eval](https://github.com/open-compass/T-Eval). -- For code interpreter evaluation, use the [gsm-8k-agent](https://github.com/open-compass/opencompass/blob/main/configs/datasets/gsm8k/gsm8k_agent_gen_be1606.py) provided in the repository. Additionally, you need to install [Lagent](https://github.com/InternLM/lagent). +- For code interpreter evaluation, use the [Math Agent Evaluation](agent/README.md) provided in the repository. ### Subjective Evaluation diff --git a/agent/README.md b/agent/README.md index 693c841..1ed009c 100644 --- a/agent/README.md +++ b/agent/README.md @@ -10,13 +10,78 @@ InternLM2-Chat, open sourced on January 17, 2024, further enhances its capabilit The results of InternLM2-Chat-20B on math code interpreter is as below: -| | GSM8K | MATH | -| :--------------------------------------: | :---: | :--: | -| InternLM2-Chat-20B | 79.6 | 32.5 | -| InternLM2-Chat-20B with Code Interpreter | 84.5 | 51.2 | -| ChatGPT (GPT-3.5) | 78.2 | 28.0 | -| GPT-4 | 91.4 | 45.8 | +| | GSM8K | MATH | +| :--------------------------------------: | :---: | :---: | +| InternLM2-Chat-20B | 79.6 | 32.5 | +| InternLM2-Chat-20B with Code Interpreter | 84.5 | 51.2 | +| ChatGPT (GPT-3.5) | 78.2 | 28.0 | +| GPT-4 | 91.4 | 45.8 | ## Usages -We offer examples using [Lagent](lagent.md) to build agents based on InternLM2-Chat to call code interpreter or search API. Additionally, we provide an example code using [PAL to evaluate GSM8K math problems](pal_inference.md) with InternLM-Chat-7B. +We offer an example using [Lagent](lagent.md) to build agents based on InternLM2-Chat to call the code interpreter. Firstly install the extra dependencies: + +```bash +pip install -r requirements.txt +``` + +Run the following script to perform inference and evaluation on GSM8K and MATH test. + +```bash +python streaming_inference.py \ + --backend=lmdeploy \ # For HuggingFace models: hf + --model_path=internlm/internlm2-chat-20b \ + --tp=2 \ + --temperature=0.0 \ + --dataset=math \ + --output_path=math_lmdeploy.jsonl \ + --do_eval +``` + +`output_path` is a jsonl format file to save the inference results. Each line is like + +```json +{ + "idx": 41, + "query": "The point $(a, b)$ lies on the line with the equation $3x + 2y = 12.$ When $a = 4$, what is the value of $b$?", + "gt": "0", + "pred": ["0"], + "steps": [ + { + "role": "language", + "content": "" + }, + { + "role": "tool", + "content": { + "name": "IPythonInteractive", + "parameters": { + "command": "```python\nfrom sympy import symbols, solve\n\ndef find_b():\n x, y = symbols('x y')\n equation = 3*x + 2*y - 12\n b = solve(equation.subs(x, 4), y)[0]\n\n return b\n\nresult = find_b()\nprint(result)\n```" + } + }, + "name": "interpreter" + }, + { + "role": "environment", + "content": "0", + "name": "interpreter" + }, + { + "role": "language", + "content": "The value of $b$ when $a = 4$ is $\\boxed{0}$." + } + ], + "error": null +} +``` + +Once it is prepared, just skip the inference stage as follows. + +```bash +python streaming_inference.py \ + --output_path=math_lmdeploy.jsonl \ + --no-do_infer \ + --do_eval +``` + +Please refer to [`streaming_inference.py`](streaming_inference.py) for more information about the arguments. diff --git a/agent/README_zh-CN.md b/agent/README_zh-CN.md index 3b198b1..a5d41ee 100644 --- a/agent/README_zh-CN.md +++ b/agent/README_zh-CN.md @@ -10,13 +10,78 @@ InternLM2-Chat 进一步提高了它在代码解释和通用工具调用方面 以下是 InternLM2-Chat-20B 在数学代码解释器上的结果。 -| | GSM8K | MATH | -| :---------------------------------: | :---: | :--: | -| InternLM2-Chat-20B 单纯依靠内在能力 | 79.6 | 32.5 | -| InternLM2-Chat-20B 配合代码解释器 | 84.5 | 51.2 | -| ChatGPT (GPT-3.5) | 78.2 | 28.0 | -| GPT-4 | 91.4 | 45.8 | +| | GSM8K | MATH | +| :---------------------------------: | :---: | :---: | +| InternLM2-Chat-20B 单纯依靠内在能力 | 79.6 | 32.5 | +| InternLM2-Chat-20B 配合代码解释器 | 84.5 | 51.2 | +| ChatGPT (GPT-3.5) | 78.2 | 28.0 | +| GPT-4 | 91.4 | 45.8 | ## 体验 -我们提供了使用 [Lagent](lagent_zh-CN.md) 来基于 InternLM2-Chat 构建智能体调用代码解释器或者搜索等工具的例子。同时,我们也提供了采用 [PAL 评测 GSM8K 数学题](pal_inference_zh-CN.md) InternLM-Chat-7B 的样例。 +我们提供了使用 [Lagent](lagent_zh-CN.md) 来基于 InternLM2-Chat 构建智能体调用代码解释器的例子。首先安装额外依赖: + +```bash +pip install -r requirements.txt +``` + +运行以下脚本在 GSM8K 和 MATH 测试集上进行推理和评估: + +```bash +python streaming_inference.py \ + --backend=lmdeploy \ # For HuggingFace models: hf + --model_path=internlm/internlm2-chat-20b \ + --tp=2 \ + --temperature=0.0 \ + --dataset=math \ + --output_path=math_lmdeploy.jsonl \ + --do_eval +``` + +`output_path` 是一个存储推理结果的 jsonl 格式文件,每行形如: + +```json +{ + "idx": 41, + "query": "The point $(a, b)$ lies on the line with the equation $3x + 2y = 12.$ When $a = 4$, what is the value of $b$?", + "gt": "0", + "pred": ["0"], + "steps": [ + { + "role": "language", + "content": "" + }, + { + "role": "tool", + "content": { + "name": "IPythonInteractive", + "parameters": { + "command": "```python\nfrom sympy import symbols, solve\n\ndef find_b():\n x, y = symbols('x y')\n equation = 3*x + 2*y - 12\n b = solve(equation.subs(x, 4), y)[0]\n\n return b\n\nresult = find_b()\nprint(result)\n```" + } + }, + "name": "interpreter" + }, + { + "role": "environment", + "content": "0", + "name": "interpreter" + }, + { + "role": "language", + "content": "The value of $b$ when $a = 4$ is $\\boxed{0}$." + } + ], + "error": null +} +``` + +如果已经准备好了该文件,可直接跳过推理阶段进行评估: + +```bash +python streaming_inference.py \ + --output_path=math_lmdeploy.jsonl \ + --no-do_infer \ + --do_eval +``` + +请参考 [`streaming_inference.py`](streaming_inference.py) 获取更多关于参数的信息。 diff --git a/agent/requirements.txt b/agent/requirements.txt new file mode 100644 index 0000000..6dbeb86 --- /dev/null +++ b/agent/requirements.txt @@ -0,0 +1,10 @@ +lmdeploy>=0.2.2 +datasets +tqdm +numpy +pebble +jsonlines +sympy==1.12 +antlr4-python3-runtime==4.11.0 +lagent +einops diff --git a/agent/streaming_inference.py b/agent/streaming_inference.py new file mode 100644 index 0000000..fc36b34 --- /dev/null +++ b/agent/streaming_inference.py @@ -0,0 +1,681 @@ +# flake8: noqa +# isort: skip_file + +# This logic is modified from ToRA: +# - https://github.com/microsoft/ToRA +# +# Copyright (c) Microsoft Corporation. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE + +import argparse +import multiprocessing +import os +import re +import sys +import traceback +from math import isclose +from typing import Union + +import jsonlines +import numpy as np +from datasets import load_dataset +from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer, + Internlm2Agent, Internlm2Protocol, LMDeployPipeline, + get_tool) +from pebble import ProcessPool +from sympy import N, simplify +from sympy.parsing.latex import parse_latex +from sympy.parsing.sympy_parser import parse_expr +from tqdm import tqdm + +DEFAULT_PROMPT = ( + 'Integrate step-by-step reasoning and Python code to solve math problems ' + 'using the following guidelines:\n' + '- Just write jupyter code to solve the problem without giving your thought;\n' + r"- Present the final result in LaTeX using a '\boxed{{}}' without any " + 'units. \n') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Math Code Interpreter') + parser.add_argument('--backend', + type=str, + default='lmdeploy', + help='Which inference framework to use.', + choices=['lmdeploy', 'hf']) + parser.add_argument( + '--model_path', + type=str, + default='internlm/internlm2-chat-7b', + help='Path or name to the model, could be HuggingFace model specifier.' + ) + parser.add_argument( + '--output_path', + type=str, + required=True, + help='Path to save inference results to, should be a `jsonl` file') + parser.add_argument('--dataset', + type=str, + default='math', + choices=['gsm8k', 'math'], + help='Dataset for inference') + parser.add_argument( + '--tp', + type=int, + default=1, + help='Number of tensor parallelism. It may be required in LMDelpoy.') + parser.add_argument('--temperature', + type=float, + default=0.1, + help='Temperature in next token prediction') + parser.add_argument('--top_p', + type=float, + default=0.8, + help='Parameter for Top-P Sampling.') + parser.add_argument('--top_k', + type=int, + default=None, + help='Parameter for Top-K Sampling.') + parser.add_argument('--stop_words', + type=str, + default=['<|action_end|>', '<|im_end|>'], + action='append', + help='Stop words') + parser.add_argument('--max_new_tokens', + type=int, + default=512, + help='Number of maximum generated tokens.') + parser.add_argument( + '--do_infer', + default=True, + action=argparse.BooleanOptionalAction, # python > 3.8 + help='Whether to launch model inference.') + # parser.add_argument( + # '--no-do_infer', + # dest='do_infer', + # action='store_false', + # help='Disable the inference.' + # ) + parser.add_argument('--do_eval', + default=False, + action='store_true', + help='Whether to evaluate the inference results.') + parser.add_argument('--overwrite', + default=False, + action='store_true', + help='Whether to overwrite the existing result file') + parser.add_argument('--debug', + default=False, + action='store_true', + help='Only infer the first 50 samples') + return parser.parse_args() + + +def _fix_fracs(string): + substrs = string.split('\\frac') + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += '\\frac' + if len(substr) > 0 and substr[0] == '{': + new_str += substr + else: + try: + assert len(substr) >= 2 + except Exception: + return string + a = substr[0] + b = substr[1] + if b != '{': + if len(substr) > 2: + post_substr = substr[2:] + new_str += '{' + a + '}{' + b + '}' + post_substr + else: + new_str += '{' + a + '}{' + b + '}' + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += '{' + a + '}' + b + post_substr + else: + new_str += '{' + a + '}' + b + string = new_str + return string + + +def _fix_a_slash_b(string): + if len(string.split('/')) != 2: + return string + a = string.split('/')[0] + b = string.split('/')[1] + try: + if 'sqrt' not in a: + a = int(a) + if 'sqrt' not in b: + b = int(b) + assert string == '{}/{}'.format(a, b) + new_string = '\\frac{' + str(a) + '}{' + str(b) + '}' + return new_string + except Exception: + return string + + +def _fix_sqrt(string): + _string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string) + return _string + + +def strip_string(string): + string = str(string).strip() + # linebreaks + string = string.replace('\n', '') + + # right "." + string = string.rstrip('.') + + # remove inverse spaces + string = string.replace('\\!', '') + string = string.replace('\\ ', '') + + # replace \\ with \ + string = string.replace('\\\\', '\\') + string = string.replace('\\\\', '\\') + + # replace tfrac and dfrac with frac + string = string.replace('tfrac', 'frac') + string = string.replace('dfrac', 'frac') + + # remove \left and \right + string = string.replace('\\left', '') + string = string.replace('\\right', '') + + # Remove unit: miles, dollars if after is not none + _string = re.sub(r'\\text{.*?}$', '', string).strip() + if _string != '' and _string != string: + # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) + string = _string + + # Remove circ (degrees) + string = string.replace('^{\\circ}', '') + string = string.replace('^\\circ', '') + + # remove dollar signs + string = string.replace('\\$', '') + string = string.replace('$', '') + + string = string.replace('\\text', '') + string = string.replace('x\\in', '') + + # remove percentage + string = string.replace('\\%', '') + string = string.replace('\%', '') + string = string.replace('%', '') + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(' .', ' 0.') + string = string.replace('{.', '{0.') + + # cdot + string = string.replace('\\cdot', '') + + # inf + string = string.replace('infinity', '\\infty') + if '\\infty' not in string: + string = string.replace('inf', '\\infty') + string = string.replace('+\\inity', '\\infty') + + # and + string = string.replace('and', '') + string = string.replace('\\mathbf', '') + + # use regex to remove \mbox{...} + string = re.sub(r'\\mbox{.*?}', '', string) + + # quote + string.replace("'", '') + string.replace('"', '') + + # i, j + if 'j' in string and 'i' not in string: + string = string.replace('j', 'i') + + # replace a.000b where b is not number or b is end, with ab, use regex + string = re.sub(r'(\d+)\.0+([^\d])', r'\1\2', string) + string = re.sub(r'(\d+)\.0+$', r'\1', string) + + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == '.': + string = '0' + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split('=')) == 2: + if len(string.split('=')[0]) <= 2: + string = string.split('=')[1] + + string = _fix_sqrt(string) + string = string.replace(' ', '') + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + return string + + +def last_boxed_only_string(string): + idx = string.rfind('\\boxed') + if idx < 0: + idx = string.rfind('\\fbox') + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == '{': + num_left_braces_open += 1 + if string[i] == '}': + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + retval = string[idx:right_brace_idx + 1] + + return retval + + +def extract_answer(pred_str): + if 'boxed' not in pred_str: + return '' + answer = pred_str.split('boxed')[-1] + if len(answer) == 0: + return '' + elif (answer[0] == '{'): + stack = 1 + a = '' + for c in answer[1:]: + if (c == '{'): + stack += 1 + a += c + elif (c == '}'): + stack -= 1 + if (stack == 0): break + a += c + else: + a += c + else: + a = answer.split('$')[0].strip() + + pred = a.split('\n')[0] + if pred != '' and pred[0] == ':': + pred = pred[1:] + if pred != '' and pred[-1] == '.': + pred = pred[:-1] + if pred != '' and pred[-1] == '/': + pred = pred[:-1] + pred = strip_string(pred) + return pred + + +def is_digit(s): + try: + float(str(s).replace(',', '')) + return True + except ValueError: + return False + + +def math_equal( + prediction: Union[bool, float, str], + reference: Union[float, str], + include_percentage: bool = True, + is_close: bool = True, + tolerance: float = 1e-4, + timeout: bool = False, +) -> bool: + """Exact match of math if and only if: + + 1. numerical equal: both can convert to float and are equal + 2. symbolic equal: both can convert to sympy expression and are equal + """ + try: # 1. numerical equal + if is_digit(prediction) and is_digit(reference): + prediction = float(str(prediction).replace(',', '')) + reference = float(str(reference).replace(',', '')) + # number questions + if include_percentage: + gt_result = [reference / 100, reference, reference * 100] + else: + gt_result = [reference] + for item in gt_result: + try: + if is_close: + if isclose(item, prediction, rel_tol=tolerance): + return True + else: + if item == prediction: + return True + except Exception: + continue + return False + except Exception: + pass + + if not prediction and prediction not in [0, False]: + return False + + # 2. symbolic equal + reference = str(reference).strip() + prediction = str(prediction).strip() + + ## deal with [], (), {} + pred_str, ref_str = prediction, reference + if (prediction.startswith('[') and prediction.endswith(']') + and not reference.startswith('(')) or ( + prediction.startswith('(') and prediction.endswith(')') + and not reference.startswith('[')): + pred_str = pred_str.strip('[]()') + ref_str = ref_str.strip('[]()') + for s in ['{', '}', '(', ')']: + ref_str = ref_str.replace(s, '') + pred_str = pred_str.replace(s, '') + if pred_str == ref_str: + return True + + ## [a, b] vs. [c, d], return a==c and b==d + if ((prediction.startswith('[') and prediction.endswith(']')) and + (reference.startswith('[') and reference.endswith(']')) + or (prediction.startswith('(') and prediction.endswith(')')) and + (reference.startswith('(') and reference.endswith(')'))): + pred_parts = prediction[1:-1].split(',') + ref_parts = reference[1:-1].split(',') + if len(pred_parts) == len(ref_parts): + if all([ + math_equal(pred_parts[i], ref_parts[i], include_percentage, + is_close) for i in range(len(pred_parts)) + ]): + return True + + # symbolic equal with sympy + if timeout: + if call_with_timeout(symbolic_equal_process, prediction, reference): + return True + else: + if symbolic_equal(prediction, reference): + return True + + return False + + +def math_equal_process(param): + return math_equal(param[-2], param[-1]) + + +def symbolic_equal(a, b): + + def _parse(s): + for f in [parse_latex, parse_expr]: + try: + return f(s) + except Exception: + pass + return s + + a = _parse(a) + b = _parse(b) + + try: + if simplify(a - b) == 0: + return True + except Exception: + pass + + try: + if isclose(N(a), N(b), rel_tol=1e-3): + return True + except Exception: + pass + return False + + +def symbolic_equal_process(a, b, output_queue): + result = symbolic_equal(a, b) + output_queue.put(result) + + +def call_with_timeout(func, *args, timeout=1, **kwargs): + output_queue = multiprocessing.Queue() + process_args = args + (output_queue, ) + process = multiprocessing.Process(target=func, + args=process_args, + kwargs=kwargs) + process.start() + process.join(timeout) + + if process.is_alive(): + process.terminate() + process.join() + return False + + return output_queue.get() + + +def init_agent(backend: str, model_path: str, tp: int, **kwargs): + if backend == 'lmdeploy': + model = LMDeployPipeline(path=model_path, + meta_template=INTERNLM2_META, + tp=tp, + **kwargs) + elif backend == 'hf': + model = HFTransformer(path=model_path, + meta_template=INTERNLM2_META, + **kwargs) + else: + raise NotImplementedError + + agent = Internlm2Agent(llm=model, + protocol=Internlm2Protocol( + meta_prompt=None, + interpreter_prompt=DEFAULT_PROMPT), + interpreter_executor=ActionExecutor( + actions=[get_tool('IPythonInteractive')])) + return agent + + +def predict(args): + if args.dataset == 'gsm8k': + + def process(d, k): + d['answer'] = re.sub(r'#### (.+)', r'The answer is \1', + re.sub(r'<<.*?>>', '', + d['answer'])).replace('$', '') + d['idx'] = k + d['query'] = d['question'].replace('$', '') + d['gt'] = re.search('The answer is (.+)', d['answer'])[1] + d['pred'], d['steps'], d['error'] = [], [], None + return d + + dataset = load_dataset('gsm8k', 'main', + split='test').map(process, True) + + elif args.dataset == 'math': + + def process(d, k): + d['idx'] = k + d['query'] = d['problem'] + gt = extract_answer(d['solution']) + if '\\boxed{90\\text{ square\nunits}}' in d['solution']: + gt = '90' + elif '$6$ is our answer' in d['solution']: + gt = '6' + elif gt.startswith('x\\in'): + gt = gt[len('x\\in'):] + gt = strip_string(gt) + d['gt'] = gt + d['pred'], d['steps'] = [], [] + d['error'] = None + return d + + dataset = load_dataset('lighteval/MATH', + split='test').map(process, True) + + else: + raise NotImplementedError + + agent = init_agent( + backend=args.backend, + model_path=args.model_path, + tp=args.tp, + temperature=args.temperature, + stop_words=args.stop_words, + top_p=args.top_p, + top_k=args.top_k, + max_new_tokens=args.max_new_tokens, + ) + with jsonlines.open(args.output_path, 'w') as f: + for item in tqdm( + dataset if not args.debug else dataset.select(range(50))): + try: + ret = agent.chat(item['query']) + item['steps'] = ret.inner_steps + + lang = [ + step for step in item['steps'] + if step['role'] == 'language' + ] + item['pred'].append('😭' if not lang else + extract_answer(lang[-1]['content']) or '😭') + agent._interpreter_executor.actions[ + 'IPythonInteractive'].reset() + except Exception as e: + err = str(traceback.format_exc()) + print(f'Error processing index {item["idx"]}: {e}\n{err}') + item['error'] = err + f.write(item) + + +def evaluate(args): + samples = [sample for sample in jsonlines.open(args.output_path)] + scores = [] + timeout_cnt = 0 + with ProcessPool() as pool: + future = pool.map( + math_equal_process, + [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) + for pred in sample['pred']], + timeout=20, + ) + iterator = future.result() + with tqdm(total=len(samples), desc='Evaluate') as progress_bar: + while True: + try: + result = next(iterator) + scores.append(result) + except StopIteration: + break + except TimeoutError as error: + print(error) + scores.append(False) + timeout_cnt += 1 + except Exception as error: + print(error.__traceback__) + sys.exit() + progress_bar.update(1) + + idx = 0 + score_mat = [] + for sample in samples: + sample['score'] = scores[idx:idx + len(sample['pred'])] + assert len(sample['score']) == len(sample['pred']) + score_mat.append(sample['score']) + idx += len(sample['pred']) + + max_len = max([len(s) for s in score_mat]) + + for i, s in enumerate(score_mat): + if len(s) < max_len: + score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad + + # output mean of each column of scores + col_means = np.array(score_mat).mean(axis=0) + mean_score = list(np.round(col_means * 100, decimals=1)) + + result_str = f'Num samples: {len(samples)}\n' \ + f'Num scores: {len(scores)}\n' \ + f'Sum scores: {sum(scores)}\n' \ + f'Timeout samples: {timeout_cnt}\n' \ + f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \ + f'Mean score: {mean_score}\n' + + # each type score + if 'type' in samples[0]: + type_scores = {} + for sample in samples: + if sample['type'] not in type_scores: + type_scores[sample['type']] = [] + type_scores[sample['type']].append(sample['score'][-1]) + type_scores = { + k: np.round(np.array(v).mean() * 100, decimals=1) + for k, v in type_scores.items() + } + type_scores = { + k: v + for k, v in sorted(type_scores.items(), key=lambda item: item[0]) + } + result_str += f'Type scores: {type_scores}\n' + + print(result_str) + + +if __name__ == '__main__': + args = parse_args() + if args.do_infer and os.path.exists( + args.output_path) and not args.overwrite: + args.do_infer = False + print(f'File {args.output_path} already exists. ' + f'Please add the `--overwrite` flag if needed.') + if args.do_infer: + predict(args) + if args.do_eval: + if not args.do_infer: + evaluate(args) + else: + import subprocess + + res = subprocess.run( + [ + sys.executable, __file__, '--output_path', + args.output_path, '--no-do_infer', '--do_eval' + ], + capture_output=True, + text=True, + check=True, + ) + print(res.stdout)