add more models into daily testcases (#717)

Co-authored-by: zhulin1 <zhulin1@pjlab.org.cn>
pull/721/head
zhulinJulia24 2024-03-06 11:06:29 +08:00 committed by GitHub
parent bd57ff3ce7
commit 43b7582201
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 60 additions and 9 deletions

View File

@ -27,7 +27,7 @@ jobs:
conda create -n internlm-model-latest --clone ${CONDA_BASE_ENV} conda create -n internlm-model-latest --clone ${CONDA_BASE_ENV}
source activate internlm-model-latest source activate internlm-model-latest
pip install transformers==${{ matrix.transformers-version }} pip install transformers==${{ matrix.transformers-version }}
pip install sentencepiece auto-gptq pip install sentencepiece auto-gptq==0.6.0 lmdeploy[all]
srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --gpus-per-task=2 pytest -s -v --color=yes ./tests/test_hf_model.py srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --gpus-per-task=2 pytest -s -v --color=yes ./tests/test_hf_model.py
conda deactivate conda deactivate
- name: load_latest_hf_model - name: load_latest_hf_model
@ -36,7 +36,7 @@ jobs:
conda create -n internlm-model-latest --clone ${CONDA_BASE_ENV} conda create -n internlm-model-latest --clone ${CONDA_BASE_ENV}
source activate internlm-model-latest source activate internlm-model-latest
pip install transformers pip install transformers
pip install sentencepiece auto-gptq pip install sentencepiece auto-gptq==0.6.0 lmdeploy[all]
srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --gpus-per-task=2 pytest -s -v --color=yes ./tests/test_hf_model.py srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --gpus-per-task=2 pytest -s -v --color=yes ./tests/test_hf_model.py
conda deactivate conda deactivate
- name: remove_env - name: remove_env

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.8 KiB

After

Width:  |  Height:  |  Size: 45 KiB

View File

@ -1,5 +1,7 @@
import pytest import pytest
import torch import torch
from auto_gptq.modeling import BaseGPTQForCausalLM
from lmdeploy import TurbomindEngineConfig, pipeline
from PIL import Image from PIL import Image
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
@ -9,6 +11,8 @@ prompts = ['你好', "what's your name"]
def assert_model(response): def assert_model(response):
assert len(response) != 0 assert len(response) != 0
assert 'UNUSED_TOKEN' not in response assert 'UNUSED_TOKEN' not in response
assert 'Mynameis' not in response
assert 'Iama' not in response
class TestChat: class TestChat:
@ -18,6 +22,7 @@ class TestChat:
'model_name', 'model_name',
[ [
'internlm/internlm2-chat-7b', 'internlm/internlm2-chat-7b-sft', 'internlm/internlm2-chat-7b', 'internlm/internlm2-chat-7b-sft',
'internlm/internlm2-chat-20b', 'internlm/internlm2-chat-20b-sft',
'internlm/internlm2-chat-1_8b', 'internlm/internlm2-chat-1_8b-sft' 'internlm/internlm2-chat-1_8b', 'internlm/internlm2-chat-1_8b-sft'
], ],
) )
@ -55,6 +60,23 @@ class TestChat:
assert_model(response) assert_model(response)
class TestChatAwq:
"""Test cases for chat model."""
@pytest.mark.parametrize(
'model_name',
['internlm/internlm2-chat-20b-4bits'],
)
def test_demo_default(self, model_name):
engine_config = TurbomindEngineConfig(model_format='awq')
pipe = pipeline('internlm/internlm2-chat-20b-4bits',
backend_config=engine_config)
responses = pipe(['Hi, pls intro yourself', 'Shanghai is'])
print(responses)
for response in responses:
assert_model(response.text)
class TestBase: class TestBase:
"""Test cases for base model.""" """Test cases for base model."""
@ -62,6 +84,7 @@ class TestBase:
'model_name', 'model_name',
[ [
'internlm/internlm2-7b', 'internlm/internlm2-base-7b', 'internlm/internlm2-7b', 'internlm/internlm2-base-7b',
'internlm/internlm2-20b', 'internlm/internlm2-base-20b',
'internlm/internlm2-1_8b' 'internlm/internlm2-1_8b'
], ],
) )
@ -140,6 +163,7 @@ class TestMMModel:
'model_name', 'model_name',
[ [
'internlm/internlm-xcomposer2-7b', 'internlm/internlm-xcomposer2-7b',
'internlm/internlm-xcomposer2-7b-4bit'
], ],
) )
def test_demo_default(self, model_name): def test_demo_default(self, model_name):
@ -148,12 +172,16 @@ class TestMMModel:
# Set `torch_dtype=torch.float16` to load model in float16, otherwise # Set `torch_dtype=torch.float16` to load model in float16, otherwise
# it will be loaded as float32 and might cause OOM Error. # it will be loaded as float32 and might cause OOM Error.
model = AutoModelForCausalLM.from_pretrained( if '4bit' in model_name:
model_name, torch_dtype=torch.float32, model = InternLMXComposer2QForCausalLM.from_quantized(
trust_remote_code=True).cuda() model_name, trust_remote_code=True, device='cuda:0').eval()
else:
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float32,
trust_remote_code=True).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True) trust_remote_code=True)
model = model.eval() model = model.eval()
img_path_list = [ img_path_list = [
'tests/panda.jpg', 'tests/panda.jpg',
@ -175,7 +203,7 @@ class TestMMModel:
do_sample=False) do_sample=False)
print(response) print(response)
assert len(response) != 0 assert len(response) != 0
assert 'panda' in response assert ' panda' in response
query = '<ImageHere> <ImageHere>请根据图片写一篇作文:我最喜欢的小动物。' \ query = '<ImageHere> <ImageHere>请根据图片写一篇作文:我最喜欢的小动物。' \
+ '要求:选准角度,确定立意,明确文体,自拟标题。' + '要求:选准角度,确定立意,明确文体,自拟标题。'
@ -197,6 +225,7 @@ class TestMMVlModel:
'model_name', 'model_name',
[ [
'internlm/internlm-xcomposer2-vl-7b', 'internlm/internlm-xcomposer2-vl-7b',
'internlm/internlm-xcomposer2-vl-7b-4bit'
], ],
) )
def test_demo_default(self, model_name): def test_demo_default(self, model_name):
@ -206,8 +235,13 @@ class TestMMVlModel:
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
# init model and tokenizer # init model and tokenizer
model = AutoModel.from_pretrained( if '4bit' in model_name:
model_name, trust_remote_code=True).cuda().eval() model = InternLMXComposer2QForCausalLM.from_quantized(
model_name, trust_remote_code=True, device='cuda:0').eval()
else:
model = AutoModel.from_pretrained(
model_name, trust_remote_code=True).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True) trust_remote_code=True)
@ -223,3 +257,20 @@ class TestMMVlModel:
assert len(response) != 0 assert len(response) != 0
assert 'Oscar Wilde' in response assert 'Oscar Wilde' in response
assert 'Live life with no excuses, travel with no regret' in response assert 'Live life with no excuses, travel with no regret' in response
class InternLMXComposer2QForCausalLM(BaseGPTQForCausalLM):
layers_block_name = 'model.layers'
outside_layer_modules = [
'vit',
'vision_proj',
'model.tok_embeddings',
'model.norm',
'output',
]
inside_layer_modules = [
['attention.wqkv.linear'],
['attention.wo.linear'],
['feed_forward.w1.linear', 'feed_forward.w3.linear'],
['feed_forward.w2.linear'],
]