ColossalAI/colossalai/inference/modeling/policy/pixart_alpha.py

35 lines
984 B
Python
Raw Normal View History

from torch import nn
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
from colossalai.shardformer.policies.base_policy import Policy
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = {}
self.append_or_create_method_replacement(
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
)
return policy
def preprocess(self) -> nn.Module:
return self.model
def postprocess(self):
return self.model
def config_sanity_check(self):
pass
def to_rpc_param(self) -> str:
return __class__.__name__
@staticmethod
def from_rpc_param() -> "PixArtAlphaInferPolicy":
return PixArtAlphaInferPolicy()