InternLM/tests/test_hf_model.py

80 lines
2.5 KiB
Python
Raw Normal View History

import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
prompts = ["你好", "what's your name"]
def assert_model(response):
assert len(response) != 0
assert "UNUSED_TOKEN" not in response
class TestChat:
"""
Test cases for chat model.
"""
@pytest.mark.parametrize(
"model_name",
[
"internlm/internlm2-chat-7b",
"internlm/internlm2-chat-7b-sft",
],
)
def test_demo_default(self, model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Set `torch_dtype=torch.float16` to load model in float16, otherwise
# it will be loaded as float32 and might cause OOM Error.
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, trust_remote_code=True
).cuda()
model = model.eval()
for prompt in prompts:
response, history = model.chat(tokenizer, prompt, history=[])
print(response)
assert_model(response)
for prompt in prompts:
length = 0
for response, history in model.stream_chat(tokenizer, prompt, history=[]):
print(response[length:], flush=True, end="")
length = len(response)
assert_model(response)
class TestBase:
"""
Test cases for base model.
"""
@pytest.mark.parametrize(
"model_name",
[
"internlm/internlm2-7b",
"internlm/internlm2-base-7b",
],
)
def test_demo_default(self, model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Set `torch_dtype=torch.float16` to load model in float16, otherwise
# it will be loaded as float32 and might cause OOM Error.
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, trust_remote_code=True
).cuda()
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt")
for k, v in inputs.items():
inputs[k] = v.cuda()
gen_kwargs = {
"max_length": 128,
"top_p": 10,
"temperature": 1.0,
"do_sample": True,
"repetition_penalty": 1.0,
}
output = model.generate(**inputs, **gen_kwargs)
output = tokenizer.decode(output[0].tolist(), skip_special_tokens=True)
print(output)
assert_model(output)