From d474c10daede0c38d4a5307c38bf036aeed983c0 Mon Sep 17 00:00:00 2001 From: zhulin1 Date: Mon, 1 Jul 2024 17:30:17 +0800 Subject: [PATCH] update --- tests/test_hf_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_hf_model.py b/tests/test_hf_model.py index 15c1c1e..566f3b6 100644 --- a/tests/test_hf_model.py +++ b/tests/test_hf_model.py @@ -219,8 +219,8 @@ class TestReward: # batch inference, get multiple scores at once scores = model.get_scores(tokenizer, [chat_1, chat_2]) print('scores: ', scores) - assert scores[0] > 0 - assert scores[1] < 0 + assert scores[0][0] > 0 + assert scores[1][0] < 0 # compare whether chat_1 is better than chat_2 compare_res = model.compare(tokenizer, chat_1, chat_2) @@ -313,6 +313,7 @@ class TestReward: # print the best response best_response = sorted_candidates[0][1][-1]['content'] + print(sorted_candidates) print(best_response) assert len(sorted_candidates) == 3