@ -1,24 +1,27 @@
# pylama:ignore=W0613,W0612
# pytorch_diffusion + derived encoder decoder
import gc
import logging
import math
from typing import Any , Optional
import numpy as np
import torch
import torch . nn as nn
from einops import rearrange
from torch import nn
from imaginairy . modules . attention import LinearAttention
from imaginairy . modules . distributions import DiagonalGaussianDistribution
from imaginairy . utils import get_device , instantiate_from_config
from imaginairy . modules . attention import MemoryEfficientCrossAttention
logger = logging . getLogger ( __name__ )
try :
import xformers
import xformers . ops
XFORMERS_IS_AVAILBLE = True
except :
XFORMERS_IS_AVAILBLE = False
# print("No module 'xformers'. Proceeding without it.")
def get_timestep_embedding ( timesteps , embedding_dim ) :
"""
Matches the implementation in Denoising Diffusion Probabilistic Models :
This m atches the implementation in Denoising Diffusion Probabilistic Models :
From Fairseq .
Build sinusoidal embeddings .
This matches the implementation in tensor2tensor , but differs slightly
@ -39,11 +42,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
def nonlinearity ( x ) :
# swish
t = torch . sigmoid ( x )
x * = t
del t
return x
return x * torch . sigmoid ( x )
def Normalize ( in_channels , num_groups = 32 ) :
@ -126,30 +125,18 @@ class ResnetBlock(nn.Module):
)
def forward ( self , x , temb ) :
h1 = x
h2 = self . norm1 ( h1 )
del h1
h3 = nonlinearity ( h2 )
del h2
h4 = self . conv1 ( h3 )
del h3
h = x
h = self . norm1 ( h )
h = nonlinearity ( h )
h = self . conv1 ( h )
if temb is not None :
h4 = h4 + self . temb_proj ( nonlinearity ( temb ) ) [ : , : , None , None ]
h5 = self . norm2 ( h4 )
del h4
h = h + self . temb_proj ( nonlinearity ( temb ) ) [ : , : , None , None ]
h6 = nonlinearity ( h5 )
del h5
h7 = self . dropout ( h6 )
del h6
h8 = self . conv2 ( h7 )
del h7
h = self . norm2 ( h )
h = nonlinearity ( h )
h = self . dropout ( h )
h = self . conv2 ( h )
if self . in_channels != self . out_channels :
if self . use_conv_shortcut :
@ -157,14 +144,7 @@ class ResnetBlock(nn.Module):
else :
x = self . nin_shortcut ( x )
return x + h8
class LinAttnBlock ( LinearAttention ) :
""" to match AttnBlock usage """
def __init__ ( self , in_channels ) :
super ( ) . __init__ ( dim = in_channels , heads = 1 , dim_head = in_channels )
return x + h
class AttnBlock ( nn . Module ) :
@ -187,8 +167,6 @@ class AttnBlock(nn.Module):
)
def forward ( self , x ) :
if get_device ( ) == " cuda " :
return self . forward_cuda ( x )
h_ = x
h_ = self . norm ( h_ )
q = self . q ( h_ )
@ -214,83 +192,276 @@ class AttnBlock(nn.Module):
return x + h_
def forward_cuda ( self , x ) :
class MemoryEfficientAttnBlock ( nn . Module ) :
"""
Uses xformers efficient implementation ,
see https : / / github . com / MatthieuTPHR / diffusers / blob / d80b531ff8060ec1ea982b65a1b8df70f73aa67c / src / diffusers / models / attention . py #L223
Note : this is a single - head self - attention operation
"""
#
def __init__ ( self , in_channels ) :
super ( ) . __init__ ( )
self . in_channels = in_channels
self . norm = Normalize ( in_channels )
self . q = torch . nn . Conv2d (
in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0
)
self . k = torch . nn . Conv2d (
in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0
)
self . v = torch . nn . Conv2d (
in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0
)
self . proj_out = torch . nn . Conv2d (
in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0
)
self . attention_op : Optional [ Any ] = None
def forward ( self , x ) :
h_ = x
h_ = self . norm ( h_ )
q1 = self . q ( h_ )
k1 = self . k ( h_ )
q = self . q ( h_ )
k = self . k ( h_ )
v = self . v ( h_ )
# compute attention
b , c , h , w = q1 . shape
q2 = q1 . reshape ( b , c , h * w )
del q1
q = q2 . permute ( 0 , 2 , 1 ) # b,hw,c
del q2
k = k1 . reshape ( b , c , h * w ) # b,c,hw
del k1
B , C , H , W = q . shape
q , k , v = map ( lambda x : rearrange ( x , " b c h w -> b (h w) c " ) , ( q , k , v ) )
q , k , v = map (
lambda t : t . unsqueeze ( 3 )
. reshape ( B , t . shape [ 1 ] , 1 , C )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( B * 1 , t . shape [ 1 ] , C )
. contiguous ( ) ,
( q , k , v ) ,
)
out = xformers . ops . memory_efficient_attention (
q , k , v , attn_bias = None , op = self . attention_op
)
h_ = torch . zeros_like ( k , device = q . device )
out = (
out . unsqueeze ( 0 )
. reshape ( B , 1 , out . shape [ 1 ] , C )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( B , out . shape [ 1 ] , C )
)
out = rearrange ( out , " b (h w) c -> b c h w " , b = B , h = H , w = W , c = C )
out = self . proj_out ( out )
return x + out
class MemoryEfficientCrossAttentionWrapper ( MemoryEfficientCrossAttention ) :
def forward ( self , x , context = None , mask = None ) :
b , c , h , w = x . shape
x = rearrange ( x , " b c h w -> b (h w) c " )
out = super ( ) . forward ( x , context = context , mask = mask )
out = rearrange ( out , " b (h w) c -> b c h w " , h = h , w = w , c = c )
return x + out
def make_attn ( in_channels , attn_type = " vanilla " , attn_kwargs = None ) :
assert attn_type in [
" vanilla " ,
" vanilla-xformers " ,
" memory-efficient-cross-attn " ,
" linear " ,
" none " ,
] , f " attn_type { attn_type } unknown "
if XFORMERS_IS_AVAILBLE and attn_type == " vanilla " :
attn_type = " vanilla-xformers "
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == " vanilla " :
assert attn_kwargs is None
return AttnBlock ( in_channels )
elif attn_type == " vanilla-xformers " :
# print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock ( in_channels )
elif type == " memory-efficient-cross-attn " :
attn_kwargs [ " query_dim " ] = in_channels
return MemoryEfficientCrossAttentionWrapper ( * * attn_kwargs )
elif attn_type == " none " :
return nn . Identity ( in_channels )
else :
raise NotImplementedError ( )
stats = torch . cuda . memory_stats ( q . device )
mem_active = stats [ " active_bytes.all.current " ]
mem_reserved = stats [ " reserved_bytes.all.current " ]
mem_free_cuda , _ = torch . cuda . mem_get_info ( torch . cuda . current_device ( ) )
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q . shape [ 0 ] * q . shape [ 1 ] * k . shape [ 2 ] * q . element_size ( )
mem_required = tensor_size * 2.5
steps = 1
class Model ( nn . Module ) :
def __init__ (
self ,
* ,
ch ,
out_ch ,
ch_mult = ( 1 , 2 , 4 , 8 ) ,
num_res_blocks ,
attn_resolutions ,
dropout = 0.0 ,
resamp_with_conv = True ,
in_channels ,
resolution ,
use_timestep = True ,
use_linear_attn = False ,
attn_type = " vanilla " ,
) :
super ( ) . __init__ ( )
if use_linear_attn :
attn_type = " linear "
self . ch = ch
self . temb_ch = self . ch * 4
self . num_resolutions = len ( ch_mult )
self . num_res_blocks = num_res_blocks
self . resolution = resolution
self . in_channels = in_channels
if mem_required > mem_free_total :
steps = 2 * * ( math . ceil ( math . log ( mem_required / mem_free_total , 2 ) ) )
self . use_timestep = use_timestep
if self . use_timestep :
# timestep embedding
self . temb = nn . Module ( )
self . temb . dense = nn . ModuleList (
[
torch . nn . Linear ( self . ch , self . temb_ch ) ,
torch . nn . Linear ( self . temb_ch , self . temb_ch ) ,
]
)
slice_size = q . shape [ 1 ] / / steps if ( q . shape [ 1 ] % steps ) == 0 else q . shape [ 1 ]
for i in range ( 0 , q . shape [ 1 ] , slice_size ) :
end = i + slice_size
# downsampling
self . conv_in = torch . nn . Conv2d (
in_channels , self . ch , kernel_size = 3 , stride = 1 , padding = 1
)
w1 = torch . bmm ( q [ : , i : end ] , k ) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * ( int ( c ) * * ( - 0.5 ) )
del w1
w3 = torch . nn . functional . softmax ( w2 , dim = 2 , dtype = q . dtype )
del w2
curr_res = resolution
in_ch_mult = ( 1 , ) + tuple ( ch_mult )
self . down = nn . ModuleList ( )
for i_level in range ( self . num_resolutions ) :
block = nn . ModuleList ( )
attn = nn . ModuleList ( )
block_in = ch * in_ch_mult [ i_level ]
block_out = ch * ch_mult [ i_level ]
for i_block in range ( self . num_res_blocks ) :
block . append (
ResnetBlock (
in_channels = block_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ,
)
)
block_in = block_out
if curr_res in attn_resolutions :
attn . append ( make_attn ( block_in , attn_type = attn_type ) )
down = nn . Module ( )
down . block = block
down . attn = attn
if i_level != self . num_resolutions - 1 :
down . downsample = Downsample ( block_in , resamp_with_conv )
curr_res = curr_res / / 2
self . down . append ( down )
# attend to values
v1 = v . reshape ( b , c , h * w )
w4 = w3 . permute ( 0 , 2 , 1 ) # b,hw,hw (first hw of k, second of q)
del w3
# middle
self . mid = nn . Module ( )
self . mid . block_1 = ResnetBlock (
in_channels = block_in ,
out_channels = block_in ,
temb_channels = self . temb_ch ,
dropout = dropout ,
)
self . mid . attn_1 = make_attn ( block_in , attn_type = attn_type )
self . mid . block_2 = ResnetBlock (
in_channels = block_in ,
out_channels = block_in ,
temb_channels = self . temb_ch ,
dropout = dropout ,
)
h_ [ : , : , i : end ] = torch . bmm (
v1 , w4
) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1 , w4
# upsampling
self . up = nn . ModuleList ( )
for i_level in reversed ( range ( self . num_resolutions ) ) :
block = nn . ModuleList ( )
attn = nn . ModuleList ( )
block_out = ch * ch_mult [ i_level ]
skip_in = ch * ch_mult [ i_level ]
for i_block in range ( self . num_res_blocks + 1 ) :
if i_block == self . num_res_blocks :
skip_in = ch * in_ch_mult [ i_level ]
block . append (
ResnetBlock (
in_channels = block_in + skip_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ,
)
)
block_in = block_out
if curr_res in attn_resolutions :
attn . append ( make_attn ( block_in , attn_type = attn_type ) )
up = nn . Module ( )
up . block = block
up . attn = attn
if i_level != 0 :
up . upsample = Upsample ( block_in , resamp_with_conv )
curr_res = curr_res * 2
self . up . insert ( 0 , up ) # prepend to get consistent order
h2 = h_ . reshape ( b , c , h , w )
del h_
# end
self . norm_out = Normalize ( block_in )
self . conv_out = torch . nn . Conv2d (
block_in , out_ch , kernel_size = 3 , stride = 1 , padding = 1
)
h3 = self . proj_out ( h2 )
del h2
def forward ( self , x , t = None , context = None ) :
# assert x.shape[2] == x.shape[3] == self.resolution
if context is not None :
# assume aligned context, cat along channel axis
x = torch . cat ( ( x , context ) , dim = 1 )
if self . use_timestep :
# timestep embedding
assert t is not None
temb = get_timestep_embedding ( t , self . ch )
temb = self . temb . dense [ 0 ] ( temb )
temb = nonlinearity ( temb )
temb = self . temb . dense [ 1 ] ( temb )
else :
temb = None
h3 + = x
# downsampling
hs = [ self . conv_in ( x ) ]
for i_level in range ( self . num_resolutions ) :
for i_block in range ( self . num_res_blocks ) :
h = self . down [ i_level ] . block [ i_block ] ( hs [ - 1 ] , temb )
if len ( self . down [ i_level ] . attn ) > 0 :
h = self . down [ i_level ] . attn [ i_block ] ( h )
hs . append ( h )
if i_level != self . num_resolutions - 1 :
hs . append ( self . down [ i_level ] . downsample ( hs [ - 1 ] ) )
return h3
# middle
h = hs [ - 1 ]
h = self . mid . block_1 ( h , temb )
h = self . mid . attn_1 ( h )
h = self . mid . block_2 ( h , temb )
# upsampling
for i_level in reversed ( range ( self . num_resolutions ) ) :
for i_block in range ( self . num_res_blocks + 1 ) :
h = self . up [ i_level ] . block [ i_block ] (
torch . cat ( [ h , hs . pop ( ) ] , dim = 1 ) , temb
)
if len ( self . up [ i_level ] . attn ) > 0 :
h = self . up [ i_level ] . attn [ i_block ] ( h )
if i_level != 0 :
h = self . up [ i_level ] . upsample ( h )
def make_attn ( in_channels , attn_type = " vanilla " ) :
assert attn_type in [ " vanilla " , " linear " , " none " ] , f " attn_type { attn_type } unknown "
logger . debug (
f " making attention of type ' { attn_type } ' with { in_channels } in_channels "
)
if attn_type == " vanilla " :
return AttnBlock ( in_channels )
if attn_type == " none " :
return nn . Identity ( in_channels )
# end
h = self . norm_out ( h )
h = nonlinearity ( h )
h = self . conv_out ( h )
return h
return LinAttnBlock ( in_channels )
def get_last_layer ( self ) :
return self . conv_out . weight
class Encoder ( nn . Module ) :
@ -447,9 +618,11 @@ class Decoder(nn.Module):
block_in = ch * ch_mult [ self . num_resolutions - 1 ]
curr_res = resolution / / 2 * * ( self . num_resolutions - 1 )
self . z_shape = ( 1 , z_channels , curr_res , curr_res )
logger . debug (
f " Working with z of shape { self . z_shape } = { np . prod ( self . z_shape ) } dimensions. "
)
# print(
# "Working with z of shape {} = {} dimensions.".format(
# self.z_shape, np.prod(self.z_shape)
# )
# )
# z to block_in
self . conv_in = torch . nn . Conv2d (
@ -503,7 +676,6 @@ class Decoder(nn.Module):
self . conv_out = torch . nn . Conv2d (
block_in , out_ch , kernel_size = 3 , stride = 1 , padding = 1
)
self . last_z_shape = None
def forward ( self , z ) :
# assert z.shape[1:] == self.z_shape[1:]
@ -513,53 +685,136 @@ class Decoder(nn.Module):
temb = None
# z to block_in
h 1 = self . conv_in ( z )
h = self . conv_in ( z )
# middle
h2 = self . mid . block_1 ( h1 , temb )
del h1
h3 = self . mid . attn_1 ( h2 )
del h2
h = self . mid . block_2 ( h3 , temb )
del h3
# prepare for up sampling
gc . collect ( )
torch . cuda . empty_cache ( )
h = self . mid . block_1 ( h , temb )
h = self . mid . attn_1 ( h )
h = self . mid . block_2 ( h , temb )
# upsampling
for i_level in reversed ( range ( self . num_resolutions ) ) :
for i_block in range ( self . num_res_blocks + 1 ) :
h = self . up [ i_level ] . block [ i_block ] ( h , temb )
if len ( self . up [ i_level ] . attn ) > 0 :
t = h
h = self . up [ i_level ] . attn [ i_block ] ( t )
del t
h = self . up [ i_level ] . attn [ i_block ] ( h )
if i_level != 0 :
t = h
h = self . up [ i_level ] . upsample ( t )
del t
h = self . up [ i_level ] . upsample ( h )
# end
if self . give_pre_end :
return h
h1 = self . norm_out ( h )
del h
h = self . norm_out ( h )
h = nonlinearity ( h )
h = self . conv_out ( h )
if self . tanh_out :
h = torch . tanh ( h )
return h
h2 = nonlinearity ( h1 )
del h1
h = self . conv_out ( h2 )
del h2
class SimpleDecoder ( nn . Module ) :
def __init__ ( self , in_channels , out_channels , * args , * * kwargs ) :
super ( ) . __init__ ( )
self . model = nn . ModuleList (
[
nn . Conv2d ( in_channels , in_channels , 1 ) ,
ResnetBlock (
in_channels = in_channels ,
out_channels = 2 * in_channels ,
temb_channels = 0 ,
dropout = 0.0 ,
) ,
ResnetBlock (
in_channels = 2 * in_channels ,
out_channels = 4 * in_channels ,
temb_channels = 0 ,
dropout = 0.0 ,
) ,
ResnetBlock (
in_channels = 4 * in_channels ,
out_channels = 2 * in_channels ,
temb_channels = 0 ,
dropout = 0.0 ,
) ,
nn . Conv2d ( 2 * in_channels , in_channels , 1 ) ,
Upsample ( in_channels , with_conv = True ) ,
]
)
# end
self . norm_out = Normalize ( in_channels )
self . conv_out = torch . nn . Conv2d (
in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1
)
if self . tanh_out :
t = h
h = torch . tanh ( t )
del t
def forward ( self , x ) :
for i , layer in enumerate ( self . model ) :
if i in [ 1 , 2 , 3 ] :
x = layer ( x , None )
else :
x = layer ( x )
h = self . norm_out ( x )
h = nonlinearity ( h )
x = self . conv_out ( h )
return x
class UpsampleDecoder ( nn . Module ) :
def __init__ (
self ,
in_channels ,
out_channels ,
ch ,
num_res_blocks ,
resolution ,
ch_mult = ( 2 , 2 ) ,
dropout = 0.0 ,
) :
super ( ) . __init__ ( )
# upsampling
self . temb_ch = 0
self . num_resolutions = len ( ch_mult )
self . num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution / / 2 * * ( self . num_resolutions - 1 )
self . res_blocks = nn . ModuleList ( )
self . upsample_blocks = nn . ModuleList ( )
for i_level in range ( self . num_resolutions ) :
res_block = [ ]
block_out = ch * ch_mult [ i_level ]
for i_block in range ( self . num_res_blocks + 1 ) :
res_block . append (
ResnetBlock (
in_channels = block_in ,
out_channels = block_out ,
temb_channels = self . temb_ch ,
dropout = dropout ,
)
)
block_in = block_out
self . res_blocks . append ( nn . ModuleList ( res_block ) )
if i_level != self . num_resolutions - 1 :
self . upsample_blocks . append ( Upsample ( block_in , True ) )
curr_res = curr_res * 2
# end
self . norm_out = Normalize ( block_in )
self . conv_out = torch . nn . Conv2d (
block_in , out_channels , kernel_size = 3 , stride = 1 , padding = 1
)
def forward ( self , x ) :
# upsampling
h = x
for k , i_level in enumerate ( range ( self . num_resolutions ) ) :
for i_block in range ( self . num_res_blocks + 1 ) :
h = self . res_blocks [ i_level ] [ i_block ] ( h , None )
if i_level != self . num_resolutions - 1 :
h = self . upsample_blocks [ k ] ( h )
h = self . norm_out ( h )
h = nonlinearity ( h )
h = self . conv_out ( h )
return h
@ -619,15 +874,102 @@ class LatentRescaler(nn.Module):
return x
class MergedRescaleEncoder ( nn . Module ) :
def __init__ (
self ,
in_channels ,
ch ,
resolution ,
out_ch ,
num_res_blocks ,
attn_resolutions ,
dropout = 0.0 ,
resamp_with_conv = True ,
ch_mult = ( 1 , 2 , 4 , 8 ) ,
rescale_factor = 1.0 ,
rescale_module_depth = 1 ,
) :
super ( ) . __init__ ( )
intermediate_chn = ch * ch_mult [ - 1 ]
self . encoder = Encoder (
in_channels = in_channels ,
num_res_blocks = num_res_blocks ,
ch = ch ,
ch_mult = ch_mult ,
z_channels = intermediate_chn ,
double_z = False ,
resolution = resolution ,
attn_resolutions = attn_resolutions ,
dropout = dropout ,
resamp_with_conv = resamp_with_conv ,
out_ch = None ,
)
self . rescaler = LatentRescaler (
factor = rescale_factor ,
in_channels = intermediate_chn ,
mid_channels = intermediate_chn ,
out_channels = out_ch ,
depth = rescale_module_depth ,
)
def forward ( self , x ) :
x = self . encoder ( x )
x = self . rescaler ( x )
return x
class MergedRescaleDecoder ( nn . Module ) :
def __init__ (
self ,
z_channels ,
out_ch ,
resolution ,
num_res_blocks ,
attn_resolutions ,
ch ,
ch_mult = ( 1 , 2 , 4 , 8 ) ,
dropout = 0.0 ,
resamp_with_conv = True ,
rescale_factor = 1.0 ,
rescale_module_depth = 1 ,
) :
super ( ) . __init__ ( )
tmp_chn = z_channels * ch_mult [ - 1 ]
self . decoder = Decoder (
out_ch = out_ch ,
z_channels = tmp_chn ,
attn_resolutions = attn_resolutions ,
dropout = dropout ,
resamp_with_conv = resamp_with_conv ,
in_channels = None ,
num_res_blocks = num_res_blocks ,
ch_mult = ch_mult ,
resolution = resolution ,
ch = ch ,
)
self . rescaler = LatentRescaler (
factor = rescale_factor ,
in_channels = z_channels ,
mid_channels = tmp_chn ,
out_channels = tmp_chn ,
depth = rescale_module_depth ,
)
def forward ( self , x ) :
x = self . rescaler ( x )
x = self . decoder ( x )
return x
class Upsampler ( nn . Module ) :
def __init__ ( self , in_size , out_size , in_channels , out_channels , ch_mult = 2 ) :
super ( ) . __init__ ( )
assert out_size > = in_size
num_blocks = int ( np . log2 ( out_size / / in_size ) ) + 1
factor_up = 1.0 + ( out_size % in_size )
logger . debug (
f " Building { self . __class__ . __name__ } with in_size: { in_size } --> out_size { out_size } and factor { factor_up } "
)
# print (
# f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
# )
self . rescaler = LatentRescaler (
factor = factor_up ,
in_channels = in_channels ,
@ -657,98 +999,21 @@ class Resize(nn.Module):
self . with_conv = learned
self . mode = mode
if self . with_conv :
logger . info (
f " Note: { self . __class__ . __name } uses learned downsampling and will ignore the fixed { mode } mode " # noqa
)
raise NotImplementedError ( )
# assert in_channels is not None
# # no asymmetric padding in torch conv, must do it ourselves
# self.conv = torch.nn.Conv2d(
# in_channels, in_channels, kernel_size=4, stride=2, padding=1
# print(
# f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
# )
raise NotImplementedError ( )
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self . conv = torch . nn . Conv2d (
in_channels , in_channels , kernel_size = 4 , stride = 2 , padding = 1
)
def forward ( self , x , scale_factor = 1.0 ) :
if scale_factor == 1.0 :
return x
x = torch . nn . functional . interpolate (
x , mode = self . mode , align_corners = False , scale_factor = scale_factor
)
return x
class FirstStagePostProcessor ( nn . Module ) :
def __init__ (
self ,
ch_mult : list ,
in_channels ,
pretrained_model : nn . Module = None ,
reshape = False ,
n_channels = None ,
dropout = 0.0 ,
pretrained_config = None ,
) :
super ( ) . __init__ ( )
if pretrained_config is None :
assert (
pretrained_model is not None
) , ' Either " pretrained_model " or " pretrained_config " must not be None '
self . pretrained_model = pretrained_model
else :
assert (
pretrained_config is not None
) , ' Either " pretrained_model " or " pretrained_config " must not be None '
self . instantiate_pretrained ( pretrained_config )
self . do_reshape = reshape
if n_channels is None :
n_channels = self . pretrained_model . encoder . ch
self . proj_norm = Normalize ( in_channels , num_groups = in_channels / / 2 )
self . proj = nn . Conv2d (
in_channels , n_channels , kernel_size = 3 , stride = 1 , padding = 1
)
blocks = [ ]
downs = [ ]
ch_in = n_channels
for m in ch_mult :
blocks . append (
ResnetBlock (
in_channels = ch_in , out_channels = m * n_channels , dropout = dropout
)
x = torch . nn . functional . interpolate (
x , mode = self . mode , align_corners = False , scale_factor = scale_factor
)
ch_in = m * n_channels
downs . append ( Downsample ( ch_in , with_conv = False ) )
self . model = nn . ModuleList ( blocks )
self . downsampler = nn . ModuleList ( downs )
def instantiate_pretrained ( self , config ) :
model = instantiate_from_config ( config )
self . pretrained_model = model . eval ( )
# self.pretrained_model.train = False
for param in self . pretrained_model . parameters ( ) :
param . requires_grad = False
@torch.no_grad ( )
def encode_with_pretrained ( self , x ) :
c = self . pretrained_model . encode ( x )
if isinstance ( c , DiagonalGaussianDistribution ) :
c = c . mode ( )
return c
def forward ( self , x ) :
z_fs = self . encode_with_pretrained ( x )
z = self . proj_norm ( z_fs )
z = self . proj ( z )
z = nonlinearity ( z )
for submodel , downmodel in zip ( self . model , self . downsampler ) :
z = submodel ( z , temb = None )
z = downmodel ( z )
if self . do_reshape :
z = rearrange ( z , " b c h w -> b (h w) c " )
return z
return x