mirror of https://github.com/InternLM/InternLM
remove chat args
parent
b6f669b1ca
commit
b81132283d
|
@ -35,15 +35,9 @@ from typing import Union
|
||||||
import jsonlines
|
import jsonlines
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from lagent import (
|
from lagent import (INTERNLM2_META, ActionExecutor, HFTransformer,
|
||||||
INTERNLM2_META,
|
Internlm2Agent, Internlm2Protocol, LMDeployPipeline,
|
||||||
ActionExecutor,
|
get_tool)
|
||||||
HFTransformer,
|
|
||||||
Internlm2Agent,
|
|
||||||
Internlm2Protocol,
|
|
||||||
LMDeployPipeline,
|
|
||||||
get_tool,
|
|
||||||
)
|
|
||||||
from pebble import ProcessPool
|
from pebble import ProcessPool
|
||||||
from sympy import N, simplify
|
from sympy import N, simplify
|
||||||
from sympy.parsing.latex import parse_latex
|
from sympy.parsing.latex import parse_latex
|
||||||
|
@ -60,12 +54,11 @@ DEFAULT_PROMPT = (
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument('--backend',
|
||||||
'--backend',
|
type=str,
|
||||||
type=str,
|
default='lmdeploy',
|
||||||
default='lmdeploy',
|
help='Which inference framework to use.',
|
||||||
help='Which inference framework to use.',
|
choices=['lmdeploy', 'hf'])
|
||||||
choices=['lmdeploy', 'hf'])
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--model_path',
|
'--model_path',
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -77,43 +70,37 @@ def parse_args():
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help='Path to save inference results to, should be a `jsonl` file')
|
help='Path to save inference results to, should be a `jsonl` file')
|
||||||
parser.add_argument(
|
parser.add_argument('--dataset',
|
||||||
'--dataset',
|
type=str,
|
||||||
type=str,
|
default='math',
|
||||||
default='math',
|
choices=['gsm8k', 'math'],
|
||||||
choices=['gsm8k', 'math'],
|
help='Dataset for inference')
|
||||||
help='Dataset for inference')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--tp',
|
'--tp',
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help='Number of tensor parallelism. It may be required in LMDelpoy.')
|
help='Number of tensor parallelism. It may be required in LMDelpoy.')
|
||||||
parser.add_argument(
|
parser.add_argument('--temperature',
|
||||||
'--temperature',
|
type=float,
|
||||||
type=float,
|
default=0.1,
|
||||||
default=0.1,
|
help='Temperature in next token prediction')
|
||||||
help='Temperature in next token prediction')
|
parser.add_argument('--top_p',
|
||||||
parser.add_argument(
|
type=float,
|
||||||
'--top_p',
|
default=0.8,
|
||||||
type=float,
|
help='Parameter for Top-P Sampling.')
|
||||||
default=0.8,
|
parser.add_argument('--top_k',
|
||||||
help='Parameter for Top-P Sampling.')
|
type=int,
|
||||||
parser.add_argument(
|
default=None,
|
||||||
'--top_k',
|
help='Parameter for Top-K Sampling.')
|
||||||
type=int,
|
parser.add_argument('--stop_words',
|
||||||
default=None,
|
type=list,
|
||||||
help='Parameter for Top-K Sampling.')
|
default=['<|action_end|>', '<|im_end|>'],
|
||||||
parser.add_argument(
|
action='append',
|
||||||
'--stop_words',
|
help='Stop words')
|
||||||
type=list,
|
parser.add_argument('--max_tokens',
|
||||||
default=['<|action_end|>', '<|im_end|>'],
|
type=int,
|
||||||
action='append',
|
default=512,
|
||||||
help='Stop words')
|
help='Number of maximum generated tokens.')
|
||||||
parser.add_argument(
|
|
||||||
'--max_tokens',
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help='Number of maximum generated tokens.')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--do_infer',
|
'--do_infer',
|
||||||
default=True,
|
default=True,
|
||||||
|
@ -125,32 +112,29 @@ def parse_args():
|
||||||
# action='store_false',
|
# action='store_false',
|
||||||
# help='Disable the inference.'
|
# help='Disable the inference.'
|
||||||
# )
|
# )
|
||||||
parser.add_argument(
|
parser.add_argument('--do_eval',
|
||||||
'--do_eval',
|
default=False,
|
||||||
default=False,
|
action='store_true',
|
||||||
action='store_true',
|
help='Whether to evaluate the inference results.')
|
||||||
help='Whether to evaluate the inference results.')
|
parser.add_argument('--overwrite',
|
||||||
parser.add_argument(
|
default=False,
|
||||||
'--overwrite',
|
action='store_true',
|
||||||
default=False,
|
help='Whether to overwrite the existing result file')
|
||||||
action='store_true',
|
parser.add_argument('--debug',
|
||||||
help='Whether to overwrite the existing result file')
|
default=False,
|
||||||
parser.add_argument(
|
action='store_true',
|
||||||
'--debug',
|
help='Only infer the first 50 samples')
|
||||||
default=False,
|
|
||||||
action='store_true',
|
|
||||||
help='Only infer the first 50 samples')
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def _fix_fracs(string):
|
def _fix_fracs(string):
|
||||||
substrs = string.split("\\frac")
|
substrs = string.split('\\frac')
|
||||||
new_str = substrs[0]
|
new_str = substrs[0]
|
||||||
if len(substrs) > 1:
|
if len(substrs) > 1:
|
||||||
substrs = substrs[1:]
|
substrs = substrs[1:]
|
||||||
for substr in substrs:
|
for substr in substrs:
|
||||||
new_str += "\\frac"
|
new_str += '\\frac'
|
||||||
if len(substr) > 0 and substr[0] == "{":
|
if len(substr) > 0 and substr[0] == '{':
|
||||||
new_str += substr
|
new_str += substr
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
@ -159,135 +143,135 @@ def _fix_fracs(string):
|
||||||
return string
|
return string
|
||||||
a = substr[0]
|
a = substr[0]
|
||||||
b = substr[1]
|
b = substr[1]
|
||||||
if b != "{":
|
if b != '{':
|
||||||
if len(substr) > 2:
|
if len(substr) > 2:
|
||||||
post_substr = substr[2:]
|
post_substr = substr[2:]
|
||||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
new_str += '{' + a + '}{' + b + '}' + post_substr
|
||||||
else:
|
else:
|
||||||
new_str += "{" + a + "}{" + b + "}"
|
new_str += '{' + a + '}{' + b + '}'
|
||||||
else:
|
else:
|
||||||
if len(substr) > 2:
|
if len(substr) > 2:
|
||||||
post_substr = substr[2:]
|
post_substr = substr[2:]
|
||||||
new_str += "{" + a + "}" + b + post_substr
|
new_str += '{' + a + '}' + b + post_substr
|
||||||
else:
|
else:
|
||||||
new_str += "{" + a + "}" + b
|
new_str += '{' + a + '}' + b
|
||||||
string = new_str
|
string = new_str
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def _fix_a_slash_b(string):
|
def _fix_a_slash_b(string):
|
||||||
if len(string.split("/")) != 2:
|
if len(string.split('/')) != 2:
|
||||||
return string
|
return string
|
||||||
a = string.split("/")[0]
|
a = string.split('/')[0]
|
||||||
b = string.split("/")[1]
|
b = string.split('/')[1]
|
||||||
try:
|
try:
|
||||||
if "sqrt" not in a:
|
if 'sqrt' not in a:
|
||||||
a = int(a)
|
a = int(a)
|
||||||
if "sqrt" not in b:
|
if 'sqrt' not in b:
|
||||||
b = int(b)
|
b = int(b)
|
||||||
assert string == "{}/{}".format(a, b)
|
assert string == '{}/{}'.format(a, b)
|
||||||
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
|
||||||
return new_string
|
return new_string
|
||||||
except Exception:
|
except Exception:
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def _fix_sqrt(string):
|
def _fix_sqrt(string):
|
||||||
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
|
_string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string)
|
||||||
return _string
|
return _string
|
||||||
|
|
||||||
|
|
||||||
def strip_string(string):
|
def strip_string(string):
|
||||||
string = str(string).strip()
|
string = str(string).strip()
|
||||||
# linebreaks
|
# linebreaks
|
||||||
string = string.replace("\n", "")
|
string = string.replace('\n', '')
|
||||||
|
|
||||||
# right "."
|
# right "."
|
||||||
string = string.rstrip(".")
|
string = string.rstrip('.')
|
||||||
|
|
||||||
# remove inverse spaces
|
# remove inverse spaces
|
||||||
string = string.replace("\\!", "")
|
string = string.replace('\\!', '')
|
||||||
string = string.replace("\\ ", "")
|
string = string.replace('\\ ', '')
|
||||||
|
|
||||||
# replace \\ with \
|
# replace \\ with \
|
||||||
string = string.replace("\\\\", "\\")
|
string = string.replace('\\\\', '\\')
|
||||||
string = string.replace("\\\\", "\\")
|
string = string.replace('\\\\', '\\')
|
||||||
|
|
||||||
# replace tfrac and dfrac with frac
|
# replace tfrac and dfrac with frac
|
||||||
string = string.replace("tfrac", "frac")
|
string = string.replace('tfrac', 'frac')
|
||||||
string = string.replace("dfrac", "frac")
|
string = string.replace('dfrac', 'frac')
|
||||||
|
|
||||||
# remove \left and \right
|
# remove \left and \right
|
||||||
string = string.replace("\\left", "")
|
string = string.replace('\\left', '')
|
||||||
string = string.replace("\\right", "")
|
string = string.replace('\\right', '')
|
||||||
|
|
||||||
# Remove unit: miles, dollars if after is not none
|
# Remove unit: miles, dollars if after is not none
|
||||||
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
_string = re.sub(r'\\text{.*?}$', '', string).strip()
|
||||||
if _string != "" and _string != string:
|
if _string != '' and _string != string:
|
||||||
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
||||||
string = _string
|
string = _string
|
||||||
|
|
||||||
# Remove circ (degrees)
|
# Remove circ (degrees)
|
||||||
string = string.replace("^{\\circ}", "")
|
string = string.replace('^{\\circ}', '')
|
||||||
string = string.replace("^\\circ", "")
|
string = string.replace('^\\circ', '')
|
||||||
|
|
||||||
# remove dollar signs
|
# remove dollar signs
|
||||||
string = string.replace("\\$", "")
|
string = string.replace('\\$', '')
|
||||||
string = string.replace("$", "")
|
string = string.replace('$', '')
|
||||||
|
|
||||||
string = string.replace("\\text", "")
|
string = string.replace('\\text', '')
|
||||||
string = string.replace("x\\in", "")
|
string = string.replace('x\\in', '')
|
||||||
|
|
||||||
# remove percentage
|
# remove percentage
|
||||||
string = string.replace("\\%", "")
|
string = string.replace('\\%', '')
|
||||||
string = string.replace("\%", "")
|
string = string.replace('\%', '')
|
||||||
string = string.replace("%", "")
|
string = string.replace('%', '')
|
||||||
|
|
||||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||||
string = string.replace(" .", " 0.")
|
string = string.replace(' .', ' 0.')
|
||||||
string = string.replace("{.", "{0.")
|
string = string.replace('{.', '{0.')
|
||||||
|
|
||||||
# cdot
|
# cdot
|
||||||
string = string.replace("\\cdot", "")
|
string = string.replace('\\cdot', '')
|
||||||
|
|
||||||
# inf
|
# inf
|
||||||
string = string.replace("infinity", "\\infty")
|
string = string.replace('infinity', '\\infty')
|
||||||
if "\\infty" not in string:
|
if '\\infty' not in string:
|
||||||
string = string.replace("inf", "\\infty")
|
string = string.replace('inf', '\\infty')
|
||||||
string = string.replace("+\\inity", "\\infty")
|
string = string.replace('+\\inity', '\\infty')
|
||||||
|
|
||||||
# and
|
# and
|
||||||
string = string.replace("and", "")
|
string = string.replace('and', '')
|
||||||
string = string.replace("\\mathbf", "")
|
string = string.replace('\\mathbf', '')
|
||||||
|
|
||||||
# use regex to remove \mbox{...}
|
# use regex to remove \mbox{...}
|
||||||
string = re.sub(r"\\mbox{.*?}", "", string)
|
string = re.sub(r'\\mbox{.*?}', '', string)
|
||||||
|
|
||||||
# quote
|
# quote
|
||||||
string.replace("'", "")
|
string.replace("'", '')
|
||||||
string.replace('"', "")
|
string.replace('"', '')
|
||||||
|
|
||||||
# i, j
|
# i, j
|
||||||
if "j" in string and "i" not in string:
|
if 'j' in string and 'i' not in string:
|
||||||
string = string.replace("j", "i")
|
string = string.replace('j', 'i')
|
||||||
|
|
||||||
# replace a.000b where b is not number or b is end, with ab, use regex
|
# replace a.000b where b is not number or b is end, with ab, use regex
|
||||||
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
|
string = re.sub(r'(\d+)\.0+([^\d])', r'\1\2', string)
|
||||||
string = re.sub(r"(\d+)\.0+$", r"\1", string)
|
string = re.sub(r'(\d+)\.0+$', r'\1', string)
|
||||||
|
|
||||||
# if empty, return empty string
|
# if empty, return empty string
|
||||||
if len(string) == 0:
|
if len(string) == 0:
|
||||||
return string
|
return string
|
||||||
if string[0] == ".":
|
if string[0] == '.':
|
||||||
string = "0" + string
|
string = '0' + string
|
||||||
|
|
||||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||||
if len(string.split("=")) == 2:
|
if len(string.split('=')) == 2:
|
||||||
if len(string.split("=")[0]) <= 2:
|
if len(string.split('=')[0]) <= 2:
|
||||||
string = string.split("=")[1]
|
string = string.split('=')[1]
|
||||||
|
|
||||||
string = _fix_sqrt(string)
|
string = _fix_sqrt(string)
|
||||||
string = string.replace(" ", "")
|
string = string.replace(' ', '')
|
||||||
|
|
||||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||||
string = _fix_fracs(string)
|
string = _fix_fracs(string)
|
||||||
|
@ -299,9 +283,9 @@ def strip_string(string):
|
||||||
|
|
||||||
|
|
||||||
def last_boxed_only_string(string):
|
def last_boxed_only_string(string):
|
||||||
idx = string.rfind("\\boxed")
|
idx = string.rfind('\\boxed')
|
||||||
if idx < 0:
|
if idx < 0:
|
||||||
idx = string.rfind("\\fbox")
|
idx = string.rfind('\\fbox')
|
||||||
if idx < 0:
|
if idx < 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -309,9 +293,9 @@ def last_boxed_only_string(string):
|
||||||
right_brace_idx = None
|
right_brace_idx = None
|
||||||
num_left_braces_open = 0
|
num_left_braces_open = 0
|
||||||
while i < len(string):
|
while i < len(string):
|
||||||
if string[i] == "{":
|
if string[i] == '{':
|
||||||
num_left_braces_open += 1
|
num_left_braces_open += 1
|
||||||
if string[i] == "}":
|
if string[i] == '}':
|
||||||
num_left_braces_open -= 1
|
num_left_braces_open -= 1
|
||||||
if num_left_braces_open == 0:
|
if num_left_braces_open == 0:
|
||||||
right_brace_idx = i
|
right_brace_idx = i
|
||||||
|
@ -329,13 +313,13 @@ def last_boxed_only_string(string):
|
||||||
def extract_answer(pred_str):
|
def extract_answer(pred_str):
|
||||||
if 'boxed' not in pred_str:
|
if 'boxed' not in pred_str:
|
||||||
return ''
|
return ''
|
||||||
ans = pred_str.split('boxed')[-1]
|
answer = pred_str.split('boxed')[-1]
|
||||||
if len(ans) == 0:
|
if len(answer) == 0:
|
||||||
return ''
|
return ''
|
||||||
elif (ans[0] == '{'):
|
elif (answer[0] == '{'):
|
||||||
stack = 1
|
stack = 1
|
||||||
a = ''
|
a = ''
|
||||||
for c in ans[1:]:
|
for c in answer[1:]:
|
||||||
if (c == '{'):
|
if (c == '{'):
|
||||||
stack += 1
|
stack += 1
|
||||||
a += c
|
a += c
|
||||||
|
@ -346,9 +330,9 @@ def extract_answer(pred_str):
|
||||||
else:
|
else:
|
||||||
a += c
|
a += c
|
||||||
else:
|
else:
|
||||||
a = ans.split('$')[0].strip()
|
a = answer.split('$')[0].strip()
|
||||||
|
|
||||||
pred = a.split("\n")[0]
|
pred = a.split('\n')[0]
|
||||||
if pred != '' and pred[0] == ':':
|
if pred != '' and pred[0] == ':':
|
||||||
pred = pred[1:]
|
pred = pred[1:]
|
||||||
if pred != '' and pred[-1] == '.':
|
if pred != '' and pred[-1] == '.':
|
||||||
|
@ -361,7 +345,7 @@ def extract_answer(pred_str):
|
||||||
|
|
||||||
def is_digit(s):
|
def is_digit(s):
|
||||||
try:
|
try:
|
||||||
float(str(s).replace(",", ""))
|
float(str(s).replace(',', ''))
|
||||||
return True
|
return True
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
@ -375,15 +359,15 @@ def math_equal(
|
||||||
tolerance: float = 1e-4,
|
tolerance: float = 1e-4,
|
||||||
timeout: bool = False,
|
timeout: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""Exact match of math if and only if:
|
||||||
Exact match of math if and only if:
|
|
||||||
1. numerical equal: both can convert to float and are equal
|
1. numerical equal: both can convert to float and are equal
|
||||||
2. symbolic equal: both can convert to sympy expression and are equal
|
2. symbolic equal: both can convert to sympy expression and are equal
|
||||||
"""
|
"""
|
||||||
try: # 1. numerical equal
|
try: # 1. numerical equal
|
||||||
if is_digit(prediction) and is_digit(reference):
|
if is_digit(prediction) and is_digit(reference):
|
||||||
prediction = float(str(prediction).replace(",", ""))
|
prediction = float(str(prediction).replace(',', ''))
|
||||||
reference = float(str(reference).replace(",", ""))
|
reference = float(str(reference).replace(',', ''))
|
||||||
# number questions
|
# number questions
|
||||||
if include_percentage:
|
if include_percentage:
|
||||||
gt_result = [reference / 100, reference, reference * 100]
|
gt_result = [reference / 100, reference, reference * 100]
|
||||||
|
@ -412,25 +396,25 @@ def math_equal(
|
||||||
|
|
||||||
## deal with [], (), {}
|
## deal with [], (), {}
|
||||||
pred_str, ref_str = prediction, reference
|
pred_str, ref_str = prediction, reference
|
||||||
if (prediction.startswith("[") and prediction.endswith("]")
|
if (prediction.startswith('[') and prediction.endswith(']')
|
||||||
and not reference.startswith("(")) or (
|
and not reference.startswith('(')) or (
|
||||||
prediction.startswith("(") and prediction.endswith(")")
|
prediction.startswith('(') and prediction.endswith(')')
|
||||||
and not reference.startswith("[")):
|
and not reference.startswith('[')):
|
||||||
pred_str = pred_str.strip("[]()")
|
pred_str = pred_str.strip('[]()')
|
||||||
ref_str = ref_str.strip("[]()")
|
ref_str = ref_str.strip('[]()')
|
||||||
for s in ["{", "}", "(", ")"]:
|
for s in ['{', '}', '(', ')']:
|
||||||
ref_str = ref_str.replace(s, "")
|
ref_str = ref_str.replace(s, '')
|
||||||
pred_str = pred_str.replace(s, "")
|
pred_str = pred_str.replace(s, '')
|
||||||
if pred_str == ref_str:
|
if pred_str == ref_str:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
## [a, b] vs. [c, d], return a==c and b==d
|
## [a, b] vs. [c, d], return a==c and b==d
|
||||||
if ((prediction.startswith("[") and prediction.endswith("]")) and
|
if ((prediction.startswith('[') and prediction.endswith(']')) and
|
||||||
(reference.startswith("[") and reference.endswith("]"))
|
(reference.startswith('[') and reference.endswith(']'))
|
||||||
or (prediction.startswith("(") and prediction.endswith(")")) and
|
or (prediction.startswith('(') and prediction.endswith(')')) and
|
||||||
(reference.startswith("(") and reference.endswith(")"))):
|
(reference.startswith('(') and reference.endswith(')'))):
|
||||||
pred_parts = prediction[1:-1].split(",")
|
pred_parts = prediction[1:-1].split(',')
|
||||||
ref_parts = reference[1:-1].split(",")
|
ref_parts = reference[1:-1].split(',')
|
||||||
if len(pred_parts) == len(ref_parts):
|
if len(pred_parts) == len(ref_parts):
|
||||||
if all([
|
if all([
|
||||||
math_equal(pred_parts[i], ref_parts[i], include_percentage,
|
math_equal(pred_parts[i], ref_parts[i], include_percentage,
|
||||||
|
@ -488,8 +472,9 @@ def symbolic_equal_process(a, b, output_queue):
|
||||||
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
||||||
output_queue = multiprocessing.Queue()
|
output_queue = multiprocessing.Queue()
|
||||||
process_args = args + (output_queue, )
|
process_args = args + (output_queue, )
|
||||||
process = multiprocessing.Process(
|
process = multiprocessing.Process(target=func,
|
||||||
target=func, args=process_args, kwargs=kwargs)
|
args=process_args,
|
||||||
|
kwargs=kwargs)
|
||||||
process.start()
|
process.start()
|
||||||
process.join(timeout)
|
process.join(timeout)
|
||||||
|
|
||||||
|
@ -501,22 +486,25 @@ def call_with_timeout(func, *args, timeout=1, **kwargs):
|
||||||
return output_queue.get()
|
return output_queue.get()
|
||||||
|
|
||||||
|
|
||||||
def init_agent(backend: str, model_path: str, tp=1, **kwargs):
|
def init_agent(backend: str, model_path: str, tp: int, **kwargs):
|
||||||
if backend == 'lmdeploy':
|
if backend == 'lmdeploy':
|
||||||
model = LMDeployPipeline(
|
model = LMDeployPipeline(path=model_path,
|
||||||
path=model_path, meta_template=INTERNLM2_META, tp=tp, **kwargs)
|
meta_template=INTERNLM2_META,
|
||||||
|
tp=tp,
|
||||||
|
**kwargs)
|
||||||
elif backend == 'hf':
|
elif backend == 'hf':
|
||||||
model = HFTransformer(
|
model = HFTransformer(path=model_path,
|
||||||
path=model_path, meta_template=INTERNLM2_META, **kwargs)
|
meta_template=INTERNLM2_META,
|
||||||
|
**kwargs)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
agent = Internlm2Agent(
|
agent = Internlm2Agent(llm=model,
|
||||||
llm=model,
|
protocol=Internlm2Protocol(
|
||||||
protocol=Internlm2Protocol(
|
meta_prompt=None,
|
||||||
meta_prompt=None, interpreter_prompt=DEFAULT_PROMPT),
|
interpreter_prompt=DEFAULT_PROMPT),
|
||||||
interpreter_executor=ActionExecutor(
|
interpreter_executor=ActionExecutor(
|
||||||
actions=[get_tool('IPythonInteractive')]))
|
actions=[get_tool('IPythonInteractive')]))
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
@ -533,8 +521,8 @@ def predict(args):
|
||||||
d['pred'], d['steps'], d['error'] = None, [], []
|
d['pred'], d['steps'], d['error'] = None, [], []
|
||||||
return d
|
return d
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset('gsm8k', 'main',
|
||||||
'gsm8k', 'main', split='test').map(process, True)
|
split='test').map(process, True)
|
||||||
|
|
||||||
elif args.dataset == 'math':
|
elif args.dataset == 'math':
|
||||||
|
|
||||||
|
@ -554,8 +542,8 @@ def predict(args):
|
||||||
d['error'] = None
|
d['error'] = None
|
||||||
return d
|
return d
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset('lighteval/MATH',
|
||||||
"lighteval/MATH", split='test').map(process, True)
|
split='test').map(process, True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -570,15 +558,11 @@ def predict(args):
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
)
|
)
|
||||||
gen_kwargs = {
|
|
||||||
'max_new_tokens': args.max_tokens
|
|
||||||
} if args.backend == 'hf' else {}
|
|
||||||
|
|
||||||
with jsonlines.open(args.output_path, 'w') as f:
|
with jsonlines.open(args.output_path, 'w') as f:
|
||||||
for item in tqdm(
|
for item in tqdm(
|
||||||
dataset if not args.debug else dataset.select(range(50))):
|
dataset if not args.debug else dataset.select(range(50))):
|
||||||
try:
|
try:
|
||||||
ret = agent.chat(item['query'], **gen_kwargs)
|
ret = agent.chat(item['query'])
|
||||||
item['steps'] = ret.inner_steps
|
item['steps'] = ret.inner_steps
|
||||||
|
|
||||||
lang = [
|
lang = [
|
||||||
|
@ -607,7 +591,7 @@ def evaluate(args):
|
||||||
timeout=20,
|
timeout=20,
|
||||||
)
|
)
|
||||||
iterator = future.result()
|
iterator = future.result()
|
||||||
with tqdm(total=len(samples), desc="Evaluate") as progress_bar:
|
with tqdm(total=len(samples), desc='Evaluate') as progress_bar:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
result = next(iterator)
|
result = next(iterator)
|
||||||
|
@ -641,15 +625,15 @@ def evaluate(args):
|
||||||
col_means = np.array(score_mat).mean(axis=0)
|
col_means = np.array(score_mat).mean(axis=0)
|
||||||
mean_score = list(np.round(col_means * 100, decimals=1))
|
mean_score = list(np.round(col_means * 100, decimals=1))
|
||||||
|
|
||||||
result_str = f"Num samples: {len(samples)}\n" \
|
result_str = f'Num samples: {len(samples)}\n' \
|
||||||
f"Num scores: {len(scores)}\n" \
|
f'Num scores: {len(scores)}\n' \
|
||||||
f"Sum scores: {sum(scores)}\n" \
|
f'Sum scores: {sum(scores)}\n' \
|
||||||
f"Timeout samples: {timeout_cnt}\n" \
|
f'Timeout samples: {timeout_cnt}\n' \
|
||||||
f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \
|
f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \
|
||||||
f"Mean score: {mean_score}\n"
|
f'Mean score: {mean_score}\n'
|
||||||
|
|
||||||
# each type score
|
# each type score
|
||||||
if "type" in samples[0]:
|
if 'type' in samples[0]:
|
||||||
type_scores = {}
|
type_scores = {}
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
if sample['type'] not in type_scores:
|
if sample['type'] not in type_scores:
|
||||||
|
@ -663,7 +647,7 @@ def evaluate(args):
|
||||||
k: v
|
k: v
|
||||||
for k, v in sorted(type_scores.items(), key=lambda item: item[0])
|
for k, v in sorted(type_scores.items(), key=lambda item: item[0])
|
||||||
}
|
}
|
||||||
result_str += f"Type scores: {type_scores}\n"
|
result_str += f'Type scores: {type_scores}\n'
|
||||||
|
|
||||||
print(result_str)
|
print(result_str)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue