[hotfix] Suport extra_kwargs in ShardConfig (#5031)

* [refactor]: replace inference args with extra_kwargs in ShardConfig

* modify shardconfig

* polish code

* fix policy bug in llama

* fix bug in auto policy

* remove setattr in ShardConfig
pull/4836/head^2
Zhongkai Zhao 2023-11-10 10:49:50 +08:00 committed by GitHub
parent 576a2f7b10
commit 70885d707d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 98 additions and 77 deletions

View File

@ -67,7 +67,9 @@ class Worker:
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
) )
shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) shard_config = ShardConfig(
enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
)
self.infer_engine = TPInferEngine( self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
) )

View File

@ -45,8 +45,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.extra_kwargs.get("inference_gptq", False):
if self.shard_config.inference_gptq:
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
decoder_attribute_replacement = { decoder_attribute_replacement = {

View File

@ -44,7 +44,7 @@ class TPInferEngine:
>>> # define model and shard config for your inference >>> # define model and shard config for your inference
>>> model = ... >>> model = ...
>>> generate_kwargs = ... >>> generate_kwargs = ...
>>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) >>> shard_config = ShardConfig(enable_tensor_parallelism=True, extra_kwargs={"inference_only": True})
>>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) >>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs) >>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
""" """
@ -181,7 +181,7 @@ class TPInferEngine:
In further generation, use the sharded model instead of original model. In further generation, use the sharded model instead of original model.
""" """
# NOTE we will change to use an inference config later with additional attrs we want # NOTE we will change to use an inference config later with additional attrs we want
assert self.shard_config.inference_only is True assert self.shard_config.extra_kwargs["inference_only"] is True
shardformer = ShardFormer(shard_config=self.shard_config) shardformer = ShardFormer(shard_config=self.shard_config)
self._prepare_with_shard_config(shard_config=self.shard_config) self._prepare_with_shard_config(shard_config=self.shard_config)
self._shard_model_by(shardformer, model) self._shard_model_by(shardformer, model)
@ -203,10 +203,10 @@ class TPInferEngine:
enable_all_optimization=False, enable_all_optimization=False,
enable_flash_attention=False, enable_flash_attention=False,
enable_jit_fused=False, enable_jit_fused=False,
inference_only=True, extra_kwargs={"inference_only": True},
) )
else: else:
shard_config.inference_only = True shard_config.extra_kwargs = {"inference_only": True}
shard_config.pipeline_stage_manager = None shard_config.pipeline_stage_manager = None
if shard_config.enable_tensor_parallelism: if shard_config.enable_tensor_parallelism:
self.tp_size = shard_config.tensor_parallel_size self.tp_size = shard_config.tensor_parallel_size
@ -221,13 +221,11 @@ class TPInferEngine:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__ model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
if self.shard_config.extra_kwargs.get("inference_gptq", False):
model = model.model if self.shard_config.inference_gptq else model model = model.model
policy = get_autopolicy(model, shard_config=self.shard_config) policy = get_autopolicy(model, shard_config=self.shard_config)
self.model, _ = shardformer.optimize(model, policy) self.model, _ = shardformer.optimize(model, policy)
if self.shard_config.extra_kwargs.get("inference_gptq", False):
if self.shard_config.inference_gptq:
self._post_init_gptq_buffer(self.model) self._post_init_gptq_buffer(self.model)
self.model = self.model.cuda() self.model = self.model.cuda()

View File

@ -4,7 +4,6 @@ import torch
from torch.nn import LayerNorm from torch.nn import LayerNorm
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
@ -38,35 +37,39 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.inference_gptq:
if self.shard_config.extra_kwargs.get("inference_gptq", False):
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, policy[BloomBlock] = ModulePolicyDescription(
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
}, },
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.query_key_value", suffix="self_attention.query_key_value",
target_module=ColCaiQuantLinear, target_module=ColCaiQuantLinear,
kwargs={'split_num': 3}), kwargs={"split_num": 3},
),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.dense", suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
target_module=RowCaiQuantLinear, ),
kwargs={'split_num': 1}),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.attention_dropout", suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h", suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
target_module=ColCaiQuantLinear, ),
kwargs={'split_num': 1}),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h", suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
target_module=RowCaiQuantLinear, ),
kwargs={'split_num': 1}), ],
]) )
# NOTE set inference mode to shard config # NOTE set inference mode to shard config
self.shard_config._infer() self.shard_config._infer()

View File

@ -13,6 +13,7 @@ from ..modeling.llama import LlamaInferenceForwards
try: try:
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
HAS_TRITON_RMSNORM = True HAS_TRITON_RMSNORM = True
except: except:
print("you should install triton from https://github.com/openai/triton") print("you should install triton from https://github.com/openai/triton")
@ -21,6 +22,7 @@ except:
def get_triton_rmsnorm_forward(): def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM: if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
@ -36,7 +38,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.inference_gptq: if self.shard_config.extra_kwargs.get("inference_gptq", False):
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
decoder_attribute_replacement = { decoder_attribute_replacement = {

View File

@ -81,8 +81,6 @@ Following are the description `ShardConfig`'s arguments:
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False. - `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
- `inference_only`: Whether only doing forward passing. Defaults to False.
### Write your own policy ### Write your own policy
If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design). If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design).
@ -185,7 +183,6 @@ class ShardConfig:
# Some possible future config fields # Some possible future config fields
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
inference_only: bool # only inject inference-suitable sharding policy
use_flash_attention: bool # whether to use flash attention to speed up attention use_flash_attention: bool # whether to use flash attention to speed up attention
``` ```

View File

@ -209,7 +209,8 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
:class:`Policy`: The auto policy for the model :class:`Policy`: The auto policy for the model
""" """
full_name = _fullname(model) full_name = _fullname(model)
if shard_config.inference_only: inference_only = shard_config.extra_kwargs.get("inference_only", False)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None) policy_location = _INFER_POLICY_LIST.get(full_name, None)
else: else:
policy_location = _POLICY_LIST.get(full_name, None) policy_location = _POLICY_LIST.get(full_name, None)
@ -219,5 +220,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
) )
else: else:
policy = import_policy(policy_location, shard_config.inference_only) policy = import_policy(policy_location, inference_only)
return policy() return policy()

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from typing import Dict, Optional
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -24,7 +24,6 @@ class ShardConfig:
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
inference_only (bool): Whether only doing forward passing. Defaults to False.
""" """
tensor_parallel_process_group: Optional[ProcessGroup] = None tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None pipeline_stage_manager: Optional[PipelineStageManager] = None
@ -33,10 +32,9 @@ class ShardConfig:
enable_flash_attention: bool = False enable_flash_attention: bool = False
enable_jit_fused: bool = False enable_jit_fused: bool = False
enable_all_optimization: bool = False enable_all_optimization: bool = False
inference_only: bool = False
inference_gptq: bool = False
enable_sequence_parallelism: bool = False enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False enable_sequence_overlap: bool = False
extra_kwargs: Dict[str, bool] = field(default_factory=dict)
# pipeline_parallel_size: int # pipeline_parallel_size: int
# data_parallel_size: int # data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
@ -77,4 +75,3 @@ class ShardConfig:
Set default params for inference. Set default params for inference.
""" """
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now" # assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
pass

View File

@ -28,7 +28,9 @@ def bench_bloom(args):
# init TPInferEngine and shard the original model # init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model # To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
# prepare data for generation # prepare data for generation

View File

@ -30,7 +30,9 @@ def run_chatglm2_test(args):
model = model.half() model = model.half()
model.config model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=1, do_sample=False) generate_kwargs = dict(max_new_tokens=1, do_sample=False)

View File

@ -30,7 +30,9 @@ def run_llama_test(args):
model = model.half() model = model.half()
model.config model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=1, do_sample=False) generate_kwargs = dict(max_new_tokens=1, do_sample=False)

View File

@ -34,7 +34,9 @@ def bench_bloom(args):
model = model.half() model = model.half()
model_config = model.config model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
@ -46,7 +48,8 @@ def bench_bloom(args):
# init TPInferEngine and shard the original model # init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model # To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "inference_gptq": True},
) )
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

View File

@ -33,7 +33,8 @@ def run_llama_test(args):
model_config = model.config model_config = model.config
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "inference_gptq": True},
) )
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

View File

@ -68,7 +68,9 @@ class Worker:
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
) )
shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) shard_config = ShardConfig(
enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
)
self.infer_engine = TPInferEngine( self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
) )

View File

@ -100,7 +100,9 @@ class ColossalInferenceHandler(BaseHandler, ABC):
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl") colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
logger.info("Initializing TPInferEngine ...") logger.info("Initializing TPInferEngine ...")
shard_config = ShardConfig(enable_tensor_parallelism=True if self.tp_size > 1 else False, inference_only=True) shard_config = ShardConfig(
enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
self.infer_engine = TPInferEngine( self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
) )

View File

@ -19,7 +19,7 @@ def build_model(
enable_tensor_parallelism=enable_tensor_parallelism, enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention, enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused, enable_jit_fused=enable_jit_fused,
inference_only=True, extra_kwargs={"inference_only": True},
) )
model_copy = copy.deepcopy(org_model) model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)

View File

@ -11,7 +11,6 @@ from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
try: try:
import lightllm
HAS_LIGHTLLM_KERNEL = True HAS_LIGHTLLM_KERNEL = True
except: except:
HAS_LIGHTLLM_KERNEL = False HAS_LIGHTLLM_KERNEL = False
@ -38,7 +37,7 @@ def run(test_config):
model = model.half() model = model.half()
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
) )
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
@ -58,7 +57,10 @@ def check_bloom(rank, world_size, port):
run() run()
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run() @clear_cache_before_run()

View File

@ -49,7 +49,7 @@ def run_chatglm2_test(test_config):
model = model.half() model = model.half()
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
) )
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)

View File

@ -34,7 +34,7 @@ def run():
model = LlamaForCausalLM(llama_config) model = LlamaForCausalLM(llama_config)
model = model.half() model = model.half()
shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) shard_config = ShardConfig(enable_tensor_parallelism=False, extra_kwargs={"inference_only": True})
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
dynamic_batch_manager = DynamicBatchManager( dynamic_batch_manager = DynamicBatchManager(

View File

@ -57,7 +57,9 @@ def run():
model = LlamaForCausalLM(llama_config) model = LlamaForCausalLM(llama_config)
model = model.half() model = model.half()
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) shard_config = ShardConfig(
enable_tensor_parallelism=True if TP_SIZE > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)

View File

@ -36,7 +36,7 @@ def run(test_config):
# 1. check TPInferEngine init and model optimization # 1. check TPInferEngine init and model optimization
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
) )
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)

View File

@ -13,7 +13,6 @@ from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
try: try:
import lightllm
HAS_LIGHTLLM_KERNEL = True HAS_LIGHTLLM_KERNEL = True
except: except:
HAS_LIGHTLLM_KERNEL = False HAS_LIGHTLLM_KERNEL = False
@ -43,7 +42,7 @@ def run_llama_test(test_config):
model = model.half() model = model.half()
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
) )
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
@ -63,7 +62,10 @@ def check_llama(rank, world_size, port):
run_llama_test() run_llama_test()
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run() @clear_cache_before_run()

View File

@ -13,7 +13,6 @@ from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
try: try:
import lightllm
HAS_LIGHTLLM_KERNEL = True HAS_LIGHTLLM_KERNEL = True
except: except:
HAS_LIGHTLLM_KERNEL = False HAS_LIGHTLLM_KERNEL = False
@ -41,7 +40,7 @@ def run_llama_test(test_config):
model = model.half() model = model.half()
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
) )
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
@ -61,7 +60,10 @@ def check_llama(rank, world_size, port):
run_llama_test() run_llama_test()
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run() @clear_cache_before_run()