mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
91 lines
2.9 KiB
91 lines
2.9 KiB
from abc import ABC, abstractmethod
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from colossalai.cluster import ProcessGroupMesh
|
|
from colossalai.inference.config import ModelShardInferenceConfig
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
|
from colossalai.shardformer.policies.base_policy import Policy
|
|
|
|
|
|
class BaseEngine(ABC):
|
|
@abstractmethod
|
|
def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None):
|
|
"""
|
|
Init Model for Engine
|
|
"""
|
|
|
|
@abstractmethod
|
|
def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs):
|
|
"""
|
|
Generate ouptput for coming requests
|
|
"""
|
|
|
|
@abstractmethod
|
|
def add_request(self, prompts, request_ids=None, **kwargs):
|
|
"""
|
|
Add new request to Engine
|
|
"""
|
|
|
|
@abstractmethod
|
|
def step(self):
|
|
"""
|
|
Perform one new step forward
|
|
"""
|
|
|
|
@abstractmethod
|
|
def _verify_args(self):
|
|
"""
|
|
Verify the parameters and members of class
|
|
"""
|
|
|
|
@torch.inference_mode()
|
|
def capture_model(self):
|
|
"""
|
|
Use cuda graph to capture model
|
|
"""
|
|
return NotImplementedError("This method should be implemented by subclasses")
|
|
|
|
def _shardformer(
|
|
self,
|
|
model: nn.Module,
|
|
model_policy: Policy,
|
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
|
stage_manager: PipelineStageManager = None,
|
|
tp_group: ProcessGroupMesh = None,
|
|
**kwargs,
|
|
) -> nn.Module:
|
|
"""
|
|
Initialize ShardConfig and replace the model with shardformer.
|
|
|
|
Args:
|
|
model (nn.Module): Path or nn.Module of this model.
|
|
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
|
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
|
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
|
|
|
Returns:
|
|
nn.Module: The model optimized by Shardformer.
|
|
"""
|
|
|
|
shardconfig = ShardConfig(
|
|
tensor_parallel_process_group=tp_group,
|
|
pipeline_stage_manager=stage_manager,
|
|
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
|
enable_fused_normalization=False,
|
|
enable_all_optimization=False,
|
|
enable_flash_attention=False,
|
|
enable_jit_fused=False,
|
|
enable_sequence_parallelism=False,
|
|
extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs},
|
|
)
|
|
shardformer = ShardFormer(shard_config=shardconfig)
|
|
shard_model, _ = shardformer.optimize(model, model_policy)
|
|
return shard_model
|