|
|
@ -14,7 +14,6 @@ from colossalai.inference.core.engine import InferenceEngine |
|
|
|
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy |
|
|
|
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy |
|
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn |
|
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn |
|
|
|
|
|
|
|
|
|
|
|
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base" |
|
|
|
|
|
|
|
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base" |
|
|
|
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -87,7 +86,7 @@ def run_engine(world_size, **kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): |
|
|
|
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): |
|
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") |
|
|
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") |
|
|
|
|
|
|
|
|
|
|
|
if ret: |
|
|
|
if ret: |
|
|
|
ret[rank] = func_to_run(**kwargs) |
|
|
|
ret[rank] = func_to_run(**kwargs) |
|
|
@ -99,7 +98,7 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): |
|
|
|
@parameterize("prompt_template", [None, "baichuan"]) |
|
|
|
@parameterize("prompt_template", [None, "baichuan"]) |
|
|
|
@parameterize("do_sample", [False]) |
|
|
|
@parameterize("do_sample", [False]) |
|
|
|
@parameterize("use_cuda_kernel", [True]) |
|
|
|
@parameterize("use_cuda_kernel", [True]) |
|
|
|
def test_tp_engine(prompt_template, do_sample, use_cuda_kernel): |
|
|
|
def check_tp_engine(prompt_template, do_sample, use_cuda_kernel): |
|
|
|
kwargs1 = { |
|
|
|
kwargs1 = { |
|
|
|
"use_engine": True, |
|
|
|
"use_engine": True, |
|
|
|
"prompt_template": prompt_template, |
|
|
|
"prompt_template": prompt_template, |
|
|
@ -132,7 +131,7 @@ def test_tp_engine(prompt_template, do_sample, use_cuda_kernel): |
|
|
|
@pytest.mark.dist |
|
|
|
@pytest.mark.dist |
|
|
|
@rerun_if_address_is_in_use() |
|
|
|
@rerun_if_address_is_in_use() |
|
|
|
def test_inference_engine(): |
|
|
|
def test_inference_engine(): |
|
|
|
test_tp_engine() |
|
|
|
check_tp_engine() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if __name__ == "__main__": |
|
|
|