@ -1,10 +1,11 @@
# pytorch_diffusion + derived encoder decoder
import math
from typing import Any , Optional
import numpy as np
import torch
import torch . nn as nn
import numpy as np
from einops import rearrange
from typing import Optional , Any
try :
from lightning . pytorch . utilities import rank_zero_info
@ -38,14 +39,14 @@ def get_timestep_embedding(timesteps, embedding_dim):
emb = emb . to ( device = timesteps . device )
emb = timesteps . float ( ) [ : , None ] * emb [ None , : ]
emb = torch . cat ( [ torch . sin ( emb ) , torch . cos ( emb ) ] , dim = 1 )
if embedding_dim % 2 == 1 : # zero pad
emb = torch . nn . functional . pad ( emb , ( 0 , 1 , 0 , 0 ) )
if embedding_dim % 2 == 1 : # zero pad
emb = torch . nn . functional . pad ( emb , ( 0 , 1 , 0 , 0 ) )
return emb
def nonlinearity ( x ) :
# swish
return x * torch . sigmoid ( x )
return x * torch . sigmoid ( x )
def Normalize ( in_channels , num_groups = 32 ) :
@ -53,15 +54,12 @@ def Normalize(in_channels, num_groups=32):
class Upsample ( nn . Module ) :
def __init__ ( self , in_channels , with_conv ) :
super ( ) . __init__ ( )
self . with_conv = with_conv
if self . with_conv :
self . conv = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
self . conv = torch . nn . Conv2d ( in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
def forward ( self , x ) :
x = torch . nn . functional . interpolate ( x , scale_factor = 2.0 , mode = " nearest " )
@ -71,20 +69,17 @@ class Upsample(nn.Module):
class Downsample ( nn . Module ) :
def __init__ ( self , in_channels , with_conv ) :
super ( ) . __init__ ( )
self . with_conv = with_conv
if self . with_conv :
# no asymmetric padding in torch conv, must do it ourselves
self . conv = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 3 ,
stride = 2 ,
padding = 0 )
self . conv = torch . nn . Conv2d ( in_channels , in_channels , kernel_size = 3 , stride = 2 , padding = 0 )
def forward ( self , x ) :
if self . with_conv :
pad = ( 0 , 1 , 0 , 1 )
pad = ( 0 , 1 , 0 , 1 )
x = torch . nn . functional . pad ( x , pad , mode = " constant " , value = 0 )
x = self . conv ( x )
else :
@ -93,8 +88,8 @@ class Downsample(nn.Module):
class ResnetBlock ( nn . Module ) :
def __init__ ( self , * , in_channels , out_channels = None , conv_shortcut = False ,
dropout , temb_channels = 512 ) :
def __init__ ( self , * , in_channels , out_channels = None , conv_shortcut = False , dropout , temb_channels = 512 ) :
super ( ) . __init__ ( )
self . in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@ -102,34 +97,17 @@ class ResnetBlock(nn.Module):
self . use_conv_shortcut = conv_shortcut
self . norm1 = Normalize ( in_channels )
self . conv1 = torch . nn . Conv2d ( in_channels ,
out_channels ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
self . conv1 = torch . nn . Conv2d ( in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
if temb_channels > 0 :
self . temb_proj = torch . nn . Linear ( temb_channels ,
out_channels )
self . temb_proj = torch . nn . Linear ( temb_channels , out_channels )
self . norm2 = Normalize ( out_channels )
self . dropout = torch . nn . Dropout ( dropout )
self . conv2 = torch . nn . Conv2d ( out_channels ,
out_channels ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
self . conv2 = torch . nn . Conv2d ( out_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
if self . in_channels != self . out_channels :
if self . use_conv_shortcut :
self . conv_shortcut = torch . nn . Conv2d ( in_channels ,
out_channels ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
self . conv_shortcut = torch . nn . Conv2d ( in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
else :
self . nin_shortcut = torch . nn . Conv2d ( in_channels ,
out_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . nin_shortcut = torch . nn . Conv2d ( in_channels , out_channels , kernel_size = 1 , stride = 1 , padding = 0 )
def forward ( self , x , temb ) :
h = x
@ -138,7 +116,7 @@ class ResnetBlock(nn.Module):
h = self . conv1 ( h )
if temb is not None :
h = h + self . temb_proj ( nonlinearity ( temb ) ) [ : , : , None , None ]
h = h + self . temb_proj ( nonlinearity ( temb ) ) [ : , : , None , None ]
h = self . norm2 ( h )
h = nonlinearity ( h )
@ -151,35 +129,20 @@ class ResnetBlock(nn.Module):
else :
x = self . nin_shortcut ( x )
return x + h
return x + h
class AttnBlock ( nn . Module ) :
def __init__ ( self , in_channels ) :
super ( ) . __init__ ( )
self . in_channels = in_channels
self . norm = Normalize ( in_channels )
self . q = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . k = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . v = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . proj_out = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . q = torch . nn . Conv2d ( in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
self . k = torch . nn . Conv2d ( in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
self . v = torch . nn . Conv2d ( in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
self . proj_out = torch . nn . Conv2d ( in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
def forward ( self , x ) :
h_ = x
@ -189,23 +152,24 @@ class AttnBlock(nn.Module):
v = self . v ( h_ )
# compute attention
b , c , h , w = q . shape
q = q . reshape ( b , c , h * w )
q = q . permute ( 0 , 2 , 1 ) # b,hw,c
k = k . reshape ( b , c , h * w ) # b,c,hw
w_ = torch . bmm ( q , k ) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
b , c , h , w = q . shape
q = q . reshape ( b , c , h * w )
q = q . permute ( 0 , 2 , 1 ) # b,hw,c
k = k . reshape ( b , c , h * w ) # b,c,hw
w_ = torch . bmm ( q , k ) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * ( int ( c ) * * ( - 0.5 ) )
w_ = torch . nn . functional . softmax ( w_ , dim = 2 )
# attend to values
v = v . reshape ( b , c , h * w )
w_ = w_ . permute ( 0 , 2 , 1 ) # b,hw,hw (first hw of k, second of q)
h_ = torch . bmm ( v , w_ ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_ . reshape ( b , c , h , w )
v = v . reshape ( b , c , h * w )
w_ = w_ . permute ( 0 , 2 , 1 ) # b,hw,hw (first hw of k, second of q)
h_ = torch . bmm ( v , w_ ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_ . reshape ( b , c , h , w )
h_ = self . proj_out ( h_ )
return x + h_
return x + h_
class MemoryEfficientAttnBlock ( nn . Module ) :
"""
@ -213,32 +177,17 @@ class MemoryEfficientAttnBlock(nn.Module):
see https : / / github . com / MatthieuTPHR / diffusers / blob / d80b531ff8060ec1ea982b65a1b8df70f73aa67c / src / diffusers / models / attention . py #L223
Note : this is a single - head self - attention operation
"""
#
def __init__ ( self , in_channels ) :
super ( ) . __init__ ( )
self . in_channels = in_channels
self . norm = Normalize ( in_channels )
self . q = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . k = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . v = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . proj_out = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . q = torch . nn . Conv2d ( in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
self . k = torch . nn . Conv2d ( in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
self . v = torch . nn . Conv2d ( in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
self . proj_out = torch . nn . Conv2d ( in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
self . attention_op : Optional [ Any ] = None
def forward ( self , x ) :
@ -253,27 +202,20 @@ class MemoryEfficientAttnBlock(nn.Module):
q , k , v = map ( lambda x : rearrange ( x , ' b c h w -> b (h w) c ' ) , ( q , k , v ) )
q , k , v = map (
lambda t : t . unsqueeze ( 3 )
. reshape ( B , t . shape [ 1 ] , 1 , C )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( B * 1 , t . shape [ 1 ] , C )
. contiguous ( ) ,
lambda t : t . unsqueeze ( 3 ) . reshape ( B , t . shape [ 1 ] , 1 , C ) . permute ( 0 , 2 , 1 , 3 ) . reshape ( B * 1 , t . shape [ 1 ] , C ) .
contiguous ( ) ,
( q , k , v ) ,
)
out = xformers . ops . memory_efficient_attention ( q , k , v , attn_bias = None , op = self . attention_op )
out = (
out . unsqueeze ( 0 )
. reshape ( B , 1 , out . shape [ 1 ] , C )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( B , out . shape [ 1 ] , C )
)
out = ( out . unsqueeze ( 0 ) . reshape ( B , 1 , out . shape [ 1 ] , C ) . permute ( 0 , 2 , 1 , 3 ) . reshape ( B , out . shape [ 1 ] , C ) )
out = rearrange ( out , ' b (h w) c -> b c h w ' , b = B , h = H , w = W , c = C )
out = self . proj_out ( out )
return x + out
return x + out
class MemoryEfficientCrossAttentionWrapper ( MemoryEfficientCrossAttention ) :
def forward ( self , x , context = None , mask = None ) :
b , c , h , w = x . shape
x = rearrange ( x , ' b c h w -> b (h w) c ' )
@ -283,10 +225,10 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def make_attn ( in_channels , attn_type = " vanilla " , attn_kwargs = None ) :
assert attn_type in [ " vanilla " , " vanilla-xformers " , " memory-efficient-cross-attn " , " linear " , " none " ] , f ' attn_type { attn_type } unknown '
assert attn_type in [ " vanilla " , " vanilla-xformers " , " memory-efficient-cross-attn " , " linear " ,
" none " ] , f ' attn_type { attn_type } unknown '
if XFORMERS_IS_AVAILBLE and attn_type == " vanilla " :
attn_type = " vanilla-xformers "
rank_zero_info ( f " making attention of type ' { attn_type } ' with { in_channels } in_channels " )
if attn_type == " vanilla " :
assert attn_kwargs is None
return AttnBlock ( in_channels )
@ -303,13 +245,26 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
class Model ( nn . Module ) :
def __init__ ( self , * , ch , out_ch , ch_mult = ( 1 , 2 , 4 , 8 ) , num_res_blocks ,
attn_resolutions , dropout = 0.0 , resamp_with_conv = True , in_channels ,
resolution , use_timestep = True , use_linear_attn = False , attn_type = " vanilla " ) :
def __init__ ( self ,
* ,
ch ,
out_ch ,
ch_mult = ( 1 , 2 , 4 , 8 ) ,
num_res_blocks ,
attn_resolutions ,
dropout = 0.0 ,
resamp_with_conv = True ,
in_channels ,
resolution ,
use_timestep = True ,
use_linear_attn = False ,
attn_type = " vanilla " ) :
super ( ) . __init__ ( )
if use_linear_attn : attn_type = " linear "
if use_linear_attn :
attn_type = " linear "
self . ch = ch
self . temb_ch = self . ch * 4
self . temb_ch = self . ch * 4
self . num_resolutions = len ( ch_mult )
self . num_res_blocks = num_res_blocks
self . resolution = resolution
@ -320,39 +275,34 @@ class Model(nn.Module):
# timestep embedding
self . temb = nn . Module ( )
self . temb . dense = nn . ModuleList ( [
torch . nn . Linear ( self . ch ,
self . temb_ch ) ,
torch . nn . Linear ( self . temb_ch ,
self . temb_ch ) ,
torch . nn . Linear ( self . ch , self . temb_ch ) ,
torch . nn . Linear ( self . temb_ch , self . temb_ch ) ,
] )
# downsampling
self . conv_in = torch . nn . Conv2d ( in_channels ,
self . ch ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
self . conv_in = torch . nn . Conv2d ( in_channels , self . ch , kernel_size = 3 , stride = 1 , padding = 1 )
curr_res = resolution
in_ch_mult = ( 1 , ) + tuple ( ch_mult )
in_ch_mult = ( 1 , ) + tuple ( ch_mult )
self . down = nn . ModuleList ( )
for i_level in range ( self . num_resolutions ) :
block = nn . ModuleList ( )
attn = nn . ModuleList ( )
block_in = ch * in_ch_mult [ i_level ]
block_out = ch * ch_mult [ i_level ]
block_in = ch * in_ch_mult [ i_level ]
block_out = ch * ch_mult [ i_level ]
for i_block in range ( self . num_res_blocks ) :
block . append ( ResnetBlock ( in_channels = block_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ) )
block . append (
ResnetBlock ( in_channels = block_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ) )
block_in = block_out
if curr_res in attn_resolutions :
attn . append ( make_attn ( block_in , attn_type = attn_type ) )
down = nn . Module ( )
down . block = block
down . attn = attn
if i_level != self . num_resolutions - 1 :
if i_level != self . num_resolutions - 1 :
down . downsample = Downsample ( block_in , resamp_with_conv )
curr_res = curr_res / / 2
self . down . append ( down )
@ -374,15 +324,16 @@ class Model(nn.Module):
for i_level in reversed ( range ( self . num_resolutions ) ) :
block = nn . ModuleList ( )
attn = nn . ModuleList ( )
block_out = ch * ch_mult [ i_level ]
skip_in = ch * ch_mult [ i_level ]
for i_block in range ( self . num_res_blocks + 1 ) :
block_out = ch * ch_mult [ i_level ]
skip_in = ch * ch_mult [ i_level ]
for i_block in range ( self . num_res_blocks + 1 ) :
if i_block == self . num_res_blocks :
skip_in = ch * in_ch_mult [ i_level ]
block . append ( ResnetBlock ( in_channels = block_in + skip_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ) )
skip_in = ch * in_ch_mult [ i_level ]
block . append (
ResnetBlock ( in_channels = block_in + skip_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ) )
block_in = block_out
if curr_res in attn_resolutions :
attn . append ( make_attn ( block_in , attn_type = attn_type ) )
@ -392,15 +343,11 @@ class Model(nn.Module):
if i_level != 0 :
up . upsample = Upsample ( block_in , resamp_with_conv )
curr_res = curr_res * 2
self . up . insert ( 0 , up ) # prepend to get consistent order
self . up . insert ( 0 , up ) # prepend to get consistent order
# end
self . norm_out = Normalize ( block_in )
self . conv_out = torch . nn . Conv2d ( block_in ,
out_ch ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
self . conv_out = torch . nn . Conv2d ( block_in , out_ch , kernel_size = 3 , stride = 1 , padding = 1 )
def forward ( self , x , t = None , context = None ) :
#assert x.shape[2] == x.shape[3] == self.resolution
@ -425,7 +372,7 @@ class Model(nn.Module):
if len ( self . down [ i_level ] . attn ) > 0 :
h = self . down [ i_level ] . attn [ i_block ] ( h )
hs . append ( h )
if i_level != self . num_resolutions - 1 :
if i_level != self . num_resolutions - 1 :
hs . append ( self . down [ i_level ] . downsample ( hs [ - 1 ] ) )
# middle
@ -436,9 +383,8 @@ class Model(nn.Module):
# upsampling
for i_level in reversed ( range ( self . num_resolutions ) ) :
for i_block in range ( self . num_res_blocks + 1 ) :
h = self . up [ i_level ] . block [ i_block ] (
torch . cat ( [ h , hs . pop ( ) ] , dim = 1 ) , temb )
for i_block in range ( self . num_res_blocks + 1 ) :
h = self . up [ i_level ] . block [ i_block ] ( torch . cat ( [ h , hs . pop ( ) ] , dim = 1 ) , temb )
if len ( self . up [ i_level ] . attn ) > 0 :
h = self . up [ i_level ] . attn [ i_block ] ( h )
if i_level != 0 :
@ -455,12 +401,26 @@ class Model(nn.Module):
class Encoder ( nn . Module ) :
def __init__ ( self , * , ch , out_ch , ch_mult = ( 1 , 2 , 4 , 8 ) , num_res_blocks ,
attn_resolutions , dropout = 0.0 , resamp_with_conv = True , in_channels ,
resolution , z_channels , double_z = True , use_linear_attn = False , attn_type = " vanilla " ,
def __init__ ( self ,
* ,
ch ,
out_ch ,
ch_mult = ( 1 , 2 , 4 , 8 ) ,
num_res_blocks ,
attn_resolutions ,
dropout = 0.0 ,
resamp_with_conv = True ,
in_channels ,
resolution ,
z_channels ,
double_z = True ,
use_linear_attn = False ,
attn_type = " vanilla " ,
* * ignore_kwargs ) :
super ( ) . __init__ ( )
if use_linear_attn : attn_type = " linear "
if use_linear_attn :
attn_type = " linear "
self . ch = ch
self . temb_ch = 0
self . num_resolutions = len ( ch_mult )
@ -469,33 +429,30 @@ class Encoder(nn.Module):
self . in_channels = in_channels
# downsampling
self . conv_in = torch . nn . Conv2d ( in_channels ,
self . ch ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
self . conv_in = torch . nn . Conv2d ( in_channels , self . ch , kernel_size = 3 , stride = 1 , padding = 1 )
curr_res = resolution
in_ch_mult = ( 1 , ) + tuple ( ch_mult )
in_ch_mult = ( 1 , ) + tuple ( ch_mult )
self . in_ch_mult = in_ch_mult
self . down = nn . ModuleList ( )
for i_level in range ( self . num_resolutions ) :
block = nn . ModuleList ( )
attn = nn . ModuleList ( )
block_in = ch * in_ch_mult [ i_level ]
block_out = ch * ch_mult [ i_level ]
block_in = ch * in_ch_mult [ i_level ]
block_out = ch * ch_mult [ i_level ]
for i_block in range ( self . num_res_blocks ) :
block . append ( ResnetBlock ( in_channels = block_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ) )
block . append (
ResnetBlock ( in_channels = block_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ) )
block_in = block_out
if curr_res in attn_resolutions :
attn . append ( make_attn ( block_in , attn_type = attn_type ) )
down = nn . Module ( )
down . block = block
down . attn = attn
if i_level != self . num_resolutions - 1 :
if i_level != self . num_resolutions - 1 :
down . downsample = Downsample ( block_in , resamp_with_conv )
curr_res = curr_res / / 2
self . down . append ( down )
@ -515,7 +472,7 @@ class Encoder(nn.Module):
# end
self . norm_out = Normalize ( block_in )
self . conv_out = torch . nn . Conv2d ( block_in ,
2 * z_channels if double_z else z_channels ,
2 * z_channels if double_z else z_channels ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
@ -532,7 +489,7 @@ class Encoder(nn.Module):
if len ( self . down [ i_level ] . attn ) > 0 :
h = self . down [ i_level ] . attn [ i_block ] ( h )
hs . append ( h )
if i_level != self . num_resolutions - 1 :
if i_level != self . num_resolutions - 1 :
hs . append ( self . down [ i_level ] . downsample ( hs [ - 1 ] ) )
# middle
@ -549,12 +506,27 @@ class Encoder(nn.Module):
class Decoder ( nn . Module ) :
def __init__ ( self , * , ch , out_ch , ch_mult = ( 1 , 2 , 4 , 8 ) , num_res_blocks ,
attn_resolutions , dropout = 0.0 , resamp_with_conv = True , in_channels ,
resolution , z_channels , give_pre_end = False , tanh_out = False , use_linear_attn = False ,
attn_type = " vanilla " , * * ignorekwargs ) :
def __init__ ( self ,
* ,
ch ,
out_ch ,
ch_mult = ( 1 , 2 , 4 , 8 ) ,
num_res_blocks ,
attn_resolutions ,
dropout = 0.0 ,
resamp_with_conv = True ,
in_channels ,
resolution ,
z_channels ,
give_pre_end = False ,
tanh_out = False ,
use_linear_attn = False ,
attn_type = " vanilla " ,
* * ignorekwargs ) :
super ( ) . __init__ ( )
if use_linear_attn : attn_type = " linear "
if use_linear_attn :
attn_type = " linear "
self . ch = ch
self . temb_ch = 0
self . num_resolutions = len ( ch_mult )
@ -565,19 +537,14 @@ class Decoder(nn.Module):
self . tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = ( 1 , ) + tuple ( ch_mult )
block_in = ch * ch_mult [ self . num_resolutions - 1 ]
curr_res = resolution / / 2 * * ( self . num_resolutions - 1 )
self . z_shape = ( 1 , z_channels , curr_res , curr_res )
rank_zero_info ( " Working with z of shape {} = {} dimensions. " . format (
self . z_shape , np . prod ( self . z_shape ) ) )
in_ch_mult = ( 1 , ) + tuple ( ch_mult )
block_in = ch * ch_mult [ self . num_resolutions - 1 ]
curr_res = resolution / / 2 * * ( self . num_resolutions - 1 )
self . z_shape = ( 1 , z_channels , curr_res , curr_res )
rank_zero_info ( " Working with z of shape {} = {} dimensions. " . format ( self . z_shape , np . prod ( self . z_shape ) ) )
# z to block_in
self . conv_in = torch . nn . Conv2d ( z_channels ,
block_in ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
self . conv_in = torch . nn . Conv2d ( z_channels , block_in , kernel_size = 3 , stride = 1 , padding = 1 )
# middle
self . mid = nn . Module ( )
@ -596,12 +563,13 @@ class Decoder(nn.Module):
for i_level in reversed ( range ( self . num_resolutions ) ) :
block = nn . ModuleList ( )
attn = nn . ModuleList ( )
block_out = ch * ch_mult [ i_level ]
for i_block in range ( self . num_res_blocks + 1 ) :
block . append ( ResnetBlock ( in_channels = block_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ) )
block_out = ch * ch_mult [ i_level ]
for i_block in range ( self . num_res_blocks + 1 ) :
block . append (
ResnetBlock ( in_channels = block_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ) )
block_in = block_out
if curr_res in attn_resolutions :
attn . append ( make_attn ( block_in , attn_type = attn_type ) )
@ -611,15 +579,11 @@ class Decoder(nn.Module):
if i_level != 0 :
up . upsample = Upsample ( block_in , resamp_with_conv )
curr_res = curr_res * 2
self . up . insert ( 0 , up ) # prepend to get consistent order
self . up . insert ( 0 , up ) # prepend to get consistent order
# end
self . norm_out = Normalize ( block_in )
self . conv_out = torch . nn . Conv2d ( block_in ,
out_ch ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
self . conv_out = torch . nn . Conv2d ( block_in , out_ch , kernel_size = 3 , stride = 1 , padding = 1 )
def forward ( self , z ) :
#assert z.shape[1:] == self.z_shape[1:]
@ -638,7 +602,7 @@ class Decoder(nn.Module):
# upsampling
for i_level in reversed ( range ( self . num_resolutions ) ) :
for i_block in range ( self . num_res_blocks + 1 ) :
for i_block in range ( self . num_res_blocks + 1 ) :
h = self . up [ i_level ] . block [ i_block ] ( h , temb )
if len ( self . up [ i_level ] . attn ) > 0 :
h = self . up [ i_level ] . attn [ i_block ] ( h )
@ -658,31 +622,24 @@ class Decoder(nn.Module):
class SimpleDecoder ( nn . Module ) :
def __init__ ( self , in_channels , out_channels , * args , * * kwargs ) :
super ( ) . __init__ ( )
self . model = nn . ModuleList ( [ nn . Conv2d ( in_channels , in_channels , 1 ) ,
ResnetBlock ( in_channels = in_channels ,
out_channels = 2 * in_channels ,
temb_channels = 0 , dropout = 0.0 ) ,
ResnetBlock ( in_channels = 2 * in_channels ,
out_channels = 4 * in_channels ,
temb_channels = 0 , dropout = 0.0 ) ,
ResnetBlock ( in_channels = 4 * in_channels ,
out_channels = 2 * in_channels ,
temb_channels = 0 , dropout = 0.0 ) ,
nn . Conv2d ( 2 * in_channels , in_channels , 1 ) ,
Upsample ( in_channels , with_conv = True ) ] )
self . model = nn . ModuleList ( [
nn . Conv2d ( in_channels , in_channels , 1 ) ,
ResnetBlock ( in_channels = in_channels , out_channels = 2 * in_channels , temb_channels = 0 , dropout = 0.0 ) ,
ResnetBlock ( in_channels = 2 * in_channels , out_channels = 4 * in_channels , temb_channels = 0 , dropout = 0.0 ) ,
ResnetBlock ( in_channels = 4 * in_channels , out_channels = 2 * in_channels , temb_channels = 0 , dropout = 0.0 ) ,
nn . Conv2d ( 2 * in_channels , in_channels , 1 ) ,
Upsample ( in_channels , with_conv = True )
] )
# end
self . norm_out = Normalize ( in_channels )
self . conv_out = torch . nn . Conv2d ( in_channels ,
out_channels ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
self . conv_out = torch . nn . Conv2d ( in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
def forward ( self , x ) :
for i , layer in enumerate ( self . model ) :
if i in [ 1 , 2 , 3 ] :
if i in [ 1 , 2 , 3 ] :
x = layer ( x , None )
else :
x = layer ( x )
@ -694,25 +651,26 @@ class SimpleDecoder(nn.Module):
class UpsampleDecoder ( nn . Module ) :
def __init__ ( self , in_channels , out_channels , ch , num_res_blocks , resolution ,
ch_mult = ( 2 , 2 ) , dropout = 0.0 ) :
def __init__ ( self , in_channels , out_channels , ch , num_res_blocks , resolution , ch_mult = ( 2 , 2 ) , dropout = 0.0 ) :
super ( ) . __init__ ( )
# upsampling
self . temb_ch = 0
self . num_resolutions = len ( ch_mult )
self . num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution / / 2 * * ( self . num_resolutions - 1 )
curr_res = resolution / / 2 * * ( self . num_resolutions - 1 )
self . res_blocks = nn . ModuleList ( )
self . upsample_blocks = nn . ModuleList ( )
for i_level in range ( self . num_resolutions ) :
res_block = [ ]
block_out = ch * ch_mult [ i_level ]
for i_block in range ( self . num_res_blocks + 1 ) :
res_block . append ( ResnetBlock ( in_channels = block_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ) )
res_block . append (
ResnetBlock ( in_channels = block_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ) )
block_in = block_out
self . res_blocks . append ( nn . ModuleList ( res_block ) )
if i_level != self . num_resolutions - 1 :
@ -721,11 +679,7 @@ class UpsampleDecoder(nn.Module):
# end
self . norm_out = Normalize ( block_in )
self . conv_out = torch . nn . Conv2d ( block_in ,
out_channels ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
self . conv_out = torch . nn . Conv2d ( block_in , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
def forward ( self , x ) :
# upsampling
@ -742,35 +696,35 @@ class UpsampleDecoder(nn.Module):
class LatentRescaler ( nn . Module ) :
def __init__ ( self , factor , in_channels , mid_channels , out_channels , depth = 2 ) :
super ( ) . __init__ ( )
# residual block, interpolate, residual block
self . factor = factor
self . conv_in = nn . Conv2d ( in_channels ,
mid_channels ,
kernel_size = 3 ,
stride = 1 ,
padding = 1 )
self . res_block1 = nn . ModuleList ( [ ResnetBlock ( in_channels = mid_channels ,
out_channels = mid_channels ,
temb_channels = 0 ,
dropout = 0.0 ) for _ in range ( depth ) ] )
self . conv_in = nn . Conv2d ( in_channels , mid_channels , kernel_size = 3 , stride = 1 , padding = 1 )
self . res_block1 = nn . ModuleList ( [
ResnetBlock ( in_channels = mid_channels , out_channels = mid_channels , temb_channels = 0 , dropout = 0.0 )
for _ in range ( depth )
] )
self . attn = AttnBlock ( mid_channels )
self . res_block2 = nn . ModuleList ( [ ResnetBlock ( in_channels = mid_channels ,
out_channels = mid_channels ,
temb_channels = 0 ,
dropout = 0.0 ) for _ in range ( depth ) ] )
self . conv_out = nn . Conv2d ( mid_channels ,
out_channels ,
kernel_size = 1 ,
)
self . res_block2 = nn . ModuleList ( [
ResnetBlock ( in_channels = mid_channels , out_channels = mid_channels , temb_channels = 0 , dropout = 0.0 )
for _ in range ( depth )
] )
self . conv_out = nn . Conv2d (
mid_channels ,
out_channels ,
kernel_size = 1 ,
)
def forward ( self , x ) :
x = self . conv_in ( x )
for block in self . res_block1 :
x = block ( x , None )
x = torch . nn . functional . interpolate ( x , size = ( int ( round ( x . shape [ 2 ] * self . factor ) ) , int ( round ( x . shape [ 3 ] * self . factor ) ) ) )
x = torch . nn . functional . interpolate ( x ,
size = ( int ( round ( x . shape [ 2 ] * self . factor ) ) ,
int ( round ( x . shape [ 3 ] * self . factor ) ) ) )
x = self . attn ( x )
for block in self . res_block2 :
x = block ( x , None )
@ -779,17 +733,37 @@ class LatentRescaler(nn.Module):
class MergedRescaleEncoder ( nn . Module ) :
def __init__ ( self , in_channels , ch , resolution , out_ch , num_res_blocks ,
attn_resolutions , dropout = 0.0 , resamp_with_conv = True ,
ch_mult = ( 1 , 2 , 4 , 8 ) , rescale_factor = 1.0 , rescale_module_depth = 1 ) :
def __init__ ( self ,
in_channels ,
ch ,
resolution ,
out_ch ,
num_res_blocks ,
attn_resolutions ,
dropout = 0.0 ,
resamp_with_conv = True ,
ch_mult = ( 1 , 2 , 4 , 8 ) ,
rescale_factor = 1.0 ,
rescale_module_depth = 1 ) :
super ( ) . __init__ ( )
intermediate_chn = ch * ch_mult [ - 1 ]
self . encoder = Encoder ( in_channels = in_channels , num_res_blocks = num_res_blocks , ch = ch , ch_mult = ch_mult ,
z_channels = intermediate_chn , double_z = False , resolution = resolution ,
attn_resolutions = attn_resolutions , dropout = dropout , resamp_with_conv = resamp_with_conv ,
self . encoder = Encoder ( in_channels = in_channels ,
num_res_blocks = num_res_blocks ,
ch = ch ,
ch_mult = ch_mult ,
z_channels = intermediate_chn ,
double_z = False ,
resolution = resolution ,
attn_resolutions = attn_resolutions ,
dropout = dropout ,
resamp_with_conv = resamp_with_conv ,
out_ch = None )
self . rescaler = LatentRescaler ( factor = rescale_factor , in_channels = intermediate_chn ,
mid_channels = intermediate_chn , out_channels = out_ch , depth = rescale_module_depth )
self . rescaler = LatentRescaler ( factor = rescale_factor ,
in_channels = intermediate_chn ,
mid_channels = intermediate_chn ,
out_channels = out_ch ,
depth = rescale_module_depth )
def forward ( self , x ) :
x = self . encoder ( x )
@ -798,15 +772,36 @@ class MergedRescaleEncoder(nn.Module):
class MergedRescaleDecoder ( nn . Module ) :
def __init__ ( self , z_channels , out_ch , resolution , num_res_blocks , attn_resolutions , ch , ch_mult = ( 1 , 2 , 4 , 8 ) ,
dropout = 0.0 , resamp_with_conv = True , rescale_factor = 1.0 , rescale_module_depth = 1 ) :
def __init__ ( self ,
z_channels ,
out_ch ,
resolution ,
num_res_blocks ,
attn_resolutions ,
ch ,
ch_mult = ( 1 , 2 , 4 , 8 ) ,
dropout = 0.0 ,
resamp_with_conv = True ,
rescale_factor = 1.0 ,
rescale_module_depth = 1 ) :
super ( ) . __init__ ( )
tmp_chn = z_channels * ch_mult [ - 1 ]
self . decoder = Decoder ( out_ch = out_ch , z_channels = tmp_chn , attn_resolutions = attn_resolutions , dropout = dropout ,
resamp_with_conv = resamp_with_conv , in_channels = None , num_res_blocks = num_res_blocks ,
ch_mult = ch_mult , resolution = resolution , ch = ch )
self . rescaler = LatentRescaler ( factor = rescale_factor , in_channels = z_channels , mid_channels = tmp_chn ,
out_channels = tmp_chn , depth = rescale_module_depth )
tmp_chn = z_channels * ch_mult [ - 1 ]
self . decoder = Decoder ( out_ch = out_ch ,
z_channels = tmp_chn ,
attn_resolutions = attn_resolutions ,
dropout = dropout ,
resamp_with_conv = resamp_with_conv ,
in_channels = None ,
num_res_blocks = num_res_blocks ,
ch_mult = ch_mult ,
resolution = resolution ,
ch = ch )
self . rescaler = LatentRescaler ( factor = rescale_factor ,
in_channels = z_channels ,
mid_channels = tmp_chn ,
out_channels = tmp_chn ,
depth = rescale_module_depth )
def forward ( self , x ) :
x = self . rescaler ( x )
@ -815,16 +810,26 @@ class MergedRescaleDecoder(nn.Module):
class Upsampler ( nn . Module ) :
def __init__ ( self , in_size , out_size , in_channels , out_channels , ch_mult = 2 ) :
super ( ) . __init__ ( )
assert out_size > = in_size
num_blocks = int ( np . log2 ( out_size / / in_size ) ) + 1
factor_up = 1. + ( out_size % in_size )
rank_zero_info ( f " Building { self . __class__ . __name__ } with in_size: { in_size } --> out_size { out_size } and factor { factor_up } " )
self . rescaler = LatentRescaler ( factor = factor_up , in_channels = in_channels , mid_channels = 2 * in_channels ,
num_blocks = int ( np . log2 ( out_size / / in_size ) ) + 1
factor_up = 1. + ( out_size % in_size )
rank_zero_info (
f " Building { self . __class__ . __name__ } with in_size: { in_size } --> out_size { out_size } and factor { factor_up } "
)
self . rescaler = LatentRescaler ( factor = factor_up ,
in_channels = in_channels ,
mid_channels = 2 * in_channels ,
out_channels = in_channels )
self . decoder = Decoder ( out_ch = out_channels , resolution = out_size , z_channels = in_channels , num_res_blocks = 2 ,
attn_resolutions = [ ] , in_channels = None , ch = in_channels ,
self . decoder = Decoder ( out_ch = out_channels ,
resolution = out_size ,
z_channels = in_channels ,
num_res_blocks = 2 ,
attn_resolutions = [ ] ,
in_channels = None ,
ch = in_channels ,
ch_mult = [ ch_mult for _ in range ( num_blocks ) ] )
def forward ( self , x ) :
@ -834,23 +839,21 @@ class Upsampler(nn.Module):
class Resize ( nn . Module ) :
def __init__ ( self , in_channels = None , learned = False , mode = " bilinear " ) :
super ( ) . __init__ ( )
self . with_conv = learned
self . mode = mode
if self . with_conv :
rank_zero_info ( f " Note: { self . __class__ . __name } uses learned downsampling and will ignore the fixed { mode } mode " )
rank_zero_info (
f " Note: { self . __class__ . __name } uses learned downsampling and will ignore the fixed { mode } mode " )
raise NotImplementedError ( )
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self . conv = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 4 ,
stride = 2 ,
padding = 1 )
self . conv = torch . nn . Conv2d ( in_channels , in_channels , kernel_size = 4 , stride = 2 , padding = 1 )
def forward ( self , x , scale_factor = 1.0 ) :
if scale_factor == 1.0 :
if scale_factor == 1.0 :
return x
else :
x = torch . nn . functional . interpolate ( x , mode = self . mode , align_corners = False , scale_factor = scale_factor )