mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Suport extra_kwargs in ShardConfig (#5031)
* [refactor]: replace inference args with extra_kwargs in ShardConfig * modify shardconfig * polish code * fix policy bug in llama * fix bug in auto policy * remove setattr in ShardConfigpull/4836/head^2
parent
576a2f7b10
commit
70885d707d
|
@ -67,7 +67,9 @@ class Worker:
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
|
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
|
||||||
)
|
)
|
||||||
shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
|
shard_config = ShardConfig(
|
||||||
|
enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
|
)
|
||||||
self.infer_engine = TPInferEngine(
|
self.infer_engine = TPInferEngine(
|
||||||
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
|
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
|
||||||
)
|
)
|
||||||
|
|
|
@ -45,8 +45,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
if self.shard_config.extra_kwargs.get("inference_gptq", False):
|
||||||
if self.shard_config.inference_gptq:
|
|
||||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||||
|
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
|
|
|
@ -44,7 +44,7 @@ class TPInferEngine:
|
||||||
>>> # define model and shard config for your inference
|
>>> # define model and shard config for your inference
|
||||||
>>> model = ...
|
>>> model = ...
|
||||||
>>> generate_kwargs = ...
|
>>> generate_kwargs = ...
|
||||||
>>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
|
>>> shard_config = ShardConfig(enable_tensor_parallelism=True, extra_kwargs={"inference_only": True})
|
||||||
>>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
>>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||||
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
|
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
|
||||||
"""
|
"""
|
||||||
|
@ -181,7 +181,7 @@ class TPInferEngine:
|
||||||
In further generation, use the sharded model instead of original model.
|
In further generation, use the sharded model instead of original model.
|
||||||
"""
|
"""
|
||||||
# NOTE we will change to use an inference config later with additional attrs we want
|
# NOTE we will change to use an inference config later with additional attrs we want
|
||||||
assert self.shard_config.inference_only is True
|
assert self.shard_config.extra_kwargs["inference_only"] is True
|
||||||
shardformer = ShardFormer(shard_config=self.shard_config)
|
shardformer = ShardFormer(shard_config=self.shard_config)
|
||||||
self._prepare_with_shard_config(shard_config=self.shard_config)
|
self._prepare_with_shard_config(shard_config=self.shard_config)
|
||||||
self._shard_model_by(shardformer, model)
|
self._shard_model_by(shardformer, model)
|
||||||
|
@ -203,10 +203,10 @@ class TPInferEngine:
|
||||||
enable_all_optimization=False,
|
enable_all_optimization=False,
|
||||||
enable_flash_attention=False,
|
enable_flash_attention=False,
|
||||||
enable_jit_fused=False,
|
enable_jit_fused=False,
|
||||||
inference_only=True,
|
extra_kwargs={"inference_only": True},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
shard_config.inference_only = True
|
shard_config.extra_kwargs = {"inference_only": True}
|
||||||
shard_config.pipeline_stage_manager = None
|
shard_config.pipeline_stage_manager = None
|
||||||
if shard_config.enable_tensor_parallelism:
|
if shard_config.enable_tensor_parallelism:
|
||||||
self.tp_size = shard_config.tensor_parallel_size
|
self.tp_size = shard_config.tensor_parallel_size
|
||||||
|
@ -221,13 +221,11 @@ class TPInferEngine:
|
||||||
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
|
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
|
||||||
model_name = model.__class__.__name__
|
model_name = model.__class__.__name__
|
||||||
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
|
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
|
||||||
|
if self.shard_config.extra_kwargs.get("inference_gptq", False):
|
||||||
model = model.model if self.shard_config.inference_gptq else model
|
model = model.model
|
||||||
policy = get_autopolicy(model, shard_config=self.shard_config)
|
policy = get_autopolicy(model, shard_config=self.shard_config)
|
||||||
|
|
||||||
self.model, _ = shardformer.optimize(model, policy)
|
self.model, _ = shardformer.optimize(model, policy)
|
||||||
|
if self.shard_config.extra_kwargs.get("inference_gptq", False):
|
||||||
if self.shard_config.inference_gptq:
|
|
||||||
self._post_init_gptq_buffer(self.model)
|
self._post_init_gptq_buffer(self.model)
|
||||||
|
|
||||||
self.model = self.model.cuda()
|
self.model = self.model.cuda()
|
||||||
|
|
|
@ -4,7 +4,6 @@ import torch
|
||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
|
|
||||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||||
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
|
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
|
||||||
|
|
||||||
|
@ -38,35 +37,39 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
|
||||||
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
|
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
if self.shard_config.inference_gptq:
|
|
||||||
|
if self.shard_config.extra_kwargs.get("inference_gptq", False):
|
||||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||||
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
|
|
||||||
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
policy[BloomBlock] = ModulePolicyDescription(
|
||||||
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
attribute_replacement={
|
||||||
|
"self_attention.hidden_size": self.model.config.hidden_size
|
||||||
|
// self.shard_config.tensor_parallel_size,
|
||||||
|
"self_attention.split_size": self.model.config.hidden_size
|
||||||
|
// self.shard_config.tensor_parallel_size,
|
||||||
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||||
},
|
},
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.query_key_value",
|
suffix="self_attention.query_key_value",
|
||||||
target_module=ColCaiQuantLinear,
|
target_module=ColCaiQuantLinear,
|
||||||
kwargs={'split_num': 3}),
|
kwargs={"split_num": 3},
|
||||||
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.dense",
|
suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
|
||||||
target_module=RowCaiQuantLinear,
|
),
|
||||||
kwargs={'split_num': 1}),
|
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.attention_dropout",
|
suffix="self_attention.attention_dropout",
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.dense_h_to_4h",
|
suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
|
||||||
target_module=ColCaiQuantLinear,
|
),
|
||||||
kwargs={'split_num': 1}),
|
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.dense_4h_to_h",
|
suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
|
||||||
target_module=RowCaiQuantLinear,
|
),
|
||||||
kwargs={'split_num': 1}),
|
],
|
||||||
])
|
)
|
||||||
# NOTE set inference mode to shard config
|
# NOTE set inference mode to shard config
|
||||||
self.shard_config._infer()
|
self.shard_config._infer()
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from ..modeling.llama import LlamaInferenceForwards
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
||||||
|
|
||||||
HAS_TRITON_RMSNORM = True
|
HAS_TRITON_RMSNORM = True
|
||||||
except:
|
except:
|
||||||
print("you should install triton from https://github.com/openai/triton")
|
print("you should install triton from https://github.com/openai/triton")
|
||||||
|
@ -21,6 +22,7 @@ except:
|
||||||
|
|
||||||
def get_triton_rmsnorm_forward():
|
def get_triton_rmsnorm_forward():
|
||||||
if HAS_TRITON_RMSNORM:
|
if HAS_TRITON_RMSNORM:
|
||||||
|
|
||||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||||
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||||
|
|
||||||
|
@ -36,7 +38,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
if self.shard_config.inference_gptq:
|
if self.shard_config.extra_kwargs.get("inference_gptq", False):
|
||||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||||
|
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
|
|
|
@ -81,8 +81,6 @@ Following are the description `ShardConfig`'s arguments:
|
||||||
|
|
||||||
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
|
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
|
||||||
|
|
||||||
- `inference_only`: Whether only doing forward passing. Defaults to False.
|
|
||||||
|
|
||||||
### Write your own policy
|
### Write your own policy
|
||||||
|
|
||||||
If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design).
|
If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design).
|
||||||
|
@ -185,7 +183,6 @@ class ShardConfig:
|
||||||
|
|
||||||
# Some possible future config fields
|
# Some possible future config fields
|
||||||
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
|
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
|
||||||
inference_only: bool # only inject inference-suitable sharding policy
|
|
||||||
use_flash_attention: bool # whether to use flash attention to speed up attention
|
use_flash_attention: bool # whether to use flash attention to speed up attention
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -209,7 +209,8 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
|
||||||
:class:`Policy`: The auto policy for the model
|
:class:`Policy`: The auto policy for the model
|
||||||
"""
|
"""
|
||||||
full_name = _fullname(model)
|
full_name = _fullname(model)
|
||||||
if shard_config.inference_only:
|
inference_only = shard_config.extra_kwargs.get("inference_only", False)
|
||||||
|
if inference_only:
|
||||||
policy_location = _INFER_POLICY_LIST.get(full_name, None)
|
policy_location = _INFER_POLICY_LIST.get(full_name, None)
|
||||||
else:
|
else:
|
||||||
policy_location = _POLICY_LIST.get(full_name, None)
|
policy_location = _POLICY_LIST.get(full_name, None)
|
||||||
|
@ -219,5 +220,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
|
||||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
|
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
policy = import_policy(policy_location, shard_config.inference_only)
|
policy = import_policy(policy_location, inference_only)
|
||||||
return policy()
|
return policy()
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
@ -24,7 +24,6 @@ class ShardConfig:
|
||||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
|
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
|
||||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
|
enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
|
||||||
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
||||||
inference_only (bool): Whether only doing forward passing. Defaults to False.
|
|
||||||
"""
|
"""
|
||||||
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
||||||
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
||||||
|
@ -33,10 +32,9 @@ class ShardConfig:
|
||||||
enable_flash_attention: bool = False
|
enable_flash_attention: bool = False
|
||||||
enable_jit_fused: bool = False
|
enable_jit_fused: bool = False
|
||||||
enable_all_optimization: bool = False
|
enable_all_optimization: bool = False
|
||||||
inference_only: bool = False
|
|
||||||
inference_gptq: bool = False
|
|
||||||
enable_sequence_parallelism: bool = False
|
enable_sequence_parallelism: bool = False
|
||||||
enable_sequence_overlap: bool = False
|
enable_sequence_overlap: bool = False
|
||||||
|
extra_kwargs: Dict[str, bool] = field(default_factory=dict)
|
||||||
# pipeline_parallel_size: int
|
# pipeline_parallel_size: int
|
||||||
# data_parallel_size: int
|
# data_parallel_size: int
|
||||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||||
|
@ -77,4 +75,3 @@ class ShardConfig:
|
||||||
Set default params for inference.
|
Set default params for inference.
|
||||||
"""
|
"""
|
||||||
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
|
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
|
||||||
pass
|
|
||||||
|
|
|
@ -28,7 +28,9 @@ def bench_bloom(args):
|
||||||
|
|
||||||
# init TPInferEngine and shard the original model
|
# init TPInferEngine and shard the original model
|
||||||
# To benchmark torch original, comment out the line of optimizing model
|
# To benchmark torch original, comment out the line of optimizing model
|
||||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
shard_config = ShardConfig(
|
||||||
|
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||||
|
|
||||||
# prepare data for generation
|
# prepare data for generation
|
||||||
|
|
|
@ -30,7 +30,9 @@ def run_chatglm2_test(args):
|
||||||
model = model.half()
|
model = model.half()
|
||||||
model.config
|
model.config
|
||||||
|
|
||||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
shard_config = ShardConfig(
|
||||||
|
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||||
|
|
||||||
generate_kwargs = dict(max_new_tokens=1, do_sample=False)
|
generate_kwargs = dict(max_new_tokens=1, do_sample=False)
|
||||||
|
|
|
@ -30,7 +30,9 @@ def run_llama_test(args):
|
||||||
model = model.half()
|
model = model.half()
|
||||||
model.config
|
model.config
|
||||||
|
|
||||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
shard_config = ShardConfig(
|
||||||
|
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||||
|
|
||||||
generate_kwargs = dict(max_new_tokens=1, do_sample=False)
|
generate_kwargs = dict(max_new_tokens=1, do_sample=False)
|
||||||
|
|
|
@ -34,7 +34,9 @@ def bench_bloom(args):
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
model_config = model.config
|
model_config = model.config
|
||||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
shard_config = ShardConfig(
|
||||||
|
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||||
|
|
||||||
|
@ -46,7 +48,8 @@ def bench_bloom(args):
|
||||||
# init TPInferEngine and shard the original model
|
# init TPInferEngine and shard the original model
|
||||||
# To benchmark torch original, comment out the line of optimizing model
|
# To benchmark torch original, comment out the line of optimizing model
|
||||||
shard_config = ShardConfig(
|
shard_config = ShardConfig(
|
||||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
|
enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
||||||
|
extra_kwargs={"inference_only": True, "inference_gptq": True},
|
||||||
)
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,8 @@ def run_llama_test(args):
|
||||||
|
|
||||||
model_config = model.config
|
model_config = model.config
|
||||||
shard_config = ShardConfig(
|
shard_config = ShardConfig(
|
||||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
|
enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
||||||
|
extra_kwargs={"inference_only": True, "inference_gptq": True},
|
||||||
)
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||||
|
|
||||||
|
|
|
@ -68,7 +68,9 @@ class Worker:
|
||||||
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
|
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
|
||||||
)
|
)
|
||||||
|
|
||||||
shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
|
shard_config = ShardConfig(
|
||||||
|
enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
|
)
|
||||||
self.infer_engine = TPInferEngine(
|
self.infer_engine = TPInferEngine(
|
||||||
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
|
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
|
||||||
)
|
)
|
||||||
|
|
|
@ -100,7 +100,9 @@ class ColossalInferenceHandler(BaseHandler, ABC):
|
||||||
|
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
|
||||||
logger.info("Initializing TPInferEngine ...")
|
logger.info("Initializing TPInferEngine ...")
|
||||||
shard_config = ShardConfig(enable_tensor_parallelism=True if self.tp_size > 1 else False, inference_only=True)
|
shard_config = ShardConfig(
|
||||||
|
enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
|
)
|
||||||
self.infer_engine = TPInferEngine(
|
self.infer_engine = TPInferEngine(
|
||||||
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
|
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,7 +19,7 @@ def build_model(
|
||||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||||
enable_flash_attention=enable_flash_attention,
|
enable_flash_attention=enable_flash_attention,
|
||||||
enable_jit_fused=enable_jit_fused,
|
enable_jit_fused=enable_jit_fused,
|
||||||
inference_only=True,
|
extra_kwargs={"inference_only": True},
|
||||||
)
|
)
|
||||||
model_copy = copy.deepcopy(org_model)
|
model_copy = copy.deepcopy(org_model)
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
|
|
|
@ -11,7 +11,6 @@ from colossalai.shardformer import ShardConfig
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import lightllm
|
|
||||||
HAS_LIGHTLLM_KERNEL = True
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
except:
|
except:
|
||||||
HAS_LIGHTLLM_KERNEL = False
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
@ -38,7 +37,7 @@ def run(test_config):
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
shard_config = ShardConfig(
|
shard_config = ShardConfig(
|
||||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
)
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||||
|
@ -58,7 +57,10 @@ def check_bloom(rank, world_size, port):
|
||||||
run()
|
run()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(
|
||||||
|
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||||
|
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||||
|
)
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
|
@ -49,7 +49,7 @@ def run_chatglm2_test(test_config):
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
shard_config = ShardConfig(
|
shard_config = ShardConfig(
|
||||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
)
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||||
|
|
|
@ -34,7 +34,7 @@ def run():
|
||||||
model = LlamaForCausalLM(llama_config)
|
model = LlamaForCausalLM(llama_config)
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True)
|
shard_config = ShardConfig(enable_tensor_parallelism=False, extra_kwargs={"inference_only": True})
|
||||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||||
|
|
||||||
dynamic_batch_manager = DynamicBatchManager(
|
dynamic_batch_manager = DynamicBatchManager(
|
||||||
|
|
|
@ -57,7 +57,9 @@ def run():
|
||||||
model = LlamaForCausalLM(llama_config)
|
model = LlamaForCausalLM(llama_config)
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)
|
shard_config = ShardConfig(
|
||||||
|
enable_tensor_parallelism=True if TP_SIZE > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
|
)
|
||||||
|
|
||||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||||
batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
|
batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
|
||||||
|
|
|
@ -36,7 +36,7 @@ def run(test_config):
|
||||||
|
|
||||||
# 1. check TPInferEngine init and model optimization
|
# 1. check TPInferEngine init and model optimization
|
||||||
shard_config = ShardConfig(
|
shard_config = ShardConfig(
|
||||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
)
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,6 @@ from colossalai.shardformer import ShardConfig
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import lightllm
|
|
||||||
HAS_LIGHTLLM_KERNEL = True
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
except:
|
except:
|
||||||
HAS_LIGHTLLM_KERNEL = False
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
@ -43,7 +42,7 @@ def run_llama_test(test_config):
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
shard_config = ShardConfig(
|
shard_config = ShardConfig(
|
||||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
)
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||||
|
@ -63,7 +62,10 @@ def check_llama(rank, world_size, port):
|
||||||
run_llama_test()
|
run_llama_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(
|
||||||
|
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||||
|
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||||
|
)
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
|
@ -13,7 +13,6 @@ from colossalai.shardformer import ShardConfig
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import lightllm
|
|
||||||
HAS_LIGHTLLM_KERNEL = True
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
except:
|
except:
|
||||||
HAS_LIGHTLLM_KERNEL = False
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
@ -41,7 +40,7 @@ def run_llama_test(test_config):
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
shard_config = ShardConfig(
|
shard_config = ShardConfig(
|
||||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
)
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||||
|
@ -61,7 +60,10 @@ def check_llama(rank, world_size, port):
|
||||||
run_llama_test()
|
run_llama_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(
|
||||||
|
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||||
|
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||||
|
)
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
Loading…
Reference in New Issue