diff --git a/tests/test_hf_model.py b/tests/test_hf_model.py index debef03..15c1c1e 100644 --- a/tests/test_hf_model.py +++ b/tests/test_hf_model.py @@ -213,12 +213,14 @@ class TestReward: score2 = model.get_score(tokenizer, chat_2) print('score1: ', score1) print('score2: ', score2) - assert score1.startswith('0.') & score2.startswith('-') + assert score1 > 0 + assert score2 < 0 # batch inference, get multiple scores at once scores = model.get_scores(tokenizer, [chat_1, chat_2]) print('scores: ', scores) - assert scores[0].startswith('0.') & scores[1].startswith('-') + assert scores[0] > 0 + assert scores[1] < 0 # compare whether chat_1 is better than chat_2 compare_res = model.compare(tokenizer, chat_1, chat_2)