mirror of https://github.com/hpcaitech/ColossalAI
[format] applied code formatting on changed files in pull request 4926 (#5007)
Co-authored-by: github-actions <github-actions@github.com>pull/5024/head
parent
1a3315e336
commit
c36e782d80
|
@ -218,7 +218,7 @@ class TPInferEngine:
|
||||||
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
|
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
|
||||||
model_name = model.__class__.__name__
|
model_name = model.__class__.__name__
|
||||||
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
|
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
|
||||||
|
|
||||||
model = model.model if self.shard_config.inference_gptq else model
|
model = model.model if self.shard_config.inference_gptq else model
|
||||||
policy = get_autopolicy(model, shard_config=self.shard_config)
|
policy = get_autopolicy(model, shard_config=self.shard_config)
|
||||||
|
|
||||||
|
@ -311,7 +311,7 @@ class TPInferEngine:
|
||||||
seq_start_indexes[i] = start_index
|
seq_start_indexes[i] = start_index
|
||||||
start_index += curr_seq_len
|
start_index += curr_seq_len
|
||||||
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
|
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
|
||||||
|
|
||||||
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
|
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
|
||||||
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
|
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
|
||||||
batch_infer_state.seq_len = seq_lengths.to("cuda")
|
batch_infer_state.seq_len = seq_lengths.to("cuda")
|
||||||
|
|
|
@ -24,6 +24,7 @@ for k, v in inputs.items():
|
||||||
new_shape[0] = 16
|
new_shape[0] = 16
|
||||||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||||
model = transformers.LlamaForCausalLM(
|
model = transformers.LlamaForCausalLM(
|
||||||
transformers.LlamaConfig(
|
transformers.LlamaConfig(
|
||||||
|
@ -58,7 +59,6 @@ def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_si
|
||||||
@parameterize("pp_size", [2])
|
@parameterize("pp_size", [2])
|
||||||
@parameterize("max_output_len", [4])
|
@parameterize("max_output_len", [4])
|
||||||
@parameterize("micro_batch_size", [1])
|
@parameterize("micro_batch_size", [1])
|
||||||
|
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||||
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
|
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
|
||||||
|
@ -76,7 +76,6 @@ def check_tp_pipeline_inference(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
Loading…
Reference in New Issue