# flake8: noqa # isort: skip_file # This logic is modified from ToRA: # - https://github.com/microsoft/ToRA # # Copyright (c) Microsoft Corporation. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE import argparse import multiprocessing import os import re import sys import traceback from math import isclose, ceil from typing import Union import jsonlines import numpy as np from datasets import load_dataset from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer, Internlm2Agent, Internlm2Protocol, LMDeployPipeline, IPythonInteractiveManager) from pebble import ProcessPool from sympy import N, simplify from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr from tqdm import tqdm # --------------------- modify the system prompt as needed --------------------- DEFAULT_PROMPT = ( 'Integrate step-by-step reasoning and Python code to solve math problems ' 'using the following guidelines:\n' '- Analyze the question and write jupyter code to solve the problem;\n' r"- Present the final result in LaTeX using a '\boxed{{}}' without any " 'units. \n') # ------------------------------------------------------------------------------ def parse_args(): parser = argparse.ArgumentParser(description='Math Code Interpreter') parser.add_argument('--backend', type=str, default='lmdeploy', help='Which inference framework to use.', choices=['lmdeploy', 'hf']) parser.add_argument( '--model_path', type=str, default='internlm/internlm2-chat-7b', help='Path or name to the model, could be HuggingFace model specifier.' ) parser.add_argument( '--output_path', type=str, required=True, help='Path to save inference results to, should be a `jsonl` file') parser.add_argument('--batch_size', type=int, default=100, help='Agent inference batch size') parser.add_argument( '--max_turn', type=int, default=5, help= 'Maximum number of interaction rounds between the agent and environment' ) parser.add_argument( '--tp', type=int, default=1, help='Number of tensor parallelism. It may be required in LMDelpoy.') parser.add_argument('--temperature', type=float, default=0.1, help='Temperature in next token prediction') parser.add_argument('--top_p', type=float, default=0.8, help='Parameter for Top-P Sampling.') parser.add_argument('--top_k', type=int, default=40, help='Parameter for Top-K Sampling.') parser.add_argument('--stop_words', type=str, default=['<|action_end|>', '<|im_end|>'], action='append', help='Stop words') parser.add_argument('--max_new_tokens', type=int, default=512, help='Number of maximum generated tokens.') parser.add_argument( '--do_infer', default=True, action=argparse.BooleanOptionalAction, # python > 3.8 help='Whether to launch model inference.') # parser.add_argument( # '--no-do_infer', # dest='do_infer', # action='store_false', # help='Disable the inference.' # ) parser.add_argument('--do_eval', default=False, action='store_true', help='Whether to evaluate the inference results.') parser.add_argument('--overwrite', default=False, action='store_true', help='Whether to overwrite the existing result file') return parser.parse_args() def _fix_fracs(string): substrs = string.split('\\frac') new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] for substr in substrs: new_str += '\\frac' if len(substr) > 0 and substr[0] == '{': new_str += substr else: try: assert len(substr) >= 2 except Exception: return string a = substr[0] b = substr[1] if b != '{': if len(substr) > 2: post_substr = substr[2:] new_str += '{' + a + '}{' + b + '}' + post_substr else: new_str += '{' + a + '}{' + b + '}' else: if len(substr) > 2: post_substr = substr[2:] new_str += '{' + a + '}' + b + post_substr else: new_str += '{' + a + '}' + b string = new_str return string def _fix_a_slash_b(string): if len(string.split('/')) != 2: return string a = string.split('/')[0] b = string.split('/')[1] try: if 'sqrt' not in a: a = int(a) if 'sqrt' not in b: b = int(b) assert string == '{}/{}'.format(a, b) new_string = '\\frac{' + str(a) + '}{' + str(b) + '}' return new_string except Exception: return string def _fix_sqrt(string): _string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string) return _string def strip_string(string): string = str(string).strip() # linebreaks string = string.replace('\n', '') # right "." string = string.rstrip('.') # remove inverse spaces string = string.replace('\\!', '') string = string.replace('\\ ', '') # replace \\ with \ string = string.replace('\\\\', '\\') string = string.replace('\\\\', '\\') # replace tfrac and dfrac with frac string = string.replace('tfrac', 'frac') string = string.replace('dfrac', 'frac') # remove \left and \right string = string.replace('\\left', '') string = string.replace('\\right', '') # Remove unit: miles, dollars if after is not none _string = re.sub(r'\\text{.*?}$', '', string).strip() if _string != '' and _string != string: # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) string = _string # Remove circ (degrees) string = string.replace('^{\\circ}', '') string = string.replace('^\\circ', '') # remove dollar signs string = string.replace('\\$', '') string = string.replace('$', '') string = string.replace('\\text', '') string = string.replace('x\\in', '') # remove percentage string = string.replace('\\%', '') string = string.replace('\%', '') string = string.replace('%', '') # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string string = string.replace(' .', ' 0.') string = string.replace('{.', '{0.') # cdot string = string.replace('\\cdot', '') # inf string = string.replace('infinity', '\\infty') if '\\infty' not in string: string = string.replace('inf', '\\infty') string = string.replace('+\\inity', '\\infty') # and string = string.replace('and', '') string = string.replace('\\mathbf', '') # use regex to remove \mbox{...} string = re.sub(r'\\mbox{.*?}', '', string) # quote string.replace("'", '') string.replace('"', '') # i, j if 'j' in string and 'i' not in string: string = string.replace('j', 'i') # replace a.000b where b is not number or b is end, with ab, use regex string = re.sub(r'(\d+)\.0+([^\d])', r'\1\2', string) string = re.sub(r'(\d+)\.0+$', r'\1', string) # if empty, return empty string if len(string) == 0: return string if string[0] == '.': string = '0' + string # to consider: get rid of e.g. "k = " or "q = " at beginning if len(string.split('=')) == 2: if len(string.split('=')[0]) <= 2: string = string.split('=')[1] string = _fix_sqrt(string) string = string.replace(' ', '') # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} string = _fix_fracs(string) # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y string = _fix_a_slash_b(string) return string def last_boxed_only_string(string): idx = string.rfind('\\boxed') if idx < 0: idx = string.rfind('\\fbox') if idx < 0: return None i = idx right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == '{': num_left_braces_open += 1 if string[i] == '}': num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 if right_brace_idx is None: retval = None else: retval = string[idx:right_brace_idx + 1] return retval def extract_answer(pred_str: str, execute: bool = False) -> str: if re.search('\boxed|boxed', pred_str): answer = re.split('\boxed|boxed', pred_str)[-1] if len(answer) == 0: return '' elif (answer[0] == '{'): stack = 1 a = '' for c in answer[1:]: if (c == '{'): stack += 1 a += c elif (c == '}'): stack -= 1 if (stack == 0): break a += c else: a += c else: a = answer.split('$')[0].strip() elif re.search('[Tt]he (final )?answer is:?', pred_str): a = re.split('[Tt]he (final )?answer is:?', pred_str)[-1].strip().rstrip('.') elif pred_str.startswith('```python') and execute: # fall back to program from lagent import get_tool a = get_tool('IPythonInteractive').exec(pred_str).value or '' else: # use the last number pred = re.findall(r'-?\d*\.?\d+', pred_str.replace(',', '')) if len(pred) >= 1: a = pred[-1] else: a = '' # multiple lines pred = a.split('\n')[0] if pred != '' and pred[0] == ':': pred = pred[1:] if pred != '' and pred[-1] == '.': pred = pred[:-1] if pred != '' and pred[-1] == '/': pred = pred[:-1] pred = strip_string(pred) return pred def is_digit(s): try: float(str(s).replace(',', '')) return True except ValueError: return False def math_equal( prediction: Union[bool, float, str], reference: Union[float, str], include_percentage: bool = True, is_close: bool = True, tolerance: float = 1e-4, timeout: bool = False, ) -> bool: """Exact match of math if and only if: 1. numerical equal: both can convert to float and are equal 2. symbolic equal: both can convert to sympy expression and are equal """ try: # 1. numerical equal if is_digit(prediction) and is_digit(reference): prediction = float(str(prediction).replace(',', '')) reference = float(str(reference).replace(',', '')) # number questions if include_percentage: gt_result = [reference / 100, reference, reference * 100] else: gt_result = [reference] for item in gt_result: try: if is_close: if isclose(item, prediction, rel_tol=tolerance): return True else: if item == prediction: return True except Exception: continue return False except Exception: pass if not prediction and prediction not in [0, False]: return False # 2. symbolic equal reference = str(reference).strip() prediction = str(prediction).strip() ## deal with [], (), {} pred_str, ref_str = prediction, reference if (prediction.startswith('[') and prediction.endswith(']') and not reference.startswith('(')) or ( prediction.startswith('(') and prediction.endswith(')') and not reference.startswith('[')): pred_str = pred_str.strip('[]()') ref_str = ref_str.strip('[]()') for s in ['{', '}', '(', ')']: ref_str = ref_str.replace(s, '') pred_str = pred_str.replace(s, '') if pred_str == ref_str: return True ## [a, b] vs. [c, d], return a==c and b==d if ((prediction.startswith('[') and prediction.endswith(']')) and (reference.startswith('[') and reference.endswith(']')) or (prediction.startswith('(') and prediction.endswith(')')) and (reference.startswith('(') and reference.endswith(')'))): pred_parts = prediction[1:-1].split(',') ref_parts = reference[1:-1].split(',') if len(pred_parts) == len(ref_parts): if all([ math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts)) ]): return True # symbolic equal with sympy if timeout: if call_with_timeout(symbolic_equal_process, prediction, reference): return True else: if symbolic_equal(prediction, reference): return True return False def math_equal_process(param): return math_equal(param[-2], param[-1]) def symbolic_equal(a, b): def _parse(s): for f in [parse_latex, parse_expr]: try: return f(s) except Exception: pass return s a = _parse(a) b = _parse(b) try: if simplify(a - b) == 0: return True except Exception: pass try: if isclose(N(a), N(b), rel_tol=1e-3): return True except Exception: pass return False def symbolic_equal_process(a, b, output_queue): result = symbolic_equal(a, b) output_queue.put(result) def call_with_timeout(func, *args, timeout=1, **kwargs): output_queue = multiprocessing.Queue() process_args = args + (output_queue, ) process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) process.start() process.join(timeout) if process.is_alive(): process.terminate() process.join() return False return output_queue.get() def init_agent(backend: str, max_turn: int, model_path: str, tp: int, **kwargs): if backend == 'lmdeploy': from lmdeploy import TurbomindEngineConfig model = LMDeployPipeline( path=model_path, model_name='internlm2-chat', meta_template=INTERNLM2_META, pipeline_cfg=dict(backend_config=TurbomindEngineConfig(tp=tp)), **kwargs) elif backend == 'hf': model = HFTransformer(path=model_path, meta_template=INTERNLM2_META, **kwargs) else: raise NotImplementedError agent = Internlm2Agent( llm=model, protocol=Internlm2Protocol(meta_prompt=None, interpreter_prompt=DEFAULT_PROMPT), interpreter_executor=ActionExecutor(actions=[ IPythonInteractiveManager(max_workers=200, ci_lock=os.path.join( os.path.dirname(__file__), '.ipython.lock')) ]), max_turn=max_turn) return agent def predict(args): def process(d, k): d['idx'] = k d['query'] = d['problem'] gt = extract_answer(d['solution']) if '\\boxed{90\\text{ square\nunits}}' in d['solution']: gt = '90' elif '$6$ is our answer' in d['solution']: gt = '6' elif gt.startswith('x\\in'): gt = gt[len('x\\in'):] gt = strip_string(gt) d['gt'] = gt d['pred'], d['steps'] = [], [] d['error'] = None return d dataset = load_dataset('lighteval/MATH', split='test').map(process, True) agent = init_agent( backend=args.backend, max_turn=args.max_turn, model_path=args.model_path, tp=args.tp, temperature=args.temperature, stop_words=args.stop_words, top_p=args.top_p, top_k=args.top_k, max_new_tokens=args.max_new_tokens, ) num_batches = ceil(len(dataset) / args.batch_size) with jsonlines.open(args.output_path, 'w', flush=True) as f: for i in tqdm(range(num_batches)): batch = dataset.select( range(i * args.batch_size, min((i + 1) * args.batch_size, len(dataset)))) try: rets = agent.batch_chat(batch['query']) for item, ret in zip(batch, rets): item['steps'] = ret.inner_steps last = item['steps'][-1] item['pred'].append( extract_answer(last['content']) if last['role'] == 'language' else '😭') f.write(item) except Exception as e: err = str(traceback.format_exc()) print(f'Processing batch data error: {e}\n{err}') for item in batch: item['error'] = err f.write(item) finally: agent._interpreter_executor.actions[ 'IPythonInteractiveManager'].reset() def evaluate(args): samples = [sample for sample in jsonlines.open(args.output_path)] scores = [] timeout_cnt = 0 with ProcessPool() as pool: future = pool.map( math_equal_process, [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']], timeout=20, ) iterator = future.result() with tqdm(total=len(samples), desc='Evaluate') as progress_bar: while True: try: result = next(iterator) scores.append(result) except StopIteration: break except TimeoutError as error: print(error) scores.append(False) timeout_cnt += 1 except Exception as error: print(error.__traceback__) scores.append(False) # sys.exit() progress_bar.update(1) idx = 0 score_mat = [] for sample in samples: sample['score'] = scores[idx:idx + len(sample['pred'])] assert len(sample['score']) == len(sample['pred']) score_mat.append(sample['score']) idx += len(sample['pred']) max_len = max([len(s) for s in score_mat]) for i, s in enumerate(score_mat): if len(s) < max_len: score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad # output mean of each column of scores col_means = np.array(score_mat).mean(axis=0) mean_score = list(np.round(col_means * 100, decimals=1)) result_str = f'Num samples: {len(samples)}\n' \ f'Num scores: {len(scores)}\n' \ f'Sum scores: {sum(scores)}\n' \ f'Timeout samples: {timeout_cnt}\n' \ f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \ f'Mean score: {mean_score}\n' # each type score if 'type' in samples[0]: type_scores = {} for sample in samples: if sample['type'] not in type_scores: type_scores[sample['type']] = [] type_scores[sample['type']].append(sample['score'][-1]) type_scores = { k: np.round(np.array(v).mean() * 100, decimals=1) for k, v in type_scores.items() } type_scores = { k: v for k, v in sorted(type_scores.items(), key=lambda item: item[0]) } result_str += f'Type scores: {type_scores}\n' print(result_str) if __name__ == '__main__': args = parse_args() if args.do_infer and os.path.exists( args.output_path) and not args.overwrite: args.do_infer = False print(f'File {args.output_path} already exists. ' f'Please add the `--overwrite` flag if needed.') if args.do_infer: predict(args) if args.do_eval: if not args.do_infer: evaluate(args) else: import subprocess res = subprocess.run( [ sys.executable, __file__, '--output_path', args.output_path, '--no-do_infer', '--do_eval' ], capture_output=True, text=True, check=True, ) print(res.stdout)