2024-07-08 08:02:07 +00:00
# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
from typing import Any , Callable , Dict , List , Optional , Union
import torch
from diffusers . pipelines . stable_diffusion_3 . pipeline_stable_diffusion_3 import retrieve_timesteps
2024-07-30 02:43:26 +00:00
from . . layers . diffusion import DiffusionPipe
2024-07-08 08:02:07 +00:00
# TODO(@lry89757) temporarily image, please support more return output
@torch.no_grad ( )
def sd3_forward (
self : DiffusionPipe ,
prompt : Union [ str , List [ str ] ] = None ,
prompt_2 : Optional [ Union [ str , List [ str ] ] ] = None ,
prompt_3 : Optional [ Union [ str , List [ str ] ] ] = None ,
height : Optional [ int ] = None ,
width : Optional [ int ] = None ,
num_inference_steps : int = 28 ,
timesteps : List [ int ] = None ,
guidance_scale : float = 7.0 ,
negative_prompt : Optional [ Union [ str , List [ str ] ] ] = None ,
negative_prompt_2 : Optional [ Union [ str , List [ str ] ] ] = None ,
negative_prompt_3 : Optional [ Union [ str , List [ str ] ] ] = None ,
num_images_per_prompt : Optional [ int ] = 1 ,
generator : Optional [ Union [ torch . Generator , List [ torch . Generator ] ] ] = None ,
latents : Optional [ torch . FloatTensor ] = None ,
prompt_embeds : Optional [ torch . FloatTensor ] = None ,
negative_prompt_embeds : Optional [ torch . FloatTensor ] = None ,
pooled_prompt_embeds : Optional [ torch . FloatTensor ] = None ,
negative_pooled_prompt_embeds : Optional [ torch . FloatTensor ] = None ,
output_type : Optional [ str ] = " pil " ,
return_dict : bool = True ,
joint_attention_kwargs : Optional [ Dict [ str , Any ] ] = None ,
clip_skip : Optional [ int ] = None ,
callback_on_step_end : Optional [ Callable [ [ int , int , Dict ] , None ] ] = None ,
callback_on_step_end_tensor_inputs : List [ str ] = [ " latents " ] ,
) :
height = height or self . default_sample_size * self . vae_scale_factor
width = width or self . default_sample_size * self . vae_scale_factor
# 1. Check inputs. Raise error if not correct
self . check_inputs (
prompt ,
prompt_2 ,
prompt_3 ,
height ,
width ,
negative_prompt = negative_prompt ,
negative_prompt_2 = negative_prompt_2 ,
negative_prompt_3 = negative_prompt_3 ,
prompt_embeds = prompt_embeds ,
negative_prompt_embeds = negative_prompt_embeds ,
pooled_prompt_embeds = pooled_prompt_embeds ,
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
)
self . _guidance_scale = guidance_scale
self . _clip_skip = clip_skip
self . _joint_attention_kwargs = joint_attention_kwargs
self . _interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance ( prompt , str ) :
batch_size = 1
elif prompt is not None and isinstance ( prompt , list ) :
batch_size = len ( prompt )
else :
batch_size = prompt_embeds . shape [ 0 ]
device = self . _execution_device
(
prompt_embeds ,
negative_prompt_embeds ,
pooled_prompt_embeds ,
negative_pooled_prompt_embeds ,
) = self . encode_prompt (
prompt = prompt ,
prompt_2 = prompt_2 ,
prompt_3 = prompt_3 ,
negative_prompt = negative_prompt ,
negative_prompt_2 = negative_prompt_2 ,
negative_prompt_3 = negative_prompt_3 ,
do_classifier_free_guidance = self . do_classifier_free_guidance ,
prompt_embeds = prompt_embeds ,
negative_prompt_embeds = negative_prompt_embeds ,
pooled_prompt_embeds = pooled_prompt_embeds ,
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
device = device ,
clip_skip = self . clip_skip ,
num_images_per_prompt = num_images_per_prompt ,
)
if self . do_classifier_free_guidance :
prompt_embeds = torch . cat ( [ negative_prompt_embeds , prompt_embeds ] , dim = 0 )
pooled_prompt_embeds = torch . cat ( [ negative_pooled_prompt_embeds , pooled_prompt_embeds ] , dim = 0 )
# 4. Prepare timesteps
timesteps , num_inference_steps = retrieve_timesteps ( self . scheduler , num_inference_steps , device , timesteps )
num_warmup_steps = max ( len ( timesteps ) - num_inference_steps * self . scheduler . order , 0 )
self . _num_timesteps = len ( timesteps )
# 5. Prepare latent variables
num_channels_latents = self . transformer . config . in_channels
latents = self . prepare_latents (
batch_size * num_images_per_prompt ,
num_channels_latents ,
height ,
width ,
prompt_embeds . dtype ,
device ,
generator ,
latents ,
)
# 6. Denoising loop
with self . progress_bar ( total = num_inference_steps ) as progress_bar :
for i , t in enumerate ( timesteps ) :
if self . interrupt :
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch . cat ( [ latents ] * 2 ) if self . do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t . expand ( latent_model_input . shape [ 0 ] )
noise_pred = self . transformer (
hidden_states = latent_model_input ,
timestep = timestep ,
encoder_hidden_states = prompt_embeds ,
pooled_projections = pooled_prompt_embeds ,
joint_attention_kwargs = self . joint_attention_kwargs ,
return_dict = False ,
) [ 0 ]
# perform guidance
if self . do_classifier_free_guidance :
noise_pred_uncond , noise_pred_text = noise_pred . chunk ( 2 )
noise_pred = noise_pred_uncond + self . guidance_scale * ( noise_pred_text - noise_pred_uncond )
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents . dtype
latents = self . scheduler . step ( noise_pred , t , latents , return_dict = False ) [ 0 ]
if latents . dtype != latents_dtype :
if torch . backends . mps . is_available ( ) :
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents . to ( latents_dtype )
if callback_on_step_end is not None :
callback_kwargs = { }
for k in callback_on_step_end_tensor_inputs :
callback_kwargs [ k ] = locals ( ) [ k ]
callback_outputs = callback_on_step_end ( self , i , t , callback_kwargs )
latents = callback_outputs . pop ( " latents " , latents )
prompt_embeds = callback_outputs . pop ( " prompt_embeds " , prompt_embeds )
negative_prompt_embeds = callback_outputs . pop ( " negative_prompt_embeds " , negative_prompt_embeds )
negative_pooled_prompt_embeds = callback_outputs . pop (
" negative_pooled_prompt_embeds " , negative_pooled_prompt_embeds
)
# call the callback, if provided
if i == len ( timesteps ) - 1 or ( ( i + 1 ) > num_warmup_steps and ( i + 1 ) % self . scheduler . order == 0 ) :
progress_bar . update ( )
if output_type == " latent " :
image = latents
else :
latents = ( latents / self . vae . config . scaling_factor ) + self . vae . config . shift_factor
image = self . vae . decode ( latents , return_dict = False ) [ 0 ]
image = self . image_processor . postprocess ( image , output_type = output_type )
return image