mirror of https://github.com/hpcaitech/ColossalAI
80 lines
3.0 KiB
Python
80 lines
3.0 KiB
Python
![]() |
from diffusers.models.attention import BasicTransformerBlock
|
||
|
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
|
||
|
from torch import nn
|
||
|
|
||
|
from colossalai.inference.config import RPC_PARAM
|
||
|
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
|
||
|
from colossalai.inference.modeling.layers.distrifusion import (
|
||
|
DistrifusionConv2D,
|
||
|
DistrifusionPatchEmbed,
|
||
|
DistriSelfAttention,
|
||
|
PixArtAlphaTransformer2DModel_forward,
|
||
|
)
|
||
|
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
|
||
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||
|
|
||
|
|
||
|
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
|
||
|
def __init__(self) -> None:
|
||
|
super().__init__()
|
||
|
|
||
|
def module_policy(self):
|
||
|
policy = {}
|
||
|
|
||
|
if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
|
||
|
|
||
|
policy[PixArtTransformer2DModel] = ModulePolicyDescription(
|
||
|
sub_module_replacement=[
|
||
|
SubModuleReplacementDescription(
|
||
|
suffix="pos_embed.proj",
|
||
|
target_module=DistrifusionConv2D,
|
||
|
kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
|
||
|
),
|
||
|
SubModuleReplacementDescription(
|
||
|
suffix="pos_embed",
|
||
|
target_module=DistrifusionPatchEmbed,
|
||
|
kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
|
||
|
),
|
||
|
],
|
||
|
attribute_replacement={
|
||
|
"patched_parallel_size": self.shard_config.extra_kwargs[
|
||
|
"model_shard_infer_config"
|
||
|
].patched_parallelism_size
|
||
|
},
|
||
|
method_replacement={"forward": PixArtAlphaTransformer2DModel_forward},
|
||
|
)
|
||
|
|
||
|
policy[BasicTransformerBlock] = ModulePolicyDescription(
|
||
|
sub_module_replacement=[
|
||
|
SubModuleReplacementDescription(
|
||
|
suffix="attn1",
|
||
|
target_module=DistriSelfAttention,
|
||
|
kwargs={
|
||
|
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
|
||
|
},
|
||
|
)
|
||
|
]
|
||
|
)
|
||
|
|
||
|
self.append_or_create_method_replacement(
|
||
|
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
|
||
|
)
|
||
|
|
||
|
return policy
|
||
|
|
||
|
def preprocess(self) -> nn.Module:
|
||
|
return self.model
|
||
|
|
||
|
def postprocess(self):
|
||
|
return self.model
|
||
|
|
||
|
def config_sanity_check(self):
|
||
|
pass
|
||
|
|
||
|
def to_rpc_param(self) -> str:
|
||
|
return __class__.__name__
|
||
|
|
||
|
@staticmethod
|
||
|
def from_rpc_param() -> "PixArtAlphaInferPolicy":
|
||
|
return PixArtAlphaInferPolicy()
|