You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/inference/core/base_engine.py

91 lines
2.9 KiB

from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
class BaseEngine(ABC):
@abstractmethod
def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None):
pass
@abstractmethod
def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None):
"""
Init Model for Engine
"""
@abstractmethod
def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs):
"""
Generate ouptput for coming requests
"""
@abstractmethod
def add_request(self, prompts, request_ids=None, **kwargs):
"""
Add new request to Engine
"""
@abstractmethod
def step(self):
"""
Perform one new step forward
"""
@abstractmethod
def _verify_args(self):
"""
Verify the parameters and members of class
"""
@torch.inference_mode()
def capture_model(self):
"""
Use cuda graph to capture model
"""
return NotImplementedError("This method should be implemented by subclasses")
def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
**kwargs,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.
Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: The model optimized by Shardformer.
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model