Browse Source

[Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference
feature/inference-refactor
Zhongkai Zhao 1 year ago committed by GitHub
parent
commit
361cf63cb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      .github/workflows/build_on_pr.yml
  2. 4
      colossalai/inference/hybridengine/polices/llama.py
  3. 4
      colossalai/shardformer/README.md
  4. 5
      colossalai/shardformer/policies/auto_policy.py
  5. 8
      colossalai/shardformer/shard/shard_config.py
  6. 4
      examples/inference/bench_bloom.py
  7. 4
      examples/inference/bench_chatglm2.py
  8. 4
      examples/inference/bench_llama.py
  9. 7
      examples/inference/gptq_bloom.py
  10. 3
      examples/inference/gptq_llama.py

2
.github/workflows/build_on_pr.yml

@ -142,7 +142,7 @@ jobs:
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 60
timeout-minutes: 100
defaults:
run:
shell: bash

4
colossalai/inference/hybridengine/polices/llama.py

@ -51,7 +51,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
if self.shard_config.quant == "gptq":
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[LlamaDecoderLayer] = ModulePolicyDescription(
@ -95,7 +95,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
],
)
elif self.shard_config.quant == "smoothquant":
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
ColW8A8BFP32OFP32Linear,

4
colossalai/shardformer/README.md

@ -81,7 +81,7 @@ Following are the description `ShardConfig`'s arguments:
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
- `inference_only`: Whether only doing forward passing. Defaults to False.
- `extra_kwargs`: A dict to store extra kwargs for ShardFomer.
### Write your own policy
@ -185,8 +185,8 @@ class ShardConfig:
# Some possible future config fields
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
inference_only: bool # only inject inference-suitable sharding policy
use_flash_attention: bool # whether to use flash attention to speed up attention
extra_kwargs: Dict[str, Any] # extra kwargs for the shardformer
```
### Policy

5
colossalai/shardformer/policies/auto_policy.py

@ -209,7 +209,8 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
if shard_config.inference_only:
inference_only = shard_config.extra_kwargs.get("inference_only", None)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None)
@ -219,5 +220,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location, shard_config.inference_only)
policy = import_policy(policy_location, inference_only)
return policy()

8
colossalai/shardformer/shard/shard_config.py

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
import torch.distributed as dist
from torch.distributed import ProcessGroup
@ -33,11 +33,9 @@ class ShardConfig:
enable_flash_attention: bool = False
enable_jit_fused: bool = False
enable_all_optimization: bool = False
inference_only: bool = False
inference_gptq: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
quant: str = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']

4
examples/inference/bench_bloom.py

@ -28,7 +28,9 @@ def bench_bloom(args):
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
# prepare data for generation

4
examples/inference/bench_chatglm2.py

@ -30,7 +30,9 @@ def run_chatglm2_test(args):
model = model.half()
model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=1, do_sample=False)

4
examples/inference/bench_llama.py

@ -30,7 +30,9 @@ def run_llama_test(args):
model = model.half()
model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=1, do_sample=False)

7
examples/inference/gptq_bloom.py

@ -34,7 +34,9 @@ def bench_bloom(args):
model = model.half()
model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
@ -46,7 +48,8 @@ def bench_bloom(args):
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "quant": "gptq"},
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

3
examples/inference/gptq_llama.py

@ -33,7 +33,8 @@ def run_llama_test(args):
model_config = model.config
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "quant": "gptq"},
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

Loading…
Cancel
Save