mirror of https://github.com/InternLM/InternLM
81 lines
2.8 KiB
Python
81 lines
2.8 KiB
Python
import pytest
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
prompts = ['你好', "what's your name"]
|
|
|
|
|
|
def assert_model(response):
|
|
assert len(response) != 0
|
|
assert 'UNUSED_TOKEN' not in response
|
|
|
|
|
|
class TestChat:
|
|
"""Test cases for chat model."""
|
|
|
|
@pytest.mark.parametrize(
|
|
'model_name',
|
|
[
|
|
'internlm/internlm2-chat-7b',
|
|
'internlm/internlm2-chat-7b-sft',
|
|
],
|
|
)
|
|
def test_demo_default(self, model_name):
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
|
trust_remote_code=True)
|
|
# Set `torch_dtype=torch.float16` to load model in float16, otherwise
|
|
# it will be loaded as float32 and might cause OOM Error.
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name, torch_dtype=torch.float16,
|
|
trust_remote_code=True).cuda()
|
|
model = model.eval()
|
|
for prompt in prompts:
|
|
response, history = model.chat(tokenizer, prompt, history=[])
|
|
print(response)
|
|
assert_model(response)
|
|
|
|
for prompt in prompts:
|
|
length = 0
|
|
for response, history in model.stream_chat(tokenizer,
|
|
prompt,
|
|
history=[]):
|
|
print(response[length:], flush=True, end='')
|
|
length = len(response)
|
|
assert_model(response)
|
|
|
|
|
|
class TestBase:
|
|
"""Test cases for base model."""
|
|
|
|
@pytest.mark.parametrize(
|
|
'model_name',
|
|
[
|
|
'internlm/internlm2-7b',
|
|
'internlm/internlm2-base-7b',
|
|
],
|
|
)
|
|
def test_demo_default(self, model_name):
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
|
trust_remote_code=True)
|
|
# Set `torch_dtype=torch.float16` to load model in float16, otherwise
|
|
# it will be loaded as float32 and might cause OOM Error.
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name, torch_dtype=torch.float16,
|
|
trust_remote_code=True).cuda()
|
|
for prompt in prompts:
|
|
inputs = tokenizer(prompt, return_tensors='pt')
|
|
for k, v in inputs.items():
|
|
inputs[k] = v.cuda()
|
|
gen_kwargs = {
|
|
'max_length': 128,
|
|
'top_p': 10,
|
|
'temperature': 1.0,
|
|
'do_sample': True,
|
|
'repetition_penalty': 1.0,
|
|
}
|
|
output = model.generate(**inputs, **gen_kwargs)
|
|
output = tokenizer.decode(output[0].tolist(),
|
|
skip_special_tokens=True)
|
|
print(output)
|
|
assert_model(output)
|