|
|
|
@ -24,6 +24,7 @@ for k, v in inputs.items():
|
|
|
|
|
new_shape[0] = 16 |
|
|
|
|
inputs[k] = v.to("cuda").repeat(*new_shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): |
|
|
|
|
model = transformers.LlamaForCausalLM( |
|
|
|
|
transformers.LlamaConfig( |
|
|
|
@ -58,7 +59,6 @@ def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_si
|
|
|
|
|
@parameterize("pp_size", [2]) |
|
|
|
|
@parameterize("max_output_len", [4]) |
|
|
|
|
@parameterize("micro_batch_size", [1]) |
|
|
|
|
|
|
|
|
|
@clear_cache_before_run() |
|
|
|
|
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): |
|
|
|
|
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) |
|
|
|
@ -76,7 +76,6 @@ def check_tp_pipeline_inference(rank, world_size, port):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") |
|
|
|
|
|
|
|
|
|
@pytest.mark.dist |
|
|
|
|
@rerun_if_address_is_in_use() |
|
|
|
|
@clear_cache_before_run() |
|
|
|
|