ColossalAI/colossalai/inference/modeling/policy/stablediffusion3.py

79 lines
2.9 KiB
Python

from diffusers.models.attention import JointTransformerBlock
from diffusers.models.transformers import SD3Transformer2DModel
from torch import nn
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
from colossalai.inference.modeling.layers.distrifusion import (
DistrifusionConv2D,
DistrifusionFusedAttention,
DistrifusionPatchEmbed,
SD3Transformer2DModel_forward,
)
from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = {}
if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
policy[SD3Transformer2DModel] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="pos_embed.proj",
target_module=DistrifusionConv2D,
kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
),
SubModuleReplacementDescription(
suffix="pos_embed",
target_module=DistrifusionPatchEmbed,
kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
),
],
attribute_replacement={
"patched_parallel_size": self.shard_config.extra_kwargs[
"model_shard_infer_config"
].patched_parallelism_size
},
method_replacement={"forward": SD3Transformer2DModel_forward},
)
policy[JointTransformerBlock] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn",
target_module=DistrifusionFusedAttention,
kwargs={
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
},
)
]
)
self.append_or_create_method_replacement(
description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
)
return policy
def preprocess(self) -> nn.Module:
return self.model
def postprocess(self):
return self.model
def config_sanity_check(self):
pass
def to_rpc_param(self) -> str:
return __class__.__name__
@staticmethod
def from_rpc_param() -> "StableDiffusion3InferPolicy":
return StableDiffusion3InferPolicy()