mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			693 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			693 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			Python
		
	
	
| # flake8: noqa
 | |
| # isort: skip_file
 | |
| 
 | |
| # This logic is modified from ToRA:
 | |
| #   - https://github.com/microsoft/ToRA
 | |
| #
 | |
| # Copyright (c) Microsoft Corporation.
 | |
| #
 | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy
 | |
| # of this software and associated documentation files (the "Software"), to deal
 | |
| # in the Software without restriction, including without limitation the rights
 | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | |
| # copies of the Software, and to permit persons to whom the Software is
 | |
| # furnished to do so, subject to the following conditions:
 | |
| #
 | |
| # The above copyright notice and this permission notice shall be included in all
 | |
| # copies or substantial portions of the Software.
 | |
| #
 | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | |
| # SOFTWARE
 | |
| 
 | |
| import argparse
 | |
| import multiprocessing
 | |
| import os
 | |
| import re
 | |
| import sys
 | |
| import traceback
 | |
| from math import isclose, ceil
 | |
| from typing import Union
 | |
| 
 | |
| import jsonlines
 | |
| import numpy as np
 | |
| from datasets import load_dataset
 | |
| from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer,
 | |
|                     Internlm2Agent, Internlm2Protocol, LMDeployPipeline,
 | |
|                     IPythonInteractiveManager)
 | |
| from pebble import ProcessPool
 | |
| from sympy import N, simplify
 | |
| from sympy.parsing.latex import parse_latex
 | |
| from sympy.parsing.sympy_parser import parse_expr
 | |
| from tqdm import tqdm
 | |
| 
 | |
| # --------------------- modify the system prompt as needed ---------------------
 | |
| DEFAULT_PROMPT = (
 | |
|     'Integrate step-by-step reasoning and Python code to solve math problems '
 | |
|     'using the following guidelines:\n'
 | |
|     '- Analyze the question and write jupyter code to solve the problem;\n'
 | |
|     r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
 | |
|     'units. \n')
 | |
| # ------------------------------------------------------------------------------
 | |
| 
 | |
| 
 | |
| def parse_args():
 | |
|     parser = argparse.ArgumentParser(description='Math Code Interpreter')
 | |
|     parser.add_argument('--backend',
 | |
|                         type=str,
 | |
|                         default='lmdeploy',
 | |
|                         help='Which inference framework to use.',
 | |
|                         choices=['lmdeploy', 'hf'])
 | |
|     parser.add_argument(
 | |
|         '--model_path',
 | |
|         type=str,
 | |
|         default='internlm/internlm2-chat-7b',
 | |
|         help='Path or name to the model, could be HuggingFace model specifier.'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--output_path',
 | |
|         type=str,
 | |
|         required=True,
 | |
|         help='Path to save inference results to, should be a `jsonl` file')
 | |
|     parser.add_argument('--batch_size',
 | |
|                         type=int,
 | |
|                         default=100,
 | |
|                         help='Agent inference batch size')
 | |
|     parser.add_argument(
 | |
|         '--max_turn',
 | |
|         type=int,
 | |
|         default=5,
 | |
|         help=
 | |
|         'Maximum number of interaction rounds between the agent and environment'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--tp',
 | |
|         type=int,
 | |
|         default=1,
 | |
|         help='Number of tensor parallelism. It may be required in LMDelpoy.')
 | |
|     parser.add_argument('--temperature',
 | |
|                         type=float,
 | |
|                         default=0.1,
 | |
|                         help='Temperature in next token prediction')
 | |
|     parser.add_argument('--top_p',
 | |
|                         type=float,
 | |
|                         default=0.8,
 | |
|                         help='Parameter for Top-P Sampling.')
 | |
|     parser.add_argument('--top_k',
 | |
|                         type=int,
 | |
|                         default=40,
 | |
|                         help='Parameter for Top-K Sampling.')
 | |
|     parser.add_argument('--stop_words',
 | |
|                         type=str,
 | |
|                         default=['<|action_end|>', '<|im_end|>'],
 | |
|                         action='append',
 | |
|                         help='Stop words')
 | |
|     parser.add_argument('--max_new_tokens',
 | |
|                         type=int,
 | |
|                         default=512,
 | |
|                         help='Number of maximum generated tokens.')
 | |
|     parser.add_argument(
 | |
|         '--do_infer',
 | |
|         default=True,
 | |
|         action=argparse.BooleanOptionalAction,  # python > 3.8
 | |
|         help='Whether to launch model inference.')
 | |
|     # parser.add_argument(
 | |
|     #     '--no-do_infer',
 | |
|     #     dest='do_infer',
 | |
|     #     action='store_false',
 | |
|     #     help='Disable the inference.'
 | |
|     # )
 | |
|     parser.add_argument('--do_eval',
 | |
|                         default=False,
 | |
|                         action='store_true',
 | |
|                         help='Whether to evaluate the inference results.')
 | |
|     parser.add_argument('--overwrite',
 | |
|                         default=False,
 | |
|                         action='store_true',
 | |
|                         help='Whether to overwrite the existing result file')
 | |
|     return parser.parse_args()
 | |
| 
 | |
| 
 | |
| def _fix_fracs(string):
 | |
|     substrs = string.split('\\frac')
 | |
|     new_str = substrs[0]
 | |
|     if len(substrs) > 1:
 | |
|         substrs = substrs[1:]
 | |
|         for substr in substrs:
 | |
|             new_str += '\\frac'
 | |
|             if len(substr) > 0 and substr[0] == '{':
 | |
|                 new_str += substr
 | |
|             else:
 | |
|                 try:
 | |
|                     assert len(substr) >= 2
 | |
|                 except Exception:
 | |
|                     return string
 | |
|                 a = substr[0]
 | |
|                 b = substr[1]
 | |
|                 if b != '{':
 | |
|                     if len(substr) > 2:
 | |
|                         post_substr = substr[2:]
 | |
|                         new_str += '{' + a + '}{' + b + '}' + post_substr
 | |
|                     else:
 | |
|                         new_str += '{' + a + '}{' + b + '}'
 | |
|                 else:
 | |
|                     if len(substr) > 2:
 | |
|                         post_substr = substr[2:]
 | |
|                         new_str += '{' + a + '}' + b + post_substr
 | |
|                     else:
 | |
|                         new_str += '{' + a + '}' + b
 | |
|     string = new_str
 | |
|     return string
 | |
| 
 | |
| 
 | |
| def _fix_a_slash_b(string):
 | |
|     if len(string.split('/')) != 2:
 | |
|         return string
 | |
|     a = string.split('/')[0]
 | |
|     b = string.split('/')[1]
 | |
|     try:
 | |
|         if 'sqrt' not in a:
 | |
|             a = int(a)
 | |
|         if 'sqrt' not in b:
 | |
|             b = int(b)
 | |
|         assert string == '{}/{}'.format(a, b)
 | |
|         new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
 | |
|         return new_string
 | |
|     except Exception:
 | |
|         return string
 | |
| 
 | |
| 
 | |
| def _fix_sqrt(string):
 | |
|     _string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string)
 | |
|     return _string
 | |
| 
 | |
| 
 | |
| def strip_string(string):
 | |
|     string = str(string).strip()
 | |
|     # linebreaks
 | |
|     string = string.replace('\n', '')
 | |
| 
 | |
|     # right "."
 | |
|     string = string.rstrip('.')
 | |
| 
 | |
|     # remove inverse spaces
 | |
|     string = string.replace('\\!', '')
 | |
|     string = string.replace('\\ ', '')
 | |
| 
 | |
|     # replace \\ with \
 | |
|     string = string.replace('\\\\', '\\')
 | |
|     string = string.replace('\\\\', '\\')
 | |
| 
 | |
|     # replace tfrac and dfrac with frac
 | |
|     string = string.replace('tfrac', 'frac')
 | |
|     string = string.replace('dfrac', 'frac')
 | |
| 
 | |
|     # remove \left and \right
 | |
|     string = string.replace('\\left', '')
 | |
|     string = string.replace('\\right', '')
 | |
| 
 | |
|     # Remove unit: miles, dollars if after is not none
 | |
|     _string = re.sub(r'\\text{.*?}$', '', string).strip()
 | |
|     if _string != '' and _string != string:
 | |
|         # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
 | |
|         string = _string
 | |
| 
 | |
|     # Remove circ (degrees)
 | |
|     string = string.replace('^{\\circ}', '')
 | |
|     string = string.replace('^\\circ', '')
 | |
| 
 | |
|     # remove dollar signs
 | |
|     string = string.replace('\\$', '')
 | |
|     string = string.replace('$', '')
 | |
| 
 | |
|     string = string.replace('\\text', '')
 | |
|     string = string.replace('x\\in', '')
 | |
| 
 | |
|     # remove percentage
 | |
|     string = string.replace('\\%', '')
 | |
|     string = string.replace('\%', '')
 | |
|     string = string.replace('%', '')
 | |
| 
 | |
|     # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
 | |
|     string = string.replace(' .', ' 0.')
 | |
|     string = string.replace('{.', '{0.')
 | |
| 
 | |
|     # cdot
 | |
|     string = string.replace('\\cdot', '')
 | |
| 
 | |
|     # inf
 | |
|     string = string.replace('infinity', '\\infty')
 | |
|     if '\\infty' not in string:
 | |
|         string = string.replace('inf', '\\infty')
 | |
|     string = string.replace('+\\inity', '\\infty')
 | |
| 
 | |
|     # and
 | |
|     string = string.replace('and', '')
 | |
|     string = string.replace('\\mathbf', '')
 | |
| 
 | |
|     # use regex to remove \mbox{...}
 | |
|     string = re.sub(r'\\mbox{.*?}', '', string)
 | |
| 
 | |
|     # quote
 | |
|     string.replace("'", '')
 | |
|     string.replace('"', '')
 | |
| 
 | |
|     # i, j
 | |
|     if 'j' in string and 'i' not in string:
 | |
|         string = string.replace('j', 'i')
 | |
| 
 | |
|     # replace a.000b where b is not number or b is end, with ab, use regex
 | |
|     string = re.sub(r'(\d+)\.0+([^\d])', r'\1\2', string)
 | |
|     string = re.sub(r'(\d+)\.0+$', r'\1', string)
 | |
| 
 | |
|     # if empty, return empty string
 | |
|     if len(string) == 0:
 | |
|         return string
 | |
|     if string[0] == '.':
 | |
|         string = '0' + string
 | |
| 
 | |
|     # to consider: get rid of e.g. "k = " or "q = " at beginning
 | |
|     if len(string.split('=')) == 2:
 | |
|         if len(string.split('=')[0]) <= 2:
 | |
|             string = string.split('=')[1]
 | |
| 
 | |
|     string = _fix_sqrt(string)
 | |
|     string = string.replace(' ', '')
 | |
| 
 | |
|     # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
 | |
|     string = _fix_fracs(string)
 | |
| 
 | |
|     # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
 | |
|     string = _fix_a_slash_b(string)
 | |
| 
 | |
|     return string
 | |
| 
 | |
| 
 | |
| def last_boxed_only_string(string):
 | |
|     idx = string.rfind('\\boxed')
 | |
|     if idx < 0:
 | |
|         idx = string.rfind('\\fbox')
 | |
|         if idx < 0:
 | |
|             return None
 | |
| 
 | |
|     i = idx
 | |
|     right_brace_idx = None
 | |
|     num_left_braces_open = 0
 | |
|     while i < len(string):
 | |
|         if string[i] == '{':
 | |
|             num_left_braces_open += 1
 | |
|         if string[i] == '}':
 | |
|             num_left_braces_open -= 1
 | |
|             if num_left_braces_open == 0:
 | |
|                 right_brace_idx = i
 | |
|                 break
 | |
|         i += 1
 | |
| 
 | |
|     if right_brace_idx is None:
 | |
|         retval = None
 | |
|     else:
 | |
|         retval = string[idx:right_brace_idx + 1]
 | |
| 
 | |
|     return retval
 | |
| 
 | |
| 
 | |
| def extract_answer(pred_str: str, execute: bool = False) -> str:
 | |
|     if re.search('\boxed|boxed', pred_str):
 | |
|         answer = re.split('\boxed|boxed', pred_str)[-1]
 | |
|         if len(answer) == 0:
 | |
|             return ''
 | |
|         elif (answer[0] == '{'):
 | |
|             stack = 1
 | |
|             a = ''
 | |
|             for c in answer[1:]:
 | |
|                 if (c == '{'):
 | |
|                     stack += 1
 | |
|                     a += c
 | |
|                 elif (c == '}'):
 | |
|                     stack -= 1
 | |
|                     if (stack == 0): break
 | |
|                     a += c
 | |
|                 else:
 | |
|                     a += c
 | |
|         else:
 | |
|             a = answer.split('$')[0].strip()
 | |
|     elif re.search('[Tt]he (final )?answer is:?', pred_str):
 | |
|         a = re.split('[Tt]he (final )?answer is:?',
 | |
|                      pred_str)[-1].strip().rstrip('.')
 | |
|     elif pred_str.startswith('```python') and execute:
 | |
|         # fall back to program
 | |
|         from lagent import get_tool
 | |
| 
 | |
|         a = get_tool('IPythonInteractive').exec(pred_str).value or ''
 | |
|     else:  # use the last number
 | |
|         pred = re.findall(r'-?\d*\.?\d+', pred_str.replace(',', ''))
 | |
|         if len(pred) >= 1:
 | |
|             a = pred[-1]
 | |
|         else:
 | |
|             a = ''
 | |
|     # multiple lines
 | |
|     pred = a.split('\n')[0]
 | |
|     if pred != '' and pred[0] == ':':
 | |
|         pred = pred[1:]
 | |
|     if pred != '' and pred[-1] == '.':
 | |
|         pred = pred[:-1]
 | |
|     if pred != '' and pred[-1] == '/':
 | |
|         pred = pred[:-1]
 | |
|     pred = strip_string(pred)
 | |
|     return pred
 | |
| 
 | |
| 
 | |
| def is_digit(s):
 | |
|     try:
 | |
|         float(str(s).replace(',', ''))
 | |
|         return True
 | |
|     except ValueError:
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def math_equal(
 | |
|     prediction: Union[bool, float, str],
 | |
|     reference: Union[float, str],
 | |
|     include_percentage: bool = True,
 | |
|     is_close: bool = True,
 | |
|     tolerance: float = 1e-4,
 | |
|     timeout: bool = False,
 | |
| ) -> bool:
 | |
|     """Exact match of math if and only if:
 | |
| 
 | |
|     1. numerical equal: both can convert to float and are equal
 | |
|     2. symbolic equal: both can convert to sympy expression and are equal
 | |
|     """
 | |
|     try:  # 1. numerical equal
 | |
|         if is_digit(prediction) and is_digit(reference):
 | |
|             prediction = float(str(prediction).replace(',', ''))
 | |
|             reference = float(str(reference).replace(',', ''))
 | |
|             # number questions
 | |
|             if include_percentage:
 | |
|                 gt_result = [reference / 100, reference, reference * 100]
 | |
|             else:
 | |
|                 gt_result = [reference]
 | |
|             for item in gt_result:
 | |
|                 try:
 | |
|                     if is_close:
 | |
|                         if isclose(item, prediction, rel_tol=tolerance):
 | |
|                             return True
 | |
|                     else:
 | |
|                         if item == prediction:
 | |
|                             return True
 | |
|                 except Exception:
 | |
|                     continue
 | |
|             return False
 | |
|     except Exception:
 | |
|         pass
 | |
| 
 | |
|     if not prediction and prediction not in [0, False]:
 | |
|         return False
 | |
| 
 | |
|     # 2. symbolic equal
 | |
|     reference = str(reference).strip()
 | |
|     prediction = str(prediction).strip()
 | |
| 
 | |
|     ## deal with [], (), {}
 | |
|     pred_str, ref_str = prediction, reference
 | |
|     if (prediction.startswith('[') and prediction.endswith(']')
 | |
|             and not reference.startswith('(')) or (
 | |
|                 prediction.startswith('(') and prediction.endswith(')')
 | |
|                 and not reference.startswith('[')):
 | |
|         pred_str = pred_str.strip('[]()')
 | |
|         ref_str = ref_str.strip('[]()')
 | |
|     for s in ['{', '}', '(', ')']:
 | |
|         ref_str = ref_str.replace(s, '')
 | |
|         pred_str = pred_str.replace(s, '')
 | |
|     if pred_str == ref_str:
 | |
|         return True
 | |
| 
 | |
|     ## [a, b] vs. [c, d], return a==c and b==d
 | |
|     if ((prediction.startswith('[') and prediction.endswith(']')) and
 | |
|         (reference.startswith('[') and reference.endswith(']'))
 | |
|             or (prediction.startswith('(') and prediction.endswith(')')) and
 | |
|         (reference.startswith('(') and reference.endswith(')'))):
 | |
|         pred_parts = prediction[1:-1].split(',')
 | |
|         ref_parts = reference[1:-1].split(',')
 | |
|         if len(pred_parts) == len(ref_parts):
 | |
|             if all([
 | |
|                     math_equal(pred_parts[i], ref_parts[i], include_percentage,
 | |
|                                is_close) for i in range(len(pred_parts))
 | |
|             ]):
 | |
|                 return True
 | |
| 
 | |
|     # symbolic equal with sympy
 | |
|     if timeout:
 | |
|         if call_with_timeout(symbolic_equal_process, prediction, reference):
 | |
|             return True
 | |
|     else:
 | |
|         if symbolic_equal(prediction, reference):
 | |
|             return True
 | |
| 
 | |
|     return False
 | |
| 
 | |
| 
 | |
| def math_equal_process(param):
 | |
|     return math_equal(param[-2], param[-1])
 | |
| 
 | |
| 
 | |
| def symbolic_equal(a, b):
 | |
| 
 | |
|     def _parse(s):
 | |
|         for f in [parse_latex, parse_expr]:
 | |
|             try:
 | |
|                 return f(s)
 | |
|             except Exception:
 | |
|                 pass
 | |
|         return s
 | |
| 
 | |
|     a = _parse(a)
 | |
|     b = _parse(b)
 | |
| 
 | |
|     try:
 | |
|         if simplify(a - b) == 0:
 | |
|             return True
 | |
|     except Exception:
 | |
|         pass
 | |
| 
 | |
|     try:
 | |
|         if isclose(N(a), N(b), rel_tol=1e-3):
 | |
|             return True
 | |
|     except Exception:
 | |
|         pass
 | |
|     return False
 | |
| 
 | |
| 
 | |
| def symbolic_equal_process(a, b, output_queue):
 | |
|     result = symbolic_equal(a, b)
 | |
|     output_queue.put(result)
 | |
| 
 | |
| 
 | |
| def call_with_timeout(func, *args, timeout=1, **kwargs):
 | |
|     output_queue = multiprocessing.Queue()
 | |
|     process_args = args + (output_queue, )
 | |
|     process = multiprocessing.Process(target=func,
 | |
|                                       args=process_args,
 | |
|                                       kwargs=kwargs)
 | |
|     process.start()
 | |
|     process.join(timeout)
 | |
| 
 | |
|     if process.is_alive():
 | |
|         process.terminate()
 | |
|         process.join()
 | |
|         return False
 | |
| 
 | |
|     return output_queue.get()
 | |
| 
 | |
| 
 | |
| def init_agent(backend: str, max_turn: int, model_path: str, tp: int,
 | |
|                **kwargs):
 | |
|     if backend == 'lmdeploy':
 | |
|         from lmdeploy import TurbomindEngineConfig
 | |
|         model = LMDeployPipeline(
 | |
|             path=model_path,
 | |
|             model_name='internlm2-chat',
 | |
|             meta_template=INTERNLM2_META,
 | |
|             pipeline_cfg=dict(backend_config=TurbomindEngineConfig(tp=tp)),
 | |
|             **kwargs)
 | |
|     elif backend == 'hf':
 | |
|         model = HFTransformer(path=model_path,
 | |
|                               meta_template=INTERNLM2_META,
 | |
|                               **kwargs)
 | |
|     else:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     agent = Internlm2Agent(
 | |
|         llm=model,
 | |
|         protocol=Internlm2Protocol(meta_prompt=None,
 | |
|                                    interpreter_prompt=DEFAULT_PROMPT),
 | |
|         interpreter_executor=ActionExecutor(actions=[
 | |
|             IPythonInteractiveManager(max_workers=200,
 | |
|                                       ci_lock=os.path.join(
 | |
|                                           os.path.dirname(__file__),
 | |
|                                           '.ipython.lock'))
 | |
|         ]),
 | |
|         max_turn=max_turn)
 | |
|     return agent
 | |
| 
 | |
| 
 | |
| def predict(args):
 | |
| 
 | |
|     def process(d, k):
 | |
|         d['idx'] = k
 | |
|         d['query'] = d['problem']
 | |
|         gt = extract_answer(d['solution'])
 | |
|         if '\\boxed{90\\text{ square\nunits}}' in d['solution']:
 | |
|             gt = '90'
 | |
|         elif '$6$ is our answer' in d['solution']:
 | |
|             gt = '6'
 | |
|         elif gt.startswith('x\\in'):
 | |
|             gt = gt[len('x\\in'):]
 | |
|         gt = strip_string(gt)
 | |
|         d['gt'] = gt
 | |
|         d['pred'], d['steps'] = [], []
 | |
|         d['error'] = None
 | |
|         return d
 | |
| 
 | |
|     dataset = load_dataset('lighteval/MATH', split='test').map(process, True)
 | |
|     agent = init_agent(
 | |
|         backend=args.backend,
 | |
|         max_turn=args.max_turn,
 | |
|         model_path=args.model_path,
 | |
|         tp=args.tp,
 | |
|         temperature=args.temperature,
 | |
|         stop_words=args.stop_words,
 | |
|         top_p=args.top_p,
 | |
|         top_k=args.top_k,
 | |
|         max_new_tokens=args.max_new_tokens,
 | |
|     )
 | |
|     num_batches = ceil(len(dataset) / args.batch_size)
 | |
|     with jsonlines.open(args.output_path, 'w', flush=True) as f:
 | |
|         for i in tqdm(range(num_batches)):
 | |
|             batch = dataset.select(
 | |
|                 range(i * args.batch_size,
 | |
|                       min((i + 1) * args.batch_size, len(dataset))))
 | |
|             try:
 | |
|                 rets = agent.batch_chat(batch['query'])
 | |
|                 for item, ret in zip(batch, rets):
 | |
|                     item['steps'] = ret.inner_steps
 | |
|                     last = item['steps'][-1]
 | |
|                     item['pred'].append(
 | |
|                         extract_answer(last['content']) if last['role'] ==
 | |
|                         'language' else '😭')
 | |
|                     f.write(item)
 | |
|             except Exception as e:
 | |
|                 err = str(traceback.format_exc())
 | |
|                 print(f'Processing batch data error: {e}\n{err}')
 | |
|                 for item in batch:
 | |
|                     item['error'] = err
 | |
|                     f.write(item)
 | |
|             finally:
 | |
|                 agent._interpreter_executor.actions[
 | |
|                     'IPythonInteractiveManager'].reset()
 | |
| 
 | |
| 
 | |
| def evaluate(args):
 | |
|     samples = [sample for sample in jsonlines.open(args.output_path)]
 | |
|     scores = []
 | |
|     timeout_cnt = 0
 | |
|     with ProcessPool() as pool:
 | |
|         future = pool.map(
 | |
|             math_equal_process,
 | |
|             [(idx, pred, sample['gt']) for idx, sample in enumerate(samples)
 | |
|              for pred in sample['pred']],
 | |
|             timeout=20,
 | |
|         )
 | |
|         iterator = future.result()
 | |
|         with tqdm(total=len(samples), desc='Evaluate') as progress_bar:
 | |
|             while True:
 | |
|                 try:
 | |
|                     result = next(iterator)
 | |
|                     scores.append(result)
 | |
|                 except StopIteration:
 | |
|                     break
 | |
|                 except TimeoutError as error:
 | |
|                     print(error)
 | |
|                     scores.append(False)
 | |
|                     timeout_cnt += 1
 | |
|                 except Exception as error:
 | |
|                     print(error.__traceback__)
 | |
|                     scores.append(False)
 | |
|                     # sys.exit()
 | |
|                 progress_bar.update(1)
 | |
| 
 | |
|     idx = 0
 | |
|     score_mat = []
 | |
|     for sample in samples:
 | |
|         sample['score'] = scores[idx:idx + len(sample['pred'])]
 | |
|         assert len(sample['score']) == len(sample['pred'])
 | |
|         score_mat.append(sample['score'])
 | |
|         idx += len(sample['pred'])
 | |
| 
 | |
|     max_len = max([len(s) for s in score_mat])
 | |
| 
 | |
|     for i, s in enumerate(score_mat):
 | |
|         if len(s) < max_len:
 | |
|             score_mat[i] = s + [s[-1]] * (max_len - len(s))  # pad
 | |
| 
 | |
|     # output mean of each column of scores
 | |
|     col_means = np.array(score_mat).mean(axis=0)
 | |
|     mean_score = list(np.round(col_means * 100, decimals=1))
 | |
| 
 | |
|     result_str = f'Num samples: {len(samples)}\n' \
 | |
|         f'Num scores: {len(scores)}\n' \
 | |
|         f'Sum scores: {sum(scores)}\n' \
 | |
|         f'Timeout samples: {timeout_cnt}\n' \
 | |
|         f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \
 | |
|         f'Mean score: {mean_score}\n'
 | |
| 
 | |
|     # each type score
 | |
|     if 'type' in samples[0]:
 | |
|         type_scores = {}
 | |
|         for sample in samples:
 | |
|             if sample['type'] not in type_scores:
 | |
|                 type_scores[sample['type']] = []
 | |
|             type_scores[sample['type']].append(sample['score'][-1])
 | |
|         type_scores = {
 | |
|             k: np.round(np.array(v).mean() * 100, decimals=1)
 | |
|             for k, v in type_scores.items()
 | |
|         }
 | |
|         type_scores = {
 | |
|             k: v
 | |
|             for k, v in sorted(type_scores.items(), key=lambda item: item[0])
 | |
|         }
 | |
|         result_str += f'Type scores: {type_scores}\n'
 | |
| 
 | |
|     print(result_str)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     args = parse_args()
 | |
|     if args.do_infer and os.path.exists(
 | |
|             args.output_path) and not args.overwrite:
 | |
|         args.do_infer = False
 | |
|         print(f'File {args.output_path} already exists. '
 | |
|               f'Please add the `--overwrite` flag if needed.')
 | |
|     if args.do_infer:
 | |
|         predict(args)
 | |
|     if args.do_eval:
 | |
|         if not args.do_infer:
 | |
|             evaluate(args)
 | |
|         else:
 | |
|             import subprocess
 | |
| 
 | |
|             res = subprocess.run(
 | |
|                 [
 | |
|                     sys.executable, __file__, '--output_path',
 | |
|                     args.output_path, '--no-do_infer', '--do_eval'
 | |
|                 ],
 | |
|                 capture_output=True,
 | |
|                 text=True,
 | |
|                 check=True,
 | |
|             )
 | |
|             print(res.stdout)
 |