mirror of https://github.com/hpcaitech/ColossalAI
35 lines
984 B
Python
35 lines
984 B
Python
![]() |
from torch import nn
|
||
|
|
||
|
from colossalai.inference.config import RPC_PARAM
|
||
|
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||
|
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
|
||
|
from colossalai.shardformer.policies.base_policy import Policy
|
||
|
|
||
|
|
||
|
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
|
||
|
def __init__(self) -> None:
|
||
|
super().__init__()
|
||
|
|
||
|
def module_policy(self):
|
||
|
policy = {}
|
||
|
self.append_or_create_method_replacement(
|
||
|
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
|
||
|
)
|
||
|
return policy
|
||
|
|
||
|
def preprocess(self) -> nn.Module:
|
||
|
return self.model
|
||
|
|
||
|
def postprocess(self):
|
||
|
return self.model
|
||
|
|
||
|
def config_sanity_check(self):
|
||
|
pass
|
||
|
|
||
|
def to_rpc_param(self) -> str:
|
||
|
return __class__.__name__
|
||
|
|
||
|
@staticmethod
|
||
|
def from_rpc_param() -> "PixArtAlphaInferPolicy":
|
||
|
return PixArtAlphaInferPolicy()
|