mirror of https://github.com/InternLM/InternLM
update
parent
f36805270c
commit
a99d681d63
|
@ -160,6 +160,7 @@ class TestMath:
|
||||||
assert_model(response)
|
assert_model(response)
|
||||||
assert '2' in response
|
assert '2' in response
|
||||||
|
|
||||||
|
|
||||||
class TestReward:
|
class TestReward:
|
||||||
"""Test cases for base model."""
|
"""Test cases for base model."""
|
||||||
|
|
||||||
|
@ -181,46 +182,57 @@ class TestReward:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
use_fast=usefast)
|
use_fast=usefast)
|
||||||
model = AutoModel.from_pretrained(model_name, device_map="cuda",
|
model = AutoModel.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
device_map='cuda',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
trust_remote_code=True,)
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
|
|
||||||
chat_1 = [
|
chat_1 = [{
|
||||||
{"role": "user", "content": "Hello! What's your name?"},
|
'role': 'user',
|
||||||
{"role": "assistant", "content": "My name is InternLM2! A helpful AI assistant. What can I do for you?"}
|
'content': "Hello! What's your name?"
|
||||||
]
|
}, {
|
||||||
chat_2 = [
|
'role':
|
||||||
{"role": "user", "content": "Hello! What's your name?"},
|
'assistant',
|
||||||
{"role": "assistant", "content": "I have no idea."}
|
'content':
|
||||||
]
|
'I am InternLM2! A helpful AI assistant. What can I do for you?'
|
||||||
|
}]
|
||||||
|
chat_2 = [{
|
||||||
|
'role': 'user',
|
||||||
|
'content': "Hello! What's your name?"
|
||||||
|
}, {
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': 'I have no idea.'
|
||||||
|
}]
|
||||||
|
|
||||||
# get reward score for a single chat
|
# get reward score for a single chat
|
||||||
score1 = model.get_score(tokenizer, chat_1)
|
score1 = model.get_score(tokenizer, chat_1)
|
||||||
score2 = model.get_score(tokenizer, chat_2)
|
score2 = model.get_score(tokenizer, chat_2)
|
||||||
print("score1: ", score1)
|
print('score1: ', score1)
|
||||||
print("score2: ", score2)
|
print('score2: ', score2)
|
||||||
assert score1 > 0.5 && score1 < 1 && score2 < 0
|
assert score1 > 0.5 & score1 < 1 & score2 < 0
|
||||||
|
|
||||||
# batch inference, get multiple scores at once
|
# batch inference, get multiple scores at once
|
||||||
scores = model.get_scores(tokenizer, [chat_1, chat_2])
|
scores = model.get_scores(tokenizer, [chat_1, chat_2])
|
||||||
print("scores: ", scores)
|
print('scores: ', scores)
|
||||||
assert scores[0] > 0.5 && scores[0] < 1 && scores[1] < 0
|
assert scores[0] > 0.5 & scores[0] < 1 & scores[1] < 0
|
||||||
|
|
||||||
# compare whether chat_1 is better than chat_2
|
# compare whether chat_1 is better than chat_2
|
||||||
compare_res = model.compare(tokenizer, chat_1, chat_2)
|
compare_res = model.compare(tokenizer, chat_1, chat_2)
|
||||||
print("compare_res: ", compare_res)
|
print('compare_res: ', compare_res)
|
||||||
assert compare_res
|
assert compare_res
|
||||||
# >>> compare_res: True
|
# >>> compare_res: True
|
||||||
|
|
||||||
# rank multiple chats, it will return the ranking index of each chat
|
# rank multiple chats, it will return the ranking index of each chat
|
||||||
# the chat with the highest score will have ranking index as 0
|
# the chat with the highest score will have ranking index as 0
|
||||||
rank_res = model.rank(tokenizer, [chat_1, chat_2])
|
rank_res = model.rank(tokenizer, [chat_1, chat_2])
|
||||||
print("rank_res: ", rank_res) # lower index means higher score
|
print('rank_res: ', rank_res) # lower index means higher score
|
||||||
# >>> rank_res: [0, 1]
|
# >>> rank_res: [0, 1]
|
||||||
assert rank_res[0] == 0 && rank_res[1] == 1
|
assert rank_res[0] == 0 & rank_res[1] == 1
|
||||||
)
|
|
||||||
|
|
||||||
class TestMMModel:
|
class TestMMModel:
|
||||||
"""Test cases for base model."""
|
"""Test cases for base model."""
|
||||||
|
|
Loading…
Reference in New Issue