From b81132283d74ec87164ed5dc53d5c6a0689dc6ad Mon Sep 17 00:00:00 2001 From: wangzy Date: Mon, 5 Feb 2024 14:15:24 +0800 Subject: [PATCH] remove chat args --- agent/streaming_inference.py | 346 +++++++++++++++++------------------ 1 file changed, 165 insertions(+), 181 deletions(-) diff --git a/agent/streaming_inference.py b/agent/streaming_inference.py index ea7d802..17033aa 100644 --- a/agent/streaming_inference.py +++ b/agent/streaming_inference.py @@ -35,15 +35,9 @@ from typing import Union import jsonlines import numpy as np from datasets import load_dataset -from lagent import ( - INTERNLM2_META, - ActionExecutor, - HFTransformer, - Internlm2Agent, - Internlm2Protocol, - LMDeployPipeline, - get_tool, -) +from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer, + Internlm2Agent, Internlm2Protocol, LMDeployPipeline, + get_tool) from pebble import ProcessPool from sympy import N, simplify from sympy.parsing.latex import parse_latex @@ -60,12 +54,11 @@ DEFAULT_PROMPT = ( def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument( - '--backend', - type=str, - default='lmdeploy', - help='Which inference framework to use.', - choices=['lmdeploy', 'hf']) + parser.add_argument('--backend', + type=str, + default='lmdeploy', + help='Which inference framework to use.', + choices=['lmdeploy', 'hf']) parser.add_argument( '--model_path', type=str, @@ -77,43 +70,37 @@ def parse_args(): type=str, required=True, help='Path to save inference results to, should be a `jsonl` file') - parser.add_argument( - '--dataset', - type=str, - default='math', - choices=['gsm8k', 'math'], - help='Dataset for inference') + parser.add_argument('--dataset', + type=str, + default='math', + choices=['gsm8k', 'math'], + help='Dataset for inference') parser.add_argument( '--tp', type=int, default=1, help='Number of tensor parallelism. It may be required in LMDelpoy.') - parser.add_argument( - '--temperature', - type=float, - default=0.1, - help='Temperature in next token prediction') - parser.add_argument( - '--top_p', - type=float, - default=0.8, - help='Parameter for Top-P Sampling.') - parser.add_argument( - '--top_k', - type=int, - default=None, - help='Parameter for Top-K Sampling.') - parser.add_argument( - '--stop_words', - type=list, - default=['<|action_end|>', '<|im_end|>'], - action='append', - help='Stop words') - parser.add_argument( - '--max_tokens', - type=int, - default=512, - help='Number of maximum generated tokens.') + parser.add_argument('--temperature', + type=float, + default=0.1, + help='Temperature in next token prediction') + parser.add_argument('--top_p', + type=float, + default=0.8, + help='Parameter for Top-P Sampling.') + parser.add_argument('--top_k', + type=int, + default=None, + help='Parameter for Top-K Sampling.') + parser.add_argument('--stop_words', + type=list, + default=['<|action_end|>', '<|im_end|>'], + action='append', + help='Stop words') + parser.add_argument('--max_tokens', + type=int, + default=512, + help='Number of maximum generated tokens.') parser.add_argument( '--do_infer', default=True, @@ -125,32 +112,29 @@ def parse_args(): # action='store_false', # help='Disable the inference.' # ) - parser.add_argument( - '--do_eval', - default=False, - action='store_true', - help='Whether to evaluate the inference results.') - parser.add_argument( - '--overwrite', - default=False, - action='store_true', - help='Whether to overwrite the existing result file') - parser.add_argument( - '--debug', - default=False, - action='store_true', - help='Only infer the first 50 samples') + parser.add_argument('--do_eval', + default=False, + action='store_true', + help='Whether to evaluate the inference results.') + parser.add_argument('--overwrite', + default=False, + action='store_true', + help='Whether to overwrite the existing result file') + parser.add_argument('--debug', + default=False, + action='store_true', + help='Only infer the first 50 samples') return parser.parse_args() def _fix_fracs(string): - substrs = string.split("\\frac") + substrs = string.split('\\frac') new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] for substr in substrs: - new_str += "\\frac" - if len(substr) > 0 and substr[0] == "{": + new_str += '\\frac' + if len(substr) > 0 and substr[0] == '{': new_str += substr else: try: @@ -159,135 +143,135 @@ def _fix_fracs(string): return string a = substr[0] b = substr[1] - if b != "{": + if b != '{': if len(substr) > 2: post_substr = substr[2:] - new_str += "{" + a + "}{" + b + "}" + post_substr + new_str += '{' + a + '}{' + b + '}' + post_substr else: - new_str += "{" + a + "}{" + b + "}" + new_str += '{' + a + '}{' + b + '}' else: if len(substr) > 2: post_substr = substr[2:] - new_str += "{" + a + "}" + b + post_substr + new_str += '{' + a + '}' + b + post_substr else: - new_str += "{" + a + "}" + b + new_str += '{' + a + '}' + b string = new_str return string def _fix_a_slash_b(string): - if len(string.split("/")) != 2: + if len(string.split('/')) != 2: return string - a = string.split("/")[0] - b = string.split("/")[1] + a = string.split('/')[0] + b = string.split('/')[1] try: - if "sqrt" not in a: + if 'sqrt' not in a: a = int(a) - if "sqrt" not in b: + if 'sqrt' not in b: b = int(b) - assert string == "{}/{}".format(a, b) - new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + assert string == '{}/{}'.format(a, b) + new_string = '\\frac{' + str(a) + '}{' + str(b) + '}' return new_string except Exception: return string def _fix_sqrt(string): - _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) + _string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string) return _string def strip_string(string): string = str(string).strip() # linebreaks - string = string.replace("\n", "") + string = string.replace('\n', '') # right "." - string = string.rstrip(".") + string = string.rstrip('.') # remove inverse spaces - string = string.replace("\\!", "") - string = string.replace("\\ ", "") + string = string.replace('\\!', '') + string = string.replace('\\ ', '') # replace \\ with \ - string = string.replace("\\\\", "\\") - string = string.replace("\\\\", "\\") + string = string.replace('\\\\', '\\') + string = string.replace('\\\\', '\\') # replace tfrac and dfrac with frac - string = string.replace("tfrac", "frac") - string = string.replace("dfrac", "frac") + string = string.replace('tfrac', 'frac') + string = string.replace('dfrac', 'frac') # remove \left and \right - string = string.replace("\\left", "") - string = string.replace("\\right", "") + string = string.replace('\\left', '') + string = string.replace('\\right', '') # Remove unit: miles, dollars if after is not none - _string = re.sub(r"\\text{.*?}$", "", string).strip() - if _string != "" and _string != string: + _string = re.sub(r'\\text{.*?}$', '', string).strip() + if _string != '' and _string != string: # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) string = _string # Remove circ (degrees) - string = string.replace("^{\\circ}", "") - string = string.replace("^\\circ", "") + string = string.replace('^{\\circ}', '') + string = string.replace('^\\circ', '') # remove dollar signs - string = string.replace("\\$", "") - string = string.replace("$", "") + string = string.replace('\\$', '') + string = string.replace('$', '') - string = string.replace("\\text", "") - string = string.replace("x\\in", "") + string = string.replace('\\text', '') + string = string.replace('x\\in', '') # remove percentage - string = string.replace("\\%", "") - string = string.replace("\%", "") - string = string.replace("%", "") + string = string.replace('\\%', '') + string = string.replace('\%', '') + string = string.replace('%', '') # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string - string = string.replace(" .", " 0.") - string = string.replace("{.", "{0.") + string = string.replace(' .', ' 0.') + string = string.replace('{.', '{0.') # cdot - string = string.replace("\\cdot", "") + string = string.replace('\\cdot', '') # inf - string = string.replace("infinity", "\\infty") - if "\\infty" not in string: - string = string.replace("inf", "\\infty") - string = string.replace("+\\inity", "\\infty") + string = string.replace('infinity', '\\infty') + if '\\infty' not in string: + string = string.replace('inf', '\\infty') + string = string.replace('+\\inity', '\\infty') # and - string = string.replace("and", "") - string = string.replace("\\mathbf", "") + string = string.replace('and', '') + string = string.replace('\\mathbf', '') # use regex to remove \mbox{...} - string = re.sub(r"\\mbox{.*?}", "", string) + string = re.sub(r'\\mbox{.*?}', '', string) # quote - string.replace("'", "") - string.replace('"', "") + string.replace("'", '') + string.replace('"', '') # i, j - if "j" in string and "i" not in string: - string = string.replace("j", "i") + if 'j' in string and 'i' not in string: + string = string.replace('j', 'i') # replace a.000b where b is not number or b is end, with ab, use regex - string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) - string = re.sub(r"(\d+)\.0+$", r"\1", string) + string = re.sub(r'(\d+)\.0+([^\d])', r'\1\2', string) + string = re.sub(r'(\d+)\.0+$', r'\1', string) # if empty, return empty string if len(string) == 0: return string - if string[0] == ".": - string = "0" + string + if string[0] == '.': + string = '0' + string # to consider: get rid of e.g. "k = " or "q = " at beginning - if len(string.split("=")) == 2: - if len(string.split("=")[0]) <= 2: - string = string.split("=")[1] + if len(string.split('=')) == 2: + if len(string.split('=')[0]) <= 2: + string = string.split('=')[1] string = _fix_sqrt(string) - string = string.replace(" ", "") + string = string.replace(' ', '') # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} string = _fix_fracs(string) @@ -299,9 +283,9 @@ def strip_string(string): def last_boxed_only_string(string): - idx = string.rfind("\\boxed") + idx = string.rfind('\\boxed') if idx < 0: - idx = string.rfind("\\fbox") + idx = string.rfind('\\fbox') if idx < 0: return None @@ -309,9 +293,9 @@ def last_boxed_only_string(string): right_brace_idx = None num_left_braces_open = 0 while i < len(string): - if string[i] == "{": + if string[i] == '{': num_left_braces_open += 1 - if string[i] == "}": + if string[i] == '}': num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i @@ -329,13 +313,13 @@ def last_boxed_only_string(string): def extract_answer(pred_str): if 'boxed' not in pred_str: return '' - ans = pred_str.split('boxed')[-1] - if len(ans) == 0: + answer = pred_str.split('boxed')[-1] + if len(answer) == 0: return '' - elif (ans[0] == '{'): + elif (answer[0] == '{'): stack = 1 a = '' - for c in ans[1:]: + for c in answer[1:]: if (c == '{'): stack += 1 a += c @@ -346,9 +330,9 @@ def extract_answer(pred_str): else: a += c else: - a = ans.split('$')[0].strip() + a = answer.split('$')[0].strip() - pred = a.split("\n")[0] + pred = a.split('\n')[0] if pred != '' and pred[0] == ':': pred = pred[1:] if pred != '' and pred[-1] == '.': @@ -361,7 +345,7 @@ def extract_answer(pred_str): def is_digit(s): try: - float(str(s).replace(",", "")) + float(str(s).replace(',', '')) return True except ValueError: return False @@ -375,15 +359,15 @@ def math_equal( tolerance: float = 1e-4, timeout: bool = False, ) -> bool: - """ - Exact match of math if and only if: + """Exact match of math if and only if: + 1. numerical equal: both can convert to float and are equal 2. symbolic equal: both can convert to sympy expression and are equal """ try: # 1. numerical equal if is_digit(prediction) and is_digit(reference): - prediction = float(str(prediction).replace(",", "")) - reference = float(str(reference).replace(",", "")) + prediction = float(str(prediction).replace(',', '')) + reference = float(str(reference).replace(',', '')) # number questions if include_percentage: gt_result = [reference / 100, reference, reference * 100] @@ -412,25 +396,25 @@ def math_equal( ## deal with [], (), {} pred_str, ref_str = prediction, reference - if (prediction.startswith("[") and prediction.endswith("]") - and not reference.startswith("(")) or ( - prediction.startswith("(") and prediction.endswith(")") - and not reference.startswith("[")): - pred_str = pred_str.strip("[]()") - ref_str = ref_str.strip("[]()") - for s in ["{", "}", "(", ")"]: - ref_str = ref_str.replace(s, "") - pred_str = pred_str.replace(s, "") + if (prediction.startswith('[') and prediction.endswith(']') + and not reference.startswith('(')) or ( + prediction.startswith('(') and prediction.endswith(')') + and not reference.startswith('[')): + pred_str = pred_str.strip('[]()') + ref_str = ref_str.strip('[]()') + for s in ['{', '}', '(', ')']: + ref_str = ref_str.replace(s, '') + pred_str = pred_str.replace(s, '') if pred_str == ref_str: return True ## [a, b] vs. [c, d], return a==c and b==d - if ((prediction.startswith("[") and prediction.endswith("]")) and - (reference.startswith("[") and reference.endswith("]")) - or (prediction.startswith("(") and prediction.endswith(")")) and - (reference.startswith("(") and reference.endswith(")"))): - pred_parts = prediction[1:-1].split(",") - ref_parts = reference[1:-1].split(",") + if ((prediction.startswith('[') and prediction.endswith(']')) and + (reference.startswith('[') and reference.endswith(']')) + or (prediction.startswith('(') and prediction.endswith(')')) and + (reference.startswith('(') and reference.endswith(')'))): + pred_parts = prediction[1:-1].split(',') + ref_parts = reference[1:-1].split(',') if len(pred_parts) == len(ref_parts): if all([ math_equal(pred_parts[i], ref_parts[i], include_percentage, @@ -488,8 +472,9 @@ def symbolic_equal_process(a, b, output_queue): def call_with_timeout(func, *args, timeout=1, **kwargs): output_queue = multiprocessing.Queue() process_args = args + (output_queue, ) - process = multiprocessing.Process( - target=func, args=process_args, kwargs=kwargs) + process = multiprocessing.Process(target=func, + args=process_args, + kwargs=kwargs) process.start() process.join(timeout) @@ -501,22 +486,25 @@ def call_with_timeout(func, *args, timeout=1, **kwargs): return output_queue.get() -def init_agent(backend: str, model_path: str, tp=1, **kwargs): +def init_agent(backend: str, model_path: str, tp: int, **kwargs): if backend == 'lmdeploy': - model = LMDeployPipeline( - path=model_path, meta_template=INTERNLM2_META, tp=tp, **kwargs) + model = LMDeployPipeline(path=model_path, + meta_template=INTERNLM2_META, + tp=tp, + **kwargs) elif backend == 'hf': - model = HFTransformer( - path=model_path, meta_template=INTERNLM2_META, **kwargs) + model = HFTransformer(path=model_path, + meta_template=INTERNLM2_META, + **kwargs) else: raise NotImplementedError - agent = Internlm2Agent( - llm=model, - protocol=Internlm2Protocol( - meta_prompt=None, interpreter_prompt=DEFAULT_PROMPT), - interpreter_executor=ActionExecutor( - actions=[get_tool('IPythonInteractive')])) + agent = Internlm2Agent(llm=model, + protocol=Internlm2Protocol( + meta_prompt=None, + interpreter_prompt=DEFAULT_PROMPT), + interpreter_executor=ActionExecutor( + actions=[get_tool('IPythonInteractive')])) return agent @@ -533,8 +521,8 @@ def predict(args): d['pred'], d['steps'], d['error'] = None, [], [] return d - dataset = load_dataset( - 'gsm8k', 'main', split='test').map(process, True) + dataset = load_dataset('gsm8k', 'main', + split='test').map(process, True) elif args.dataset == 'math': @@ -554,8 +542,8 @@ def predict(args): d['error'] = None return d - dataset = load_dataset( - "lighteval/MATH", split='test').map(process, True) + dataset = load_dataset('lighteval/MATH', + split='test').map(process, True) else: raise NotImplementedError @@ -570,15 +558,11 @@ def predict(args): top_k=args.top_k, max_tokens=args.max_tokens, ) - gen_kwargs = { - 'max_new_tokens': args.max_tokens - } if args.backend == 'hf' else {} - with jsonlines.open(args.output_path, 'w') as f: for item in tqdm( dataset if not args.debug else dataset.select(range(50))): try: - ret = agent.chat(item['query'], **gen_kwargs) + ret = agent.chat(item['query']) item['steps'] = ret.inner_steps lang = [ @@ -607,7 +591,7 @@ def evaluate(args): timeout=20, ) iterator = future.result() - with tqdm(total=len(samples), desc="Evaluate") as progress_bar: + with tqdm(total=len(samples), desc='Evaluate') as progress_bar: while True: try: result = next(iterator) @@ -641,15 +625,15 @@ def evaluate(args): col_means = np.array(score_mat).mean(axis=0) mean_score = list(np.round(col_means * 100, decimals=1)) - result_str = f"Num samples: {len(samples)}\n" \ - f"Num scores: {len(scores)}\n" \ - f"Sum scores: {sum(scores)}\n" \ - f"Timeout samples: {timeout_cnt}\n" \ + result_str = f'Num samples: {len(samples)}\n' \ + f'Num scores: {len(scores)}\n' \ + f'Sum scores: {sum(scores)}\n' \ + f'Timeout samples: {timeout_cnt}\n' \ f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \ - f"Mean score: {mean_score}\n" + f'Mean score: {mean_score}\n' # each type score - if "type" in samples[0]: + if 'type' in samples[0]: type_scores = {} for sample in samples: if sample['type'] not in type_scores: @@ -663,7 +647,7 @@ def evaluate(args): k: v for k, v in sorted(type_scores.items(), key=lambda item: item[0]) } - result_str += f"Type scores: {type_scores}\n" + result_str += f'Type scores: {type_scores}\n' print(result_str)