remove chat args

pull/695/head
wangzy 2024-02-05 14:15:24 +08:00
parent b6f669b1ca
commit b81132283d
1 changed files with 165 additions and 181 deletions

View File

@ -35,15 +35,9 @@ from typing import Union
import jsonlines import jsonlines
import numpy as np import numpy as np
from datasets import load_dataset from datasets import load_dataset
from lagent import ( from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer,
INTERNLM2_META, Internlm2Agent, Internlm2Protocol, LMDeployPipeline,
ActionExecutor, get_tool)
HFTransformer,
Internlm2Agent,
Internlm2Protocol,
LMDeployPipeline,
get_tool,
)
from pebble import ProcessPool from pebble import ProcessPool
from sympy import N, simplify from sympy import N, simplify
from sympy.parsing.latex import parse_latex from sympy.parsing.latex import parse_latex
@ -60,12 +54,11 @@ DEFAULT_PROMPT = (
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument('--backend',
'--backend', type=str,
type=str, default='lmdeploy',
default='lmdeploy', help='Which inference framework to use.',
help='Which inference framework to use.', choices=['lmdeploy', 'hf'])
choices=['lmdeploy', 'hf'])
parser.add_argument( parser.add_argument(
'--model_path', '--model_path',
type=str, type=str,
@ -77,43 +70,37 @@ def parse_args():
type=str, type=str,
required=True, required=True,
help='Path to save inference results to, should be a `jsonl` file') help='Path to save inference results to, should be a `jsonl` file')
parser.add_argument( parser.add_argument('--dataset',
'--dataset', type=str,
type=str, default='math',
default='math', choices=['gsm8k', 'math'],
choices=['gsm8k', 'math'], help='Dataset for inference')
help='Dataset for inference')
parser.add_argument( parser.add_argument(
'--tp', '--tp',
type=int, type=int,
default=1, default=1,
help='Number of tensor parallelism. It may be required in LMDelpoy.') help='Number of tensor parallelism. It may be required in LMDelpoy.')
parser.add_argument( parser.add_argument('--temperature',
'--temperature', type=float,
type=float, default=0.1,
default=0.1, help='Temperature in next token prediction')
help='Temperature in next token prediction') parser.add_argument('--top_p',
parser.add_argument( type=float,
'--top_p', default=0.8,
type=float, help='Parameter for Top-P Sampling.')
default=0.8, parser.add_argument('--top_k',
help='Parameter for Top-P Sampling.') type=int,
parser.add_argument( default=None,
'--top_k', help='Parameter for Top-K Sampling.')
type=int, parser.add_argument('--stop_words',
default=None, type=list,
help='Parameter for Top-K Sampling.') default=['<|action_end|>', '<|im_end|>'],
parser.add_argument( action='append',
'--stop_words', help='Stop words')
type=list, parser.add_argument('--max_tokens',
default=['<|action_end|>', '<|im_end|>'], type=int,
action='append', default=512,
help='Stop words') help='Number of maximum generated tokens.')
parser.add_argument(
'--max_tokens',
type=int,
default=512,
help='Number of maximum generated tokens.')
parser.add_argument( parser.add_argument(
'--do_infer', '--do_infer',
default=True, default=True,
@ -125,32 +112,29 @@ def parse_args():
# action='store_false', # action='store_false',
# help='Disable the inference.' # help='Disable the inference.'
# ) # )
parser.add_argument( parser.add_argument('--do_eval',
'--do_eval', default=False,
default=False, action='store_true',
action='store_true', help='Whether to evaluate the inference results.')
help='Whether to evaluate the inference results.') parser.add_argument('--overwrite',
parser.add_argument( default=False,
'--overwrite', action='store_true',
default=False, help='Whether to overwrite the existing result file')
action='store_true', parser.add_argument('--debug',
help='Whether to overwrite the existing result file') default=False,
parser.add_argument( action='store_true',
'--debug', help='Only infer the first 50 samples')
default=False,
action='store_true',
help='Only infer the first 50 samples')
return parser.parse_args() return parser.parse_args()
def _fix_fracs(string): def _fix_fracs(string):
substrs = string.split("\\frac") substrs = string.split('\\frac')
new_str = substrs[0] new_str = substrs[0]
if len(substrs) > 1: if len(substrs) > 1:
substrs = substrs[1:] substrs = substrs[1:]
for substr in substrs: for substr in substrs:
new_str += "\\frac" new_str += '\\frac'
if len(substr) > 0 and substr[0] == "{": if len(substr) > 0 and substr[0] == '{':
new_str += substr new_str += substr
else: else:
try: try:
@ -159,135 +143,135 @@ def _fix_fracs(string):
return string return string
a = substr[0] a = substr[0]
b = substr[1] b = substr[1]
if b != "{": if b != '{':
if len(substr) > 2: if len(substr) > 2:
post_substr = substr[2:] post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr new_str += '{' + a + '}{' + b + '}' + post_substr
else: else:
new_str += "{" + a + "}{" + b + "}" new_str += '{' + a + '}{' + b + '}'
else: else:
if len(substr) > 2: if len(substr) > 2:
post_substr = substr[2:] post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr new_str += '{' + a + '}' + b + post_substr
else: else:
new_str += "{" + a + "}" + b new_str += '{' + a + '}' + b
string = new_str string = new_str
return string return string
def _fix_a_slash_b(string): def _fix_a_slash_b(string):
if len(string.split("/")) != 2: if len(string.split('/')) != 2:
return string return string
a = string.split("/")[0] a = string.split('/')[0]
b = string.split("/")[1] b = string.split('/')[1]
try: try:
if "sqrt" not in a: if 'sqrt' not in a:
a = int(a) a = int(a)
if "sqrt" not in b: if 'sqrt' not in b:
b = int(b) b = int(b)
assert string == "{}/{}".format(a, b) assert string == '{}/{}'.format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
return new_string return new_string
except Exception: except Exception:
return string return string
def _fix_sqrt(string): def _fix_sqrt(string):
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) _string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string)
return _string return _string
def strip_string(string): def strip_string(string):
string = str(string).strip() string = str(string).strip()
# linebreaks # linebreaks
string = string.replace("\n", "") string = string.replace('\n', '')
# right "." # right "."
string = string.rstrip(".") string = string.rstrip('.')
# remove inverse spaces # remove inverse spaces
string = string.replace("\\!", "") string = string.replace('\\!', '')
string = string.replace("\\ ", "") string = string.replace('\\ ', '')
# replace \\ with \ # replace \\ with \
string = string.replace("\\\\", "\\") string = string.replace('\\\\', '\\')
string = string.replace("\\\\", "\\") string = string.replace('\\\\', '\\')
# replace tfrac and dfrac with frac # replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac") string = string.replace('tfrac', 'frac')
string = string.replace("dfrac", "frac") string = string.replace('dfrac', 'frac')
# remove \left and \right # remove \left and \right
string = string.replace("\\left", "") string = string.replace('\\left', '')
string = string.replace("\\right", "") string = string.replace('\\right', '')
# Remove unit: miles, dollars if after is not none # Remove unit: miles, dollars if after is not none
_string = re.sub(r"\\text{.*?}$", "", string).strip() _string = re.sub(r'\\text{.*?}$', '', string).strip()
if _string != "" and _string != string: if _string != '' and _string != string:
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
string = _string string = _string
# Remove circ (degrees) # Remove circ (degrees)
string = string.replace("^{\\circ}", "") string = string.replace('^{\\circ}', '')
string = string.replace("^\\circ", "") string = string.replace('^\\circ', '')
# remove dollar signs # remove dollar signs
string = string.replace("\\$", "") string = string.replace('\\$', '')
string = string.replace("$", "") string = string.replace('$', '')
string = string.replace("\\text", "") string = string.replace('\\text', '')
string = string.replace("x\\in", "") string = string.replace('x\\in', '')
# remove percentage # remove percentage
string = string.replace("\\%", "") string = string.replace('\\%', '')
string = string.replace("\%", "") string = string.replace('\%', '')
string = string.replace("%", "") string = string.replace('%', '')
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.") string = string.replace(' .', ' 0.')
string = string.replace("{.", "{0.") string = string.replace('{.', '{0.')
# cdot # cdot
string = string.replace("\\cdot", "") string = string.replace('\\cdot', '')
# inf # inf
string = string.replace("infinity", "\\infty") string = string.replace('infinity', '\\infty')
if "\\infty" not in string: if '\\infty' not in string:
string = string.replace("inf", "\\infty") string = string.replace('inf', '\\infty')
string = string.replace("+\\inity", "\\infty") string = string.replace('+\\inity', '\\infty')
# and # and
string = string.replace("and", "") string = string.replace('and', '')
string = string.replace("\\mathbf", "") string = string.replace('\\mathbf', '')
# use regex to remove \mbox{...} # use regex to remove \mbox{...}
string = re.sub(r"\\mbox{.*?}", "", string) string = re.sub(r'\\mbox{.*?}', '', string)
# quote # quote
string.replace("'", "") string.replace("'", '')
string.replace('"', "") string.replace('"', '')
# i, j # i, j
if "j" in string and "i" not in string: if 'j' in string and 'i' not in string:
string = string.replace("j", "i") string = string.replace('j', 'i')
# replace a.000b where b is not number or b is end, with ab, use regex # replace a.000b where b is not number or b is end, with ab, use regex
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) string = re.sub(r'(\d+)\.0+([^\d])', r'\1\2', string)
string = re.sub(r"(\d+)\.0+$", r"\1", string) string = re.sub(r'(\d+)\.0+$', r'\1', string)
# if empty, return empty string # if empty, return empty string
if len(string) == 0: if len(string) == 0:
return string return string
if string[0] == ".": if string[0] == '.':
string = "0" + string string = '0' + string
# to consider: get rid of e.g. "k = " or "q = " at beginning # to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2: if len(string.split('=')) == 2:
if len(string.split("=")[0]) <= 2: if len(string.split('=')[0]) <= 2:
string = string.split("=")[1] string = string.split('=')[1]
string = _fix_sqrt(string) string = _fix_sqrt(string)
string = string.replace(" ", "") string = string.replace(' ', '')
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string) string = _fix_fracs(string)
@ -299,9 +283,9 @@ def strip_string(string):
def last_boxed_only_string(string): def last_boxed_only_string(string):
idx = string.rfind("\\boxed") idx = string.rfind('\\boxed')
if idx < 0: if idx < 0:
idx = string.rfind("\\fbox") idx = string.rfind('\\fbox')
if idx < 0: if idx < 0:
return None return None
@ -309,9 +293,9 @@ def last_boxed_only_string(string):
right_brace_idx = None right_brace_idx = None
num_left_braces_open = 0 num_left_braces_open = 0
while i < len(string): while i < len(string):
if string[i] == "{": if string[i] == '{':
num_left_braces_open += 1 num_left_braces_open += 1
if string[i] == "}": if string[i] == '}':
num_left_braces_open -= 1 num_left_braces_open -= 1
if num_left_braces_open == 0: if num_left_braces_open == 0:
right_brace_idx = i right_brace_idx = i
@ -329,13 +313,13 @@ def last_boxed_only_string(string):
def extract_answer(pred_str): def extract_answer(pred_str):
if 'boxed' not in pred_str: if 'boxed' not in pred_str:
return '' return ''
ans = pred_str.split('boxed')[-1] answer = pred_str.split('boxed')[-1]
if len(ans) == 0: if len(answer) == 0:
return '' return ''
elif (ans[0] == '{'): elif (answer[0] == '{'):
stack = 1 stack = 1
a = '' a = ''
for c in ans[1:]: for c in answer[1:]:
if (c == '{'): if (c == '{'):
stack += 1 stack += 1
a += c a += c
@ -346,9 +330,9 @@ def extract_answer(pred_str):
else: else:
a += c a += c
else: else:
a = ans.split('$')[0].strip() a = answer.split('$')[0].strip()
pred = a.split("\n")[0] pred = a.split('\n')[0]
if pred != '' and pred[0] == ':': if pred != '' and pred[0] == ':':
pred = pred[1:] pred = pred[1:]
if pred != '' and pred[-1] == '.': if pred != '' and pred[-1] == '.':
@ -361,7 +345,7 @@ def extract_answer(pred_str):
def is_digit(s): def is_digit(s):
try: try:
float(str(s).replace(",", "")) float(str(s).replace(',', ''))
return True return True
except ValueError: except ValueError:
return False return False
@ -375,15 +359,15 @@ def math_equal(
tolerance: float = 1e-4, tolerance: float = 1e-4,
timeout: bool = False, timeout: bool = False,
) -> bool: ) -> bool:
""" """Exact match of math if and only if:
Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal 1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal 2. symbolic equal: both can convert to sympy expression and are equal
""" """
try: # 1. numerical equal try: # 1. numerical equal
if is_digit(prediction) and is_digit(reference): if is_digit(prediction) and is_digit(reference):
prediction = float(str(prediction).replace(",", "")) prediction = float(str(prediction).replace(',', ''))
reference = float(str(reference).replace(",", "")) reference = float(str(reference).replace(',', ''))
# number questions # number questions
if include_percentage: if include_percentage:
gt_result = [reference / 100, reference, reference * 100] gt_result = [reference / 100, reference, reference * 100]
@ -412,25 +396,25 @@ def math_equal(
## deal with [], (), {} ## deal with [], (), {}
pred_str, ref_str = prediction, reference pred_str, ref_str = prediction, reference
if (prediction.startswith("[") and prediction.endswith("]") if (prediction.startswith('[') and prediction.endswith(']')
and not reference.startswith("(")) or ( and not reference.startswith('(')) or (
prediction.startswith("(") and prediction.endswith(")") prediction.startswith('(') and prediction.endswith(')')
and not reference.startswith("[")): and not reference.startswith('[')):
pred_str = pred_str.strip("[]()") pred_str = pred_str.strip('[]()')
ref_str = ref_str.strip("[]()") ref_str = ref_str.strip('[]()')
for s in ["{", "}", "(", ")"]: for s in ['{', '}', '(', ')']:
ref_str = ref_str.replace(s, "") ref_str = ref_str.replace(s, '')
pred_str = pred_str.replace(s, "") pred_str = pred_str.replace(s, '')
if pred_str == ref_str: if pred_str == ref_str:
return True return True
## [a, b] vs. [c, d], return a==c and b==d ## [a, b] vs. [c, d], return a==c and b==d
if ((prediction.startswith("[") and prediction.endswith("]")) and if ((prediction.startswith('[') and prediction.endswith(']')) and
(reference.startswith("[") and reference.endswith("]")) (reference.startswith('[') and reference.endswith(']'))
or (prediction.startswith("(") and prediction.endswith(")")) and or (prediction.startswith('(') and prediction.endswith(')')) and
(reference.startswith("(") and reference.endswith(")"))): (reference.startswith('(') and reference.endswith(')'))):
pred_parts = prediction[1:-1].split(",") pred_parts = prediction[1:-1].split(',')
ref_parts = reference[1:-1].split(",") ref_parts = reference[1:-1].split(',')
if len(pred_parts) == len(ref_parts): if len(pred_parts) == len(ref_parts):
if all([ if all([
math_equal(pred_parts[i], ref_parts[i], include_percentage, math_equal(pred_parts[i], ref_parts[i], include_percentage,
@ -488,8 +472,9 @@ def symbolic_equal_process(a, b, output_queue):
def call_with_timeout(func, *args, timeout=1, **kwargs): def call_with_timeout(func, *args, timeout=1, **kwargs):
output_queue = multiprocessing.Queue() output_queue = multiprocessing.Queue()
process_args = args + (output_queue, ) process_args = args + (output_queue, )
process = multiprocessing.Process( process = multiprocessing.Process(target=func,
target=func, args=process_args, kwargs=kwargs) args=process_args,
kwargs=kwargs)
process.start() process.start()
process.join(timeout) process.join(timeout)
@ -501,22 +486,25 @@ def call_with_timeout(func, *args, timeout=1, **kwargs):
return output_queue.get() return output_queue.get()
def init_agent(backend: str, model_path: str, tp=1, **kwargs): def init_agent(backend: str, model_path: str, tp: int, **kwargs):
if backend == 'lmdeploy': if backend == 'lmdeploy':
model = LMDeployPipeline( model = LMDeployPipeline(path=model_path,
path=model_path, meta_template=INTERNLM2_META, tp=tp, **kwargs) meta_template=INTERNLM2_META,
tp=tp,
**kwargs)
elif backend == 'hf': elif backend == 'hf':
model = HFTransformer( model = HFTransformer(path=model_path,
path=model_path, meta_template=INTERNLM2_META, **kwargs) meta_template=INTERNLM2_META,
**kwargs)
else: else:
raise NotImplementedError raise NotImplementedError
agent = Internlm2Agent( agent = Internlm2Agent(llm=model,
llm=model, protocol=Internlm2Protocol(
protocol=Internlm2Protocol( meta_prompt=None,
meta_prompt=None, interpreter_prompt=DEFAULT_PROMPT), interpreter_prompt=DEFAULT_PROMPT),
interpreter_executor=ActionExecutor( interpreter_executor=ActionExecutor(
actions=[get_tool('IPythonInteractive')])) actions=[get_tool('IPythonInteractive')]))
return agent return agent
@ -533,8 +521,8 @@ def predict(args):
d['pred'], d['steps'], d['error'] = None, [], [] d['pred'], d['steps'], d['error'] = None, [], []
return d return d
dataset = load_dataset( dataset = load_dataset('gsm8k', 'main',
'gsm8k', 'main', split='test').map(process, True) split='test').map(process, True)
elif args.dataset == 'math': elif args.dataset == 'math':
@ -554,8 +542,8 @@ def predict(args):
d['error'] = None d['error'] = None
return d return d
dataset = load_dataset( dataset = load_dataset('lighteval/MATH',
"lighteval/MATH", split='test').map(process, True) split='test').map(process, True)
else: else:
raise NotImplementedError raise NotImplementedError
@ -570,15 +558,11 @@ def predict(args):
top_k=args.top_k, top_k=args.top_k,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
) )
gen_kwargs = {
'max_new_tokens': args.max_tokens
} if args.backend == 'hf' else {}
with jsonlines.open(args.output_path, 'w') as f: with jsonlines.open(args.output_path, 'w') as f:
for item in tqdm( for item in tqdm(
dataset if not args.debug else dataset.select(range(50))): dataset if not args.debug else dataset.select(range(50))):
try: try:
ret = agent.chat(item['query'], **gen_kwargs) ret = agent.chat(item['query'])
item['steps'] = ret.inner_steps item['steps'] = ret.inner_steps
lang = [ lang = [
@ -607,7 +591,7 @@ def evaluate(args):
timeout=20, timeout=20,
) )
iterator = future.result() iterator = future.result()
with tqdm(total=len(samples), desc="Evaluate") as progress_bar: with tqdm(total=len(samples), desc='Evaluate') as progress_bar:
while True: while True:
try: try:
result = next(iterator) result = next(iterator)
@ -641,15 +625,15 @@ def evaluate(args):
col_means = np.array(score_mat).mean(axis=0) col_means = np.array(score_mat).mean(axis=0)
mean_score = list(np.round(col_means * 100, decimals=1)) mean_score = list(np.round(col_means * 100, decimals=1))
result_str = f"Num samples: {len(samples)}\n" \ result_str = f'Num samples: {len(samples)}\n' \
f"Num scores: {len(scores)}\n" \ f'Num scores: {len(scores)}\n' \
f"Sum scores: {sum(scores)}\n" \ f'Sum scores: {sum(scores)}\n' \
f"Timeout samples: {timeout_cnt}\n" \ f'Timeout samples: {timeout_cnt}\n' \
f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \ f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \
f"Mean score: {mean_score}\n" f'Mean score: {mean_score}\n'
# each type score # each type score
if "type" in samples[0]: if 'type' in samples[0]:
type_scores = {} type_scores = {}
for sample in samples: for sample in samples:
if sample['type'] not in type_scores: if sample['type'] not in type_scores:
@ -663,7 +647,7 @@ def evaluate(args):
k: v k: v
for k, v in sorted(type_scores.items(), key=lambda item: item[0]) for k, v in sorted(type_scores.items(), key=lambda item: item[0])
} }
result_str += f"Type scores: {type_scores}\n" result_str += f'Type scores: {type_scores}\n'
print(result_str) print(result_str)