diff --git a/tests/test_hf_model.py b/tests/test_hf_model.py index 566f3b6..5429d3b 100644 --- a/tests/test_hf_model.py +++ b/tests/test_hf_model.py @@ -219,8 +219,8 @@ class TestReward: # batch inference, get multiple scores at once scores = model.get_scores(tokenizer, [chat_1, chat_2]) print('scores: ', scores) - assert scores[0][0] > 0 - assert scores[1][0] < 0 + assert scores[0] > 0 + assert scores[1] < 0 # compare whether chat_1 is better than chat_2 compare_res = model.compare(tokenizer, chat_1, chat_2)