pull/717/head
zhulin1 2024-02-29 11:33:52 +08:00
parent 3002cbd265
commit 4a61dc799e
1 changed files with 56 additions and 7 deletions

View File

@ -1,5 +1,7 @@
import pytest import pytest
import torch import torch
from auto_gptq.modeling import BaseGPTQForCausalLM
from lmdeploy import TurbomindEngineConfig, pipeline
from PIL import Image from PIL import Image
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
@ -20,6 +22,7 @@ class TestChat:
'model_name', 'model_name',
[ [
'internlm/internlm2-chat-7b', 'internlm/internlm2-chat-7b-sft', 'internlm/internlm2-chat-7b', 'internlm/internlm2-chat-7b-sft',
'internlm/internlm2-chat-20b', 'internlm/internlm2-chat-20b-sft',
'internlm/internlm2-chat-1_8b', 'internlm/internlm2-chat-1_8b-sft' 'internlm/internlm2-chat-1_8b', 'internlm/internlm2-chat-1_8b-sft'
], ],
) )
@ -57,6 +60,23 @@ class TestChat:
assert_model(response) assert_model(response)
class TestChatAwq:
"""Test cases for chat model."""
@pytest.mark.parametrize(
'model_name',
['internlm/internlm2-chat-20b-4bits'],
)
def test_demo_default(self, model_name):
engine_config = TurbomindEngineConfig(model_format='awq')
pipe = pipeline('internlm/internlm2-chat-20b-4bits',
backend_config=engine_config)
responses = pipe(['Hi, pls intro yourself', 'Shanghai is'])
print(responses)
for response in responses:
assert_model(response.text)
class TestBase: class TestBase:
"""Test cases for base model.""" """Test cases for base model."""
@ -64,6 +84,7 @@ class TestBase:
'model_name', 'model_name',
[ [
'internlm/internlm2-7b', 'internlm/internlm2-base-7b', 'internlm/internlm2-7b', 'internlm/internlm2-base-7b',
'internlm/internlm2-20b', 'internlm/internlm2-base-20b',
'internlm/internlm2-1_8b' 'internlm/internlm2-1_8b'
], ],
) )
@ -142,6 +163,7 @@ class TestMMModel:
'model_name', 'model_name',
[ [
'internlm/internlm-xcomposer2-7b', 'internlm/internlm-xcomposer2-7b',
'internlm/internlm-xcomposer2-7b-4bit'
], ],
) )
def test_demo_default(self, model_name): def test_demo_default(self, model_name):
@ -150,12 +172,16 @@ class TestMMModel:
# Set `torch_dtype=torch.float16` to load model in float16, otherwise # Set `torch_dtype=torch.float16` to load model in float16, otherwise
# it will be loaded as float32 and might cause OOM Error. # it will be loaded as float32 and might cause OOM Error.
model = AutoModelForCausalLM.from_pretrained( if '4bit' in model_name:
model_name, torch_dtype=torch.float32, model = InternLMXComposer2QForCausalLM.from_quantized(
trust_remote_code=True).cuda() model_name, trust_remote_code=True, device='cuda:0').eval()
else:
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16,
trust_remote_code=True).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True) trust_remote_code=True)
model = model.eval() model = model.eval()
img_path_list = [ img_path_list = [
'tests/panda.jpg', 'tests/panda.jpg',
@ -177,7 +203,7 @@ class TestMMModel:
do_sample=False) do_sample=False)
print(response) print(response)
assert len(response) != 0 assert len(response) != 0
assert 'panda' in response assert ' panda ' in response
query = '<ImageHere> <ImageHere>请根据图片写一篇作文:我最喜欢的小动物。' \ query = '<ImageHere> <ImageHere>请根据图片写一篇作文:我最喜欢的小动物。' \
+ '要求:选准角度,确定立意,明确文体,自拟标题。' + '要求:选准角度,确定立意,明确文体,自拟标题。'
@ -199,6 +225,7 @@ class TestMMVlModel:
'model_name', 'model_name',
[ [
'internlm/internlm-xcomposer2-vl-7b', 'internlm/internlm-xcomposer2-vl-7b',
'internlm/internlm-xcomposer2-vl-7b-4bit'
], ],
) )
def test_demo_default(self, model_name): def test_demo_default(self, model_name):
@ -208,8 +235,13 @@ class TestMMVlModel:
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
# init model and tokenizer # init model and tokenizer
model = AutoModel.from_pretrained( if '4bit' in model_name:
model_name, trust_remote_code=True).cuda().eval() model = InternLMXComposer2QForCausalLM.from_quantized(
model_name, trust_remote_code=True, device='cuda:0').eval()
else:
model = AutoModel.from_pretrained(
model_name, trust_remote_code=True).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True) trust_remote_code=True)
@ -225,3 +257,20 @@ class TestMMVlModel:
assert len(response) != 0 assert len(response) != 0
assert 'Oscar Wilde' in response assert 'Oscar Wilde' in response
assert 'Live life with no excuses, travel with no regret' in response assert 'Live life with no excuses, travel with no regret' in response
class InternLMXComposer2QForCausalLM(BaseGPTQForCausalLM):
layers_block_name = 'model.layers'
outside_layer_modules = [
'vit',
'vision_proj',
'model.tok_embeddings',
'model.norm',
'output',
]
inside_layer_modules = [
['attention.wqkv.linear'],
['attention.wo.linear'],
['feed_forward.w1.linear', 'feed_forward.w3.linear'],
['feed_forward.w2.linear'],
]