mirror of https://github.com/hpcaitech/ColossalAI
[Refactor] refactor policy search and quant type controlling in inference (#5035)
* [Refactor] refactor policy search and quant type controling in inferencefeature/inference-refactor
parent
c6295c3381
commit
361cf63cb0
|
@ -142,7 +142,7 @@ jobs:
|
|||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
timeout-minutes: 60
|
||||
timeout-minutes: 100
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
|
@ -51,7 +51,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|||
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if self.shard_config.quant == "gptq":
|
||||
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
|
||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
|
@ -95,7 +95,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|||
],
|
||||
)
|
||||
|
||||
elif self.shard_config.quant == "smoothquant":
|
||||
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
|
||||
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
|
||||
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
|
||||
ColW8A8BFP32OFP32Linear,
|
||||
|
|
|
@ -81,7 +81,7 @@ Following are the description `ShardConfig`'s arguments:
|
|||
|
||||
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
|
||||
|
||||
- `inference_only`: Whether only doing forward passing. Defaults to False.
|
||||
- `extra_kwargs`: A dict to store extra kwargs for ShardFomer.
|
||||
|
||||
### Write your own policy
|
||||
|
||||
|
@ -185,8 +185,8 @@ class ShardConfig:
|
|||
|
||||
# Some possible future config fields
|
||||
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
|
||||
inference_only: bool # only inject inference-suitable sharding policy
|
||||
use_flash_attention: bool # whether to use flash attention to speed up attention
|
||||
extra_kwargs: Dict[str, Any] # extra kwargs for the shardformer
|
||||
```
|
||||
|
||||
### Policy
|
||||
|
|
|
@ -209,7 +209,8 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
|
|||
:class:`Policy`: The auto policy for the model
|
||||
"""
|
||||
full_name = _fullname(model)
|
||||
if shard_config.inference_only:
|
||||
inference_only = shard_config.extra_kwargs.get("inference_only", None)
|
||||
if inference_only:
|
||||
policy_location = _INFER_POLICY_LIST.get(full_name, None)
|
||||
else:
|
||||
policy_location = _POLICY_LIST.get(full_name, None)
|
||||
|
@ -219,5 +220,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
|
|||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
|
||||
)
|
||||
else:
|
||||
policy = import_policy(policy_location, shard_config.inference_only)
|
||||
policy = import_policy(policy_location, inference_only)
|
||||
return policy()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
@ -33,11 +33,9 @@ class ShardConfig:
|
|||
enable_flash_attention: bool = False
|
||||
enable_jit_fused: bool = False
|
||||
enable_all_optimization: bool = False
|
||||
inference_only: bool = False
|
||||
inference_gptq: bool = False
|
||||
enable_sequence_parallelism: bool = False
|
||||
enable_sequence_overlap: bool = False
|
||||
quant: str = None
|
||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||
|
|
|
@ -28,7 +28,9 @@ def bench_bloom(args):
|
|||
|
||||
# init TPInferEngine and shard the original model
|
||||
# To benchmark torch original, comment out the line of optimizing model
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
# prepare data for generation
|
||||
|
|
|
@ -30,7 +30,9 @@ def run_chatglm2_test(args):
|
|||
model = model.half()
|
||||
model.config
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
generate_kwargs = dict(max_new_tokens=1, do_sample=False)
|
||||
|
|
|
@ -30,7 +30,9 @@ def run_llama_test(args):
|
|||
model = model.half()
|
||||
model.config
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
generate_kwargs = dict(max_new_tokens=1, do_sample=False)
|
||||
|
|
|
@ -34,7 +34,9 @@ def bench_bloom(args):
|
|||
model = model.half()
|
||||
|
||||
model_config = model.config
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||
|
||||
|
@ -46,7 +48,8 @@ def bench_bloom(args):
|
|||
# init TPInferEngine and shard the original model
|
||||
# To benchmark torch original, comment out the line of optimizing model
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
||||
extra_kwargs={"inference_only": True, "quant": "gptq"},
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
|
|
|
@ -33,7 +33,8 @@ def run_llama_test(args):
|
|||
|
||||
model_config = model.config
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
||||
extra_kwargs={"inference_only": True, "quant": "gptq"},
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
|
|
Loading…
Reference in New Issue