diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 9e6d1066e..da19c7d22 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -5,7 +5,9 @@ from .reward_utils import extract_solution, validate_response_structure def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] - reward = torch.tensor(0.0).to(input_ids.device) + reward = torch.tensor(0.0) + format_reward = torch.tensor(0.0) + acc_reward = torch.tensor(0.0) s, e = response_idx[0], response_idx[1] if gt_answer is None: return reward @@ -15,13 +17,21 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): final_answer, processed_str = extract_solution(decoded_final_answer) format_valid = validate_response_structure(processed_str, kwargs["tags"]) - if not format_valid: - return reward - else: + + # Check format accuracy + if format_valid: + format_reward += 1.0 reward += 1.0 - if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): - reward = reward + 2.0 - return reward + + # Check answer accuracy + if ( + final_answer is not None + and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() + ): + acc_reward += 5.0 + reward += 5.0 + + return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device) def gsm8k_reward_fn(input_ids, **kwargs):