feature: Stable Diffusion 2.0

working: CUDA and MacOS
working: 512p model with all samplers
working: inpainting with all samplers
working: 768p model with ddim sampler
pull/101/head
Bryce 2 years ago committed by Bryce Drennan
parent 1a866c48e1
commit e67341223b

@ -224,6 +224,13 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
**6.0.0a**
- feature: 🎉🎉🎉 Stable Diffusion 2.0
- Tested on MacOS and Linux
- All samplers working for new 512x512 model
- New inpainting model working
- 768x768 model working for DDIM sampler only
**5.1.0**
- feature: add progress image callback

@ -0,0 +1,68 @@
model:
base_learning_rate: 1.0e-4
target: imaginairy.modules.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: imaginairy.modules.diffusion.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: False
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: imaginairy.modules.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: imaginairy.modules.encoders.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

@ -0,0 +1,67 @@
model:
base_learning_rate: 1.0e-4
target: imaginairy.modules.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: imaginairy.modules.diffusion.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: False
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: imaginairy.modules.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: imaginairy.modules.encoders.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

@ -0,0 +1,158 @@
model:
base_learning_rate: 5.0e-05
target: imaginairy.modules.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: hybrid
scale_factor: 0.18215
monitor: val/loss_simple_ema
finetune_keys: null
use_ema: False
unet_config:
target: imaginairy.modules.diffusion.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: imaginairy.modules.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: imaginairy.modules.encoders.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: null # for concat as in LAION-A
p_unsafe_threshold: 0.1
filter_word_list: "data/filters.yaml"
max_pwatermark: 0.45
batch_size: 8
num_workers: 6
multinode: True
min_size: 512
train:
shards:
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
shuffle: 10000
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
# NOTE use enough shards to avoid empty validation loops in workers
validation:
shards:
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
shuffle: 0
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
lightning:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 10000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
disabled: False
batch_frequency: 1000
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 5.0
unconditional_guidance_label: [""]
ddim_steps: 50 # todo check these out for depth2img,
ddim_eta: 0.0 # todo check these out for depth2img,
trainer:
benchmark: True
val_check_interval: 5000000
num_sanity_val_steps: 0
accumulate_grad_batches: 1

@ -0,0 +1,74 @@
model:
base_learning_rate: 5.0e-07
target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: hybrid
scale_factor: 0.18215
monitor: val/loss_simple_ema
finetune_keys: null
use_ema: False
depth_stage_config:
target: ldm.modules.midas.api.MiDaSInference
params:
model_type: "dpt_hybrid"
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 5
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

@ -0,0 +1,76 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
params:
parameterization: "v"
low_scale_key: "lr"
linear_start: 0.0001
linear_end: 0.02
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 128
channels: 4
cond_stage_trainable: false
conditioning_key: "hybrid-adm"
monitor: val/loss_simple_ema
scale_factor: 0.08333
use_ema: False
low_scale_config:
target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
params:
noise_schedule_config: # image space
linear_start: 0.0001
linear_end: 0.02
max_noise_level: 350
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
image_size: 128
in_channels: 7
out_channels: 4
model_channels: 256
attention_resolutions: [ 2,4,8]
num_res_blocks: 2
channel_mult: [ 1, 2, 2, 4]
disable_self_attentions: [True, True, True, False]
disable_middle_self_attn: False
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
use_linear_in_transformer: True
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
ddconfig:
# attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though)
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

@ -28,8 +28,20 @@ MODEL_SHORTCUTS = {
"configs/stable-diffusion-v1-inpaint.yaml",
"https://huggingface.co/julienacquaviva/inpainting/resolve/2155ff7fe38b55f4c0d99c2f1ab9b561f8311ca7/sd-v1-5-inpainting.ckpt",
),
"SD-2.0": (
"configs/stable-diffusion-v2-inference.yaml",
"https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt",
),
"SD-2.0-inpaint": (
"configs/stable-diffusion-v2-inpainting-inference.yaml",
"https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/512-inpainting-ema.ckpt",
),
"SD-2.0-v": (
"configs/stable-diffusion-v2-inference-v.yaml",
"https://huggingface.co/stabilityai/stable-diffusion-2/resolve/main/768-v-ema.ckpt",
),
}
DEFAULT_MODEL = "SD-1.5"
DEFAULT_MODEL = "SD-2.0"
LOADED_MODELS = {}
MOST_RECENTLY_LOADED_MODEL = None

@ -9,8 +9,15 @@ from torch import einsum, nn
from imaginairy.modules.diffusion.util import checkpoint
from imaginairy.utils import get_device
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
@ -245,7 +252,67 @@ class CrossAttention(nn.Module):
return self.to_out(r2)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
# print(
# f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
# f"{heads} heads."
# )
inner_dim = dim_head * heads
context_dim = context_dim if context_dim is not None else query_dim
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.attention_op = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = context if context is not None else x
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
if mask is not None:
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention,
}
def __init__(
self,
dim,
@ -254,14 +321,23 @@ class BasicTransformerBlock(nn.Module):
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True, # noqa
checkpoint=True,
disable_self_attn=False,
):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
@ -280,7 +356,12 @@ class BasicTransformerBlock(nn.Module):
def _forward(self, x, context=None):
x = x.contiguous() if x.device.type == "mps" else x
x = self.attn1(self.norm1(x)) + x
x = (
self.attn1(
self.norm1(x), context=context if self.disable_self_attn else None
)
+ x
)
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
@ -293,42 +374,73 @@ class SpatialTransformer(nn.Module):
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(
self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
use_linear=False,
use_checkpoint=True,
):
super().__init__()
if context_dim is not None and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape # noqa
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in

@ -1,10 +1,12 @@
import logging
from contextlib import contextmanager
import pytorch_lightning as pl
import torch
from imaginairy.modules.diffusion.model import Decoder, Encoder
from imaginairy.modules.distributions import DiagonalGaussianDistribution
from imaginairy.modules.ema import LitEma
from imaginairy.utils import instantiate_from_config
logger = logging.getLogger(__name__)
@ -17,13 +19,15 @@ class AutoencoderKL(pl.LightningModule):
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False,
):
super().__init__()
ignore_keys = [] if ignore_keys is None else ignore_keys
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
@ -33,24 +37,50 @@ class AutoencoderKL(pl.LightningModule):
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert isinstance(colorize_nlabels, int)
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0.0 < ema_decay < 1.0
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=None):
ignore_keys = [] if ignore_keys is None else ignore_keys
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
logger.info(f"Deleting key {k} from state_dict.")
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
logger.info(f"Restored from {path}")
print(f"Restored from {path}")
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
@ -63,7 +93,7 @@ class AutoencoderKL(pl.LightningModule):
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True): # noqa
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
@ -78,3 +108,166 @@ class AutoencoderKL(pl.LightningModule):
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"discloss",
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
)
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = (
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters())
)
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(
torch.randn_like(posterior_ema.sample())
)
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
# def to_rgb(self, x):
# assert self.image_key == "segmentation"
# if not hasattr(self, "colorize"):
# self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
# x = F.conv2d(x, weight=self.colorize)
# x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
# return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

@ -5,24 +5,26 @@ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bb
https://github.com/CompVis/taming-transformers
-- merci
"""
import itertools
import logging
from contextlib import contextmanager
from functools import partial
import numpy as np
import pytorch_lightning as pl
import torch
from einops import rearrange
from einops import rearrange, repeat
from torch import nn
from torchvision.utils import make_grid
from tqdm import tqdm
from imaginairy.log_utils import log_latent
from imaginairy.modules.diffusion.util import (
extract_into_tensor,
make_beta_schedule,
noise_like,
)
from imaginairy.modules.distributions import DiagonalGaussianDistribution
from imaginairy.modules.ema import LitEma
from imaginairy.utils import instantiate_from_config
logger = logging.getLogger(__name__)
@ -42,12 +44,7 @@ def uniform_on_device(r1, r2, shape, device):
class DDPM(pl.LightningModule):
"""
classic DDPM with Gaussian diffusion, in image space
Denoising diffusion probabilistic models
"""
# classic DDPM with Gaussian diffusion, in image space
def __init__(
self,
unet_config,
@ -55,9 +52,10 @@ class DDPM(pl.LightningModule):
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
use_ema=True,
first_stage_key="image",
image_size=256,
channels=3,
@ -76,18 +74,21 @@ class DDPM(pl.LightningModule):
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.0,
make_it_fit=False,
ucg_training=None,
reset_ema=False,
reset_num_ema_updates=False,
):
super().__init__()
ignore_keys = [] if ignore_keys is None else ignore_keys
assert parameterization in [
"eps",
"x0",
], 'currently only supporting "eps" and "x0"'
"v",
], 'currently only supporting "eps" and "x0" and "v"'
self.parameterization = parameterization
logger.debug(
f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
)
# print(
# f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
# )
self.cond_stage_model = None
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
@ -96,6 +97,11 @@ class DDPM(pl.LightningModule):
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.model = DiffusionWrapper(unet_config, conditioning_key)
# count_params(self.model, verbose=True)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
# print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
@ -107,10 +113,25 @@ class DDPM(pl.LightningModule):
if monitor is not None:
self.monitor = monitor
self.make_it_fit = make_it_fit
if reset_ema:
assert ckpt_path is not None
if ckpt_path is not None:
self.init_from_ckpt(
ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet
)
if reset_ema:
assert self.use_ema
print(
f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
)
self.model_ema = LitEma(self.model)
if reset_num_ema_updates:
print(
" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
)
assert self.use_ema
self.model_ema.reset_num_updates()
self.register_schedule(
given_betas=given_betas,
@ -128,6 +149,10 @@ class DDPM(pl.LightningModule):
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.ucg_training = ucg_training or dict()
if self.ucg_training:
self.ucg_prng = np.random.RandomState()
def register_schedule(
self,
given_betas=None,
@ -170,6 +195,15 @@ class DDPM(pl.LightningModule):
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (
@ -206,13 +240,404 @@ class DDPM(pl.LightningModule):
* np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
)
elif self.parameterization == "v":
lvlb_weights = torch.ones_like(
self.betas**2
/ (
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
)
)
else:
raise NotImplementedError("mu not supported")
# TODO how to choose this term
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
@torch.no_grad()
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
if self.make_it_fit:
n_params = len(
[
name
for name, _ in itertools.chain(
self.named_parameters(), self.named_buffers()
)
]
)
for name, param in tqdm(
itertools.chain(self.named_parameters(), self.named_buffers()),
desc="Fitting old weights to new weights",
total=n_params,
):
if not name in sd:
continue
old_shape = sd[name].shape
new_shape = param.shape
assert len(old_shape) == len(new_shape)
if len(new_shape) > 2:
# we only modify first two axes
assert new_shape[2:] == old_shape[2:]
# assumes first axis corresponds to output dim
if not new_shape == old_shape:
new_param = param.clone()
old_param = sd[name]
if len(new_shape) == 1:
for i in range(new_param.shape[0]):
new_param[i] = old_param[i % old_shape[0]]
elif len(new_shape) >= 2:
for i in range(new_param.shape[0]):
for j in range(new_param.shape[1]):
new_param[i, j] = old_param[
i % old_shape[0], j % old_shape[1]
]
n_used_old = torch.ones(old_shape[1])
for j in range(new_param.shape[1]):
n_used_old[j % old_shape[1]] += 1
n_used_new = torch.zeros(new_shape[1])
for j in range(new_param.shape[1]):
n_used_new[j] = n_used_old[j % old_shape[1]]
n_used_new = n_used_new[None, :]
while len(n_used_new.shape) < len(new_shape):
n_used_new = n_used_new.unsqueeze(-1)
new_param /= n_used_new
sd[name] = new_param
missing, unexpected = (
self.load_state_dict(sd, strict=False)
if not only_model
else self.model.load_state_dict(sd, strict=False)
)
# print(
# f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
# )
# if len(missing) > 0:
# print(f"Missing Keys:\n {missing}")
# if len(unexpected) > 0:
# print(f"\nUnexpected Keys:\n {unexpected}")
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract_into_tensor(
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
* noise
)
def predict_start_from_z_and_v(self, x_t, t, v):
# self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
# self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_eps_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
* x_t
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, clip_denoised: bool):
model_out = self.model(x, t)
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
if clip_denoised:
x_recon.clamp_(-1.0, 1.0)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t
)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(
x=x, t=t, clip_denoised=clip_denoised
)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(
reversed(range(0, self.num_timesteps)),
desc="Sampling t",
total=self.num_timesteps,
):
img = self.p_sample(
img,
torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised,
)
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
intermediates.append(img)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop(
(batch_size, channels, image_size, image_size),
return_intermediates=return_intermediates,
)
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def get_v(self, x, noise, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)
def get_loss(self, pred, target, mean=True):
if self.loss_type == "l1":
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == "l2":
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
def p_losses(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.model(x_noisy, t)
loss_dict = {}
if self.parameterization == "eps":
target = noise
elif self.parameterization == "x0":
target = x_start
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError(
f"Paramterization {self.parameterization} not yet supported"
)
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
log_prefix = "train" if self.training else "val"
loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()})
loss_simple = loss.mean() * self.l_simple_weight
loss_vlb = (self.lvlb_weights[t] * loss).mean()
loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb})
loss = loss_simple + self.original_elbo_weight * loss_vlb
loss_dict.update({f"{log_prefix}/loss": loss})
return loss, loss_dict
def forward(self, x, *args, **kwargs):
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(
0, self.num_timesteps, (x.shape[0],), device=self.device
).long()
return self.p_losses(x, t, *args, **kwargs)
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, "b h w c -> b c h w")
x = x.to(memory_format=torch.contiguous_format).float()
return x
def shared_step(self, batch):
x = self.get_input(batch, self.first_stage_key)
loss, loss_dict = self(x)
return loss, loss_dict
def training_step(self, batch, batch_idx):
for k in self.ucg_training:
p = self.ucg_training[k]["p"]
val = self.ucg_training[k]["val"]
if val is None:
val = ""
for i in range(len(batch[k])):
if self.ucg_prng.choice(2, p=[1 - p, p]):
batch[k][i] = val
loss, loss_dict = self.shared_step(batch)
self.log_dict(
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True
)
self.log(
"global_step",
self.global_step,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]["lr"]
self.log(
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(
loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
)
self.log_dict(
loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
)
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
def _get_rows_from_list(self, samples):
n_imgs_per_row = len(samples)
denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict()
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
x = x.to(self.device)[:N]
log["inputs"] = x
# get diffusion row
diffusion_row = list()
x_start = x[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
diffusion_row.append(x_noisy)
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, denoise_row = self.sample(
batch_size=N, return_intermediates=True
)
log["samples"] = samples
log["denoise_row"] = self._get_rows_from_list(denoise_row)
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.learn_logvar:
params = params + [self.logvar]
opt = torch.optim.AdamW(params, lr=lr)
return opt
class LatentDiffusion(DDPM):
"""main class"""
@ -762,9 +1187,9 @@ class LatentDiffusion(DDPM):
else:
x_recon = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
log_latent(x_recon, "predicted noise")
return x_recon

@ -1,24 +1,27 @@
# pylama:ignore=W0613,W0612
# pytorch_diffusion + derived encoder decoder
import gc
import logging
import math
from typing import Any, Optional
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from torch import nn
from imaginairy.modules.attention import LinearAttention
from imaginairy.modules.distributions import DiagonalGaussianDistribution
from imaginairy.utils import get_device, instantiate_from_config
from imaginairy.modules.attention import MemoryEfficientCrossAttention
logger = logging.getLogger(__name__)
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
# print("No module 'xformers'. Proceeding without it.")
def get_timestep_embedding(timesteps, embedding_dim):
"""
Matches the implementation in Denoising Diffusion Probabilistic Models:
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
@ -39,11 +42,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
def nonlinearity(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
@ -126,30 +125,18 @@ class ResnetBlock(nn.Module):
)
def forward(self, x, temb):
h1 = x
h2 = self.norm1(h1)
del h1
h3 = nonlinearity(h2)
del h2
h4 = self.conv1(h3)
del h3
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h4 = h4 + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h5 = self.norm2(h4)
del h4
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h6 = nonlinearity(h5)
del h5
h7 = self.dropout(h6)
del h6
h8 = self.conv2(h7)
del h7
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
@ -157,14 +144,7 @@ class ResnetBlock(nn.Module):
else:
x = self.nin_shortcut(x)
return x + h8
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
return x + h
class AttnBlock(nn.Module):
@ -187,8 +167,6 @@ class AttnBlock(nn.Module):
)
def forward(self, x):
if get_device() == "cuda":
return self.forward_cuda(x)
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
@ -214,83 +192,276 @@ class AttnBlock(nn.Module):
return x + h_
def forward_cuda(self, x):
class MemoryEfficientAttnBlock(nn.Module):
"""
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
#
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.attention_op: Optional[Any] = None
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h * w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h * w) # b,c,hw
del k1
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
h_ = torch.zeros_like(k, device=q.device)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
out = self.proj_out(out)
return x + out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None):
b, c, h, w = x.shape
x = rearrange(x, "b c h w -> b (h w) c")
out = super().forward(x, context=context, mask=mask)
out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
return x + out
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in [
"vanilla",
"vanilla-xformers",
"memory-efficient-cross-attn",
"linear",
"none",
], f"attn_type {attn_type} unknown"
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
attn_type = "vanilla-xformers"
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
# print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
raise NotImplementedError()
stats = torch.cuda.memory_stats(q.device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
class Model(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type="vanilla",
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList(
[
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
]
)
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c) ** (-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# attend to values
v1 = v.reshape(b, c, h * w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
h_[:, :, i:end] = torch.bmm(
v1, w4
) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
skip_in = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
h2 = h_.reshape(b, c, h, w)
del h_
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
h3 = self.proj_out(h2)
del h2
def forward(self, x, t=None, context=None):
# assert x.shape[2] == x.shape[3] == self.resolution
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
h3 += x
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
return h3
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb
)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
logger.debug(
f"making attention of type '{attn_type}' with {in_channels} in_channels"
)
if attn_type == "vanilla":
return AttnBlock(in_channels)
if attn_type == "none":
return nn.Identity(in_channels)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
return LinAttnBlock(in_channels)
def get_last_layer(self):
return self.conv_out.weight
class Encoder(nn.Module):
@ -447,9 +618,11 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logger.debug(
f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions."
)
# print(
# "Working with z of shape {} = {} dimensions.".format(
# self.z_shape, np.prod(self.z_shape)
# )
# )
# z to block_in
self.conv_in = torch.nn.Conv2d(
@ -503,7 +676,6 @@ class Decoder(nn.Module):
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
self.last_z_shape = None
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
@ -513,53 +685,136 @@ class Decoder(nn.Module):
temb = None
# z to block_in
h1 = self.conv_in(z)
h = self.conv_in(z)
# middle
h2 = self.mid.block_1(h1, temb)
del h1
h3 = self.mid.attn_1(h2)
del h2
h = self.mid.block_2(h3, temb)
del h3
# prepare for up sampling
gc.collect()
torch.cuda.empty_cache()
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
t = h
h = self.up[i_level].attn[i_block](t)
del t
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
t = h
h = self.up[i_level].upsample(t)
del t
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h1 = self.norm_out(h)
del h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
h2 = nonlinearity(h1)
del h1
h = self.conv_out(h2)
del h2
class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList(
[
nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(
in_channels=in_channels,
out_channels=2 * in_channels,
temb_channels=0,
dropout=0.0,
),
ResnetBlock(
in_channels=2 * in_channels,
out_channels=4 * in_channels,
temb_channels=0,
dropout=0.0,
),
ResnetBlock(
in_channels=4 * in_channels,
out_channels=2 * in_channels,
temb_channels=0,
dropout=0.0,
),
nn.Conv2d(2 * in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True),
]
)
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.tanh_out:
t = h
h = torch.tanh(t)
del t
def forward(self, x):
for i, layer in enumerate(self.model):
if i in [1, 2, 3]:
x = layer(x, None)
else:
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
x = self.conv_out(h)
return x
class UpsampleDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
ch,
num_res_blocks,
resolution,
ch_mult=(2, 2),
dropout=0.0,
):
super().__init__()
# upsampling
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
res_block = []
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
self.upsample_blocks.append(Upsample(block_in, True))
curr_res = curr_res * 2
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
# upsampling
h = x
for k, i_level in enumerate(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.res_blocks[i_level][i_block](h, None)
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
@ -619,15 +874,102 @@ class LatentRescaler(nn.Module):
return x
class MergedRescaleEncoder(nn.Module):
def __init__(
self,
in_channels,
ch,
resolution,
out_ch,
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
ch_mult=(1, 2, 4, 8),
rescale_factor=1.0,
rescale_module_depth=1,
):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(
in_channels=in_channels,
num_res_blocks=num_res_blocks,
ch=ch,
ch_mult=ch_mult,
z_channels=intermediate_chn,
double_z=False,
resolution=resolution,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
out_ch=None,
)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=intermediate_chn,
mid_channels=intermediate_chn,
out_channels=out_ch,
depth=rescale_module_depth,
)
def forward(self, x):
x = self.encoder(x)
x = self.rescaler(x)
return x
class MergedRescaleDecoder(nn.Module):
def __init__(
self,
z_channels,
out_ch,
resolution,
num_res_blocks,
attn_resolutions,
ch,
ch_mult=(1, 2, 4, 8),
dropout=0.0,
resamp_with_conv=True,
rescale_factor=1.0,
rescale_module_depth=1,
):
super().__init__()
tmp_chn = z_channels * ch_mult[-1]
self.decoder = Decoder(
out_ch=out_ch,
z_channels=tmp_chn,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
in_channels=None,
num_res_blocks=num_res_blocks,
ch_mult=ch_mult,
resolution=resolution,
ch=ch,
)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=z_channels,
mid_channels=tmp_chn,
out_channels=tmp_chn,
depth=rescale_module_depth,
)
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Upsampler(nn.Module):
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
super().__init__()
assert out_size >= in_size
num_blocks = int(np.log2(out_size // in_size)) + 1
factor_up = 1.0 + (out_size % in_size)
logger.debug(
f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
)
# print(
# f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
# )
self.rescaler = LatentRescaler(
factor=factor_up,
in_channels=in_channels,
@ -657,98 +999,21 @@ class Resize(nn.Module):
self.with_conv = learned
self.mode = mode
if self.with_conv:
logger.info(
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" # noqa
)
raise NotImplementedError()
# assert in_channels is not None
# # no asymmetric padding in torch conv, must do it ourselves
# self.conv = torch.nn.Conv2d(
# in_channels, in_channels, kernel_size=4, stride=2, padding=1
# print(
# f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
# )
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=4, stride=2, padding=1
)
def forward(self, x, scale_factor=1.0):
if scale_factor == 1.0:
return x
x = torch.nn.functional.interpolate(
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
)
return x
class FirstStagePostProcessor(nn.Module):
def __init__(
self,
ch_mult: list,
in_channels,
pretrained_model: nn.Module = None,
reshape=False,
n_channels=None,
dropout=0.0,
pretrained_config=None,
):
super().__init__()
if pretrained_config is None:
assert (
pretrained_model is not None
), 'Either "pretrained_model" or "pretrained_config" must not be None'
self.pretrained_model = pretrained_model
else:
assert (
pretrained_config is not None
), 'Either "pretrained_model" or "pretrained_config" must not be None'
self.instantiate_pretrained(pretrained_config)
self.do_reshape = reshape
if n_channels is None:
n_channels = self.pretrained_model.encoder.ch
self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
self.proj = nn.Conv2d(
in_channels, n_channels, kernel_size=3, stride=1, padding=1
)
blocks = []
downs = []
ch_in = n_channels
for m in ch_mult:
blocks.append(
ResnetBlock(
in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
)
x = torch.nn.functional.interpolate(
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
)
ch_in = m * n_channels
downs.append(Downsample(ch_in, with_conv=False))
self.model = nn.ModuleList(blocks)
self.downsampler = nn.ModuleList(downs)
def instantiate_pretrained(self, config):
model = instantiate_from_config(config)
self.pretrained_model = model.eval()
# self.pretrained_model.train = False
for param in self.pretrained_model.parameters():
param.requires_grad = False
@torch.no_grad()
def encode_with_pretrained(self, x):
c = self.pretrained_model.encode(x)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
return c
def forward(self, x):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
z = nonlinearity(z)
for submodel, downmodel in zip(self.model, self.downsampler):
z = submodel(z, temb=None)
z = downmodel(z)
if self.do_reshape:
z = rearrange(z, "b c h w -> b (h w) c")
return z
return x

@ -4,7 +4,6 @@ from abc import abstractmethod
import numpy as np
import torch as th
import torch.nn.functional as F
from omegaconf.listconfig import ListConfig
from torch import nn
from imaginairy.modules.attention import SpatialTransformer
@ -12,6 +11,7 @@ from imaginairy.modules.diffusion.util import (
avg_pool_nd,
checkpoint,
conv_nd,
linear,
normalization,
timestep_embedding,
zero_module,
@ -478,6 +478,10 @@ class UNetModel(nn.Module):
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
super().__init__()
if use_spatial_transformer:
@ -489,8 +493,9 @@ class UNetModel(nn.Module):
assert (
use_spatial_transformer
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
from omegaconf.listconfig import ListConfig
if isinstance(context_dim, ListConfig):
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
@ -510,7 +515,33 @@ class UNetModel(nn.Module):
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(
map(
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
range(len(num_attention_blocks)),
)
)
print(
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set."
)
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
@ -525,13 +556,19 @@ class UNetModel(nn.Module):
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim),
linear(model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
# print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
else:
raise ValueError()
self.input_blocks = nn.ModuleList(
[
@ -545,7 +582,7 @@ class UNetModel(nn.Module):
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
@ -571,23 +608,32 @@ class UNetModel(nn.Module):
if use_spatial_transformer
else num_head_channels
)
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
if disable_self_attentions is not None:
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if num_attention_blocks is None or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
)
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
@ -641,12 +687,15 @@ class UNetModel(nn.Module):
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer(
else SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
),
ResBlock(
ch,
@ -661,7 +710,7 @@ class UNetModel(nn.Module):
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
@ -688,24 +737,33 @@ class UNetModel(nn.Module):
if use_spatial_transformer
else num_head_channels
)
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
if disable_self_attentions is not None:
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if num_attention_blocks is None or i < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
)
)
)
if level and i == num_res_blocks:
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
ResBlock(
@ -753,7 +811,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): # noqa
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
@ -770,7 +828,7 @@ class UNetModel(nn.Module):
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
@ -784,5 +842,5 @@ class UNetModel(nn.Module):
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
return self.out(h)
else:
return self.out(h)

@ -269,6 +269,13 @@ def conv_nd(dims, *args, **kwargs):
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""Create a 1D, 2D, or 3D average pooling module."""
if dims == 1:

@ -0,0 +1,88 @@
import torch
from torch import nn
# https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/ema.py
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {}
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
"num_updates",
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def reset_num_updates(self):
del self.num_updates
self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)

@ -0,0 +1,263 @@
import open_clip
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from imaginairy.utils import get_device
# https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/encoders/modules.py
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.ucg_rate = ucg_rate
def forward(self, batch, key=None, disable_dropout=False):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
if self.ucg_rate > 0.0 and not disable_dropout:
mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
c = c.long()
c = self.embedding(c)
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
uc_class = (
self.n_classes - 1
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs,), device=device) * uc_class
uc = {self.key: uc}
return uc
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(
self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
layer="last",
layer_idx=None,
): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(
input_ids=tokens, output_hidden_states=self.layer == "hidden"
)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
def encode(self, text):
return self(text)
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
# "pooled",
"last",
"penultimate",
]
def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device=None,
max_length=77,
freeze=True,
layer="last",
):
super().__init__()
assert layer in self.LAYERS
if device is None:
device = get_device()
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device("cpu"), pretrained=version
)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if (
self.model.transformer.grad_checkpointing
and not torch.jit.is_scripting()
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(
self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
clip_max_length=77,
t5_max_length=77,
):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(
clip_version, device, max_length=clip_max_length
)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
# print(
# f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
# f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
# )
def encode(self, text):
return self(text)
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]

@ -123,6 +123,13 @@ class DDIMSampler:
signal_amplification=guidance_scale,
)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(
noisy_latent, time_encoding, noise_pred
)
else:
e_t = noise_pred
batch_size = noisy_latent.shape[0]
# select parameters corresponding to the currently considered timestep
@ -146,11 +153,15 @@ class DDIMSampler:
schedule.ddim_sqrt_one_minus_alphas[index],
device=noisy_latent.device,
)
noisy_latent, predicted_latent = self._p_sample_ddim_formula(
model=self.model,
noisy_latent=noisy_latent,
noise_pred=noise_pred,
e_t=e_t,
sqrt_one_minus_at=sqrt_one_minus_at,
a_t=a_t,
time_encoding=time_encoding,
sigma_t=sigma_t,
a_prev=a_prev,
noise_dropout=noise_dropout,
@ -161,19 +172,27 @@ class DDIMSampler:
@staticmethod
def _p_sample_ddim_formula(
model,
noisy_latent,
noise_pred,
e_t,
sqrt_one_minus_at,
a_t,
time_encoding,
sigma_t,
a_prev,
noise_dropout,
repeat_noise,
temperature,
):
predicted_latent = (noisy_latent - sqrt_one_minus_at * noise_pred) / a_t.sqrt()
if model.parameterization != "v":
predicted_latent = (noisy_latent - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
predicted_latent = model.predict_start_from_z_and_v(
noisy_latent, time_encoding, noise_pred
)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * noise_pred
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = (
sigma_t
* noise_like(noisy_latent.shape, noisy_latent.device, repeat_noise)

@ -48,6 +48,7 @@ setup(
"psutil",
"pytorch-lightning==1.4.2",
"omegaconf==2.1.1",
"open-clip-torch",
"einops==0.3.0",
"timm>=0.4.12", # for vendored blip
"torchdiffeq",

Loading…
Cancel
Save