|
|
|
@ -12,7 +12,9 @@ from torch import Tensor, device as Device, dtype as DType, nn
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
|
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
|
|
|
|
from imaginairy import config
|
|
|
|
|
from imaginairy.schema import WeightedPrompt
|
|
|
|
|
from imaginairy.utils.downloads import get_cached_url_path
|
|
|
|
|
from imaginairy.utils.feather_tile import rebuild_image, tile_image
|
|
|
|
|
from imaginairy.vendored.refiners.fluxion.layers.attentions import (
|
|
|
|
|
ScaledDotProductAttention,
|
|
|
|
@ -22,6 +24,10 @@ from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpol
|
|
|
|
|
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
|
|
|
|
|
CLIPTextEncoderL,
|
|
|
|
|
)
|
|
|
|
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
|
|
|
|
|
SD1IPAdapter,
|
|
|
|
|
SDXLIPAdapter,
|
|
|
|
|
)
|
|
|
|
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import (
|
|
|
|
|
TLatentDiffusionModel,
|
|
|
|
|
)
|
|
|
|
@ -55,6 +61,13 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusio
|
|
|
|
|
SDXLUNet,
|
|
|
|
|
)
|
|
|
|
|
from imaginairy.weight_management.conversion import cast_weights
|
|
|
|
|
from imaginairy.weight_management.translators import (
|
|
|
|
|
diffusers_ip_adapter_plus_sd15_to_refiners_translator,
|
|
|
|
|
diffusers_ip_adapter_plus_sdxl_to_refiners_translator,
|
|
|
|
|
diffusers_ip_adapter_sd15_to_refiners_translator,
|
|
|
|
|
diffusers_ip_adapter_sdxl_to_refiners_translator,
|
|
|
|
|
transformers_image_encoder_to_refiners_translator,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
@ -106,7 +119,66 @@ class TileModeMixin(nn.Module):
|
|
|
|
|
m.padding_y = (0, 0, rprt[2], rprt[3]) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1):
|
|
|
|
|
class SD1ImagePromptMixin(nn.Module):
|
|
|
|
|
def _get_ip_adapter(self, model_type: str):
|
|
|
|
|
valid_model_types = ["normal", "plus", "plus-face"]
|
|
|
|
|
if model_type not in valid_model_types:
|
|
|
|
|
msg = f"IP Adapter model_type must be one of {valid_model_types}"
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
|
ip_adapter_weights_path = get_cached_url_path(
|
|
|
|
|
config.IP_ADAPTER_WEIGHT_LOCATIONS["sd15"][model_type]
|
|
|
|
|
)
|
|
|
|
|
clip_image_weights_path = get_cached_url_path(config.SD21_UNCLIP_WEIGHTS_URL)
|
|
|
|
|
if "plus" in model_type:
|
|
|
|
|
ip_adapter_weight_translator = (
|
|
|
|
|
diffusers_ip_adapter_plus_sd15_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
ip_adapter_weight_translator = (
|
|
|
|
|
diffusers_ip_adapter_sd15_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
clip_image_weight_translator = (
|
|
|
|
|
transformers_image_encoder_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
ip_adapter = SD1IPAdapter(
|
|
|
|
|
target=self.unet,
|
|
|
|
|
weights=ip_adapter_weight_translator.load_and_translate_weights(
|
|
|
|
|
ip_adapter_weights_path
|
|
|
|
|
),
|
|
|
|
|
fine_grained="plus" in model_type,
|
|
|
|
|
)
|
|
|
|
|
ip_adapter.clip_image_encoder.load_state_dict(
|
|
|
|
|
clip_image_weight_translator.load_and_translate_weights(
|
|
|
|
|
clip_image_weights_path
|
|
|
|
|
),
|
|
|
|
|
assign=True,
|
|
|
|
|
)
|
|
|
|
|
ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype)
|
|
|
|
|
ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype)
|
|
|
|
|
return ip_adapter
|
|
|
|
|
|
|
|
|
|
def set_image_prompt(
|
|
|
|
|
self, images: list[Image.Image], scale: float, model_type: str = "normal"
|
|
|
|
|
):
|
|
|
|
|
ip_adapter = self._get_ip_adapter(model_type)
|
|
|
|
|
ip_adapter.inject()
|
|
|
|
|
|
|
|
|
|
ip_adapter.set_scale(scale)
|
|
|
|
|
image_embeddings = []
|
|
|
|
|
for image in images:
|
|
|
|
|
image_embedding = ip_adapter.compute_clip_image_embedding(
|
|
|
|
|
ip_adapter.preprocess_image(image).to(device=self.unet.device)
|
|
|
|
|
)
|
|
|
|
|
image_embeddings.append(image_embedding)
|
|
|
|
|
|
|
|
|
|
clip_image_embedding = sum(image_embeddings) / len(image_embeddings)
|
|
|
|
|
|
|
|
|
|
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusion_1(TileModeMixin, SD1ImagePromptMixin, RefinerStableDiffusion_1):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
unet: SD1UNet | None = None,
|
|
|
|
@ -184,7 +256,68 @@ class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1):
|
|
|
|
|
return conditioning
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
|
|
|
|
|
class SDXLImagePromptMixin(nn.Module):
|
|
|
|
|
def _get_ip_adapter(self, model_type: str):
|
|
|
|
|
valid_model_types = ["normal", "plus", "plus-face"]
|
|
|
|
|
if model_type not in valid_model_types:
|
|
|
|
|
msg = f"IP Adapter model_type must be one of {valid_model_types}"
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
|
ip_adapter_weights_path = get_cached_url_path(
|
|
|
|
|
config.IP_ADAPTER_WEIGHT_LOCATIONS["sdxl"][model_type]
|
|
|
|
|
)
|
|
|
|
|
clip_image_weights_path = get_cached_url_path(config.SD21_UNCLIP_WEIGHTS_URL)
|
|
|
|
|
if "plus" in model_type:
|
|
|
|
|
ip_adapter_weight_translator = (
|
|
|
|
|
diffusers_ip_adapter_plus_sdxl_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
ip_adapter_weight_translator = (
|
|
|
|
|
diffusers_ip_adapter_sdxl_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
clip_image_weight_translator = (
|
|
|
|
|
transformers_image_encoder_to_refiners_translator()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
ip_adapter = SDXLIPAdapter(
|
|
|
|
|
target=self.unet,
|
|
|
|
|
weights=ip_adapter_weight_translator.load_and_translate_weights(
|
|
|
|
|
ip_adapter_weights_path
|
|
|
|
|
),
|
|
|
|
|
fine_grained="plus" in model_type,
|
|
|
|
|
)
|
|
|
|
|
ip_adapter.clip_image_encoder.load_state_dict(
|
|
|
|
|
clip_image_weight_translator.load_and_translate_weights(
|
|
|
|
|
clip_image_weights_path
|
|
|
|
|
),
|
|
|
|
|
assign=True,
|
|
|
|
|
)
|
|
|
|
|
ip_adapter.to(device=self.unet.device, dtype=self.unet.dtype)
|
|
|
|
|
ip_adapter.clip_image_encoder.to(device=self.unet.device, dtype=self.unet.dtype)
|
|
|
|
|
return ip_adapter
|
|
|
|
|
|
|
|
|
|
def set_image_prompt(
|
|
|
|
|
self, images: list[Image.Image], scale: float, model_type: str = "normal"
|
|
|
|
|
):
|
|
|
|
|
ip_adapter = self._get_ip_adapter(model_type)
|
|
|
|
|
ip_adapter.inject()
|
|
|
|
|
|
|
|
|
|
ip_adapter.set_scale(scale)
|
|
|
|
|
image_embeddings = []
|
|
|
|
|
for image in images:
|
|
|
|
|
image_embedding = ip_adapter.compute_clip_image_embedding(
|
|
|
|
|
ip_adapter.preprocess_image(image).to(device=self.unet.device)
|
|
|
|
|
)
|
|
|
|
|
image_embeddings.append(image_embedding)
|
|
|
|
|
|
|
|
|
|
clip_image_embedding = sum(image_embeddings) / len(image_embeddings)
|
|
|
|
|
|
|
|
|
|
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusion_XL(
|
|
|
|
|
TileModeMixin, SDXLImagePromptMixin, RefinerStableDiffusion_XL
|
|
|
|
|
):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
unet: SDXLUNet | None = None,
|
|
|
|
@ -324,7 +457,9 @@ class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
|
|
|
|
|
return clip_text_embedding, pooled_text_embedding # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpainting):
|
|
|
|
|
class StableDiffusion_1_Inpainting(
|
|
|
|
|
TileModeMixin, SD1ImagePromptMixin, RefinerStableDiffusion_1_Inpainting
|
|
|
|
|
):
|
|
|
|
|
def compute_self_attention_guidance(
|
|
|
|
|
self,
|
|
|
|
|
x: Tensor,
|
|
|
|
@ -356,7 +491,17 @@ class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpai
|
|
|
|
|
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
|
|
|
|
|
dim=1,
|
|
|
|
|
)
|
|
|
|
|
degraded_noise = self.unet(x)
|
|
|
|
|
if "ip_adapter" in self.unet.provider.contexts:
|
|
|
|
|
# this implementation is a bit hacky, it should be refactored in the future
|
|
|
|
|
ip_adapter_context = self.unet.use_context("ip_adapter")
|
|
|
|
|
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
|
|
|
|
|
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context[
|
|
|
|
|
"clip_image_embedding"
|
|
|
|
|
].chunk(2)
|
|
|
|
|
degraded_noise = self.unet(x)
|
|
|
|
|
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
|
|
|
|
|
else:
|
|
|
|
|
degraded_noise = self.unet(x)
|
|
|
|
|
|
|
|
|
|
return sag.scale * (noise - degraded_noise)
|
|
|
|
|
|
|
|
|
@ -518,7 +663,17 @@ class StableDiffusion_XL_Inpainting(StableDiffusion_XL):
|
|
|
|
|
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
|
|
|
|
|
dim=1,
|
|
|
|
|
)
|
|
|
|
|
degraded_noise = self.unet(x)
|
|
|
|
|
if "ip_adapter" in self.unet.provider.contexts:
|
|
|
|
|
# this implementation is a bit hacky, it should be refactored in the future
|
|
|
|
|
ip_adapter_context = self.unet.use_context("ip_adapter")
|
|
|
|
|
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
|
|
|
|
|
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context[
|
|
|
|
|
"clip_image_embedding"
|
|
|
|
|
].chunk(2)
|
|
|
|
|
degraded_noise = self.unet(x)
|
|
|
|
|
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
|
|
|
|
|
else:
|
|
|
|
|
degraded_noise = self.unet(x)
|
|
|
|
|
|
|
|
|
|
return sag.scale * (noise - degraded_noise)
|
|
|
|
|
|
|
|
|
|