mirror of https://github.com/hpcaitech/ColossalAI
update reward fn
parent
9d9d51614e
commit
754b16dfbf
|
@ -5,7 +5,9 @@ from .reward_utils import extract_solution, validate_response_structure
|
|||
|
||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
reward = torch.tensor(0.0).to(input_ids.device)
|
||||
reward = torch.tensor(0.0)
|
||||
format_reward = torch.tensor(0.0)
|
||||
acc_reward = torch.tensor(0.0)
|
||||
s, e = response_idx[0], response_idx[1]
|
||||
if gt_answer is None:
|
||||
return reward
|
||||
|
@ -15,13 +17,21 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||
|
||||
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
||||
if not format_valid:
|
||||
return reward
|
||||
else:
|
||||
|
||||
# Check format accuracy
|
||||
if format_valid:
|
||||
format_reward += 1.0
|
||||
reward += 1.0
|
||||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
||||
reward = reward + 2.0
|
||||
return reward
|
||||
|
||||
# Check answer accuracy
|
||||
if (
|
||||
final_answer is not None
|
||||
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
|
||||
):
|
||||
acc_reward += 5.0
|
||||
reward += 5.0
|
||||
|
||||
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
|
||||
|
||||
|
||||
def gsm8k_reward_fn(input_ids, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue