feature: sliced latent decoding

allows generation of bigger images. tile seams can be noticeable occasionally despite the feathering
pull/252/head
Bryce 1 year ago committed by Bryce Drennan
parent ad62496557
commit 8a97213622

@ -298,6 +298,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
- feature: sliced latent decoding - now possible to make much bigger images. 8 MP (3840x2160) on 11 GB GPU.
**9.0.2**
- fix: edit interface was broken

@ -0,0 +1,13 @@
As seen in:
- 2023-02-05 - https://p.migdal.pl/blog/2023/02/ai-arts-information-theory/
- 2023-01-26 - YouTube - [Live 96: InstructPix2Pix and STABLE DIFFUSION](https://www.youtube.com/watch?v=AyWDYVFuALs&t=485s)
- 2023-01-23 - [Gigazine](https://gigazine.net/gsc_news/en/20230123-imaginairy-ai-image/)
- 2023-01-23 - Changelog.com - ["ImaginAIry imagines & edits images from text inputs"](https://changelog.com/news/imaginairy-imagines-edits-images-from-text-inputs-QpzQ)
- 2022-12-10 - YouTube - [Live 91: TRANSCRIBE Audio with WHISPER and GENERATE images with STABLE DIFFUSION, Locally](https://www.youtube.com/watch?v=CaLmLP2GTEU&t=1785s)
- 2023-01-22 - Hacker News - [Show HN: New AI edits images based on text instructions](https://news.ycombinator.com/item?id=34474270)
- 2023-01-10 - YouTube - [Build A Free AI Image Generator Bot in 20 minutes!](https://www.youtube.com/watch?v=ufQcDD1kQCI)
- 2022-11-24 - Hacker News - [Stable Diffusion 2.0 on Mac and Linux via imaginAIry Python library](https://news.ycombinator.com/item?id=33729694)
Used by:
- [imageeditor.ai](https://imageeditor.ai/)

@ -0,0 +1,205 @@
# inspired by https://github.com/ProGamerGov/neural-dream/blob/master/neural_dream/dream_tile.py
# but with all the bugs fixed and lots of simplifications
# MIT License
import math
import torch
def mask_tile(tile, overlap, std_overlap, side="bottom"):
h, w = tile.size(2), tile.size(3)
top_overlap, bottom_overlap, right_overlap, left_overlap = overlap
(
std_top_overlap,
std_bottom_overlap,
std_right_overlap,
std_left_overlap,
) = std_overlap
if "left" in side:
lin_mask_left = torch.linspace(0, 1, std_left_overlap, device=tile.device)
if left_overlap > std_left_overlap:
zeros_mask = torch.zeros(
left_overlap - std_left_overlap, device=tile.device
)
lin_mask_left = (
torch.cat([zeros_mask, lin_mask_left], 0)
.repeat(h, 1)
.repeat(3, 1, 1)
.unsqueeze(0)
)
if "right" in side:
lin_mask_right = (
torch.linspace(1, 0, right_overlap, device=tile.device)
.repeat(h, 1)
.repeat(3, 1, 1)
.unsqueeze(0)
)
if "top" in side:
lin_mask_top = torch.linspace(0, 1, std_top_overlap, device=tile.device)
if top_overlap > std_top_overlap:
zeros_mask = torch.zeros(top_overlap - std_top_overlap, device=tile.device)
lin_mask_top = torch.cat([zeros_mask, lin_mask_top], 0)
lin_mask_top = lin_mask_top.repeat(w, 1).rot90(3).repeat(3, 1, 1).unsqueeze(0)
if "bottom" in side:
lin_mask_bottom = (
torch.linspace(1, 0, std_bottom_overlap, device=tile.device)
.repeat(w, 1)
.rot90(3)
.repeat(3, 1, 1)
.unsqueeze(0)
)
base_mask = torch.ones_like(tile)
if "right" in side:
base_mask[:, :, :, w - right_overlap :] = (
base_mask[:, :, :, w - right_overlap :] * lin_mask_right
)
if "left" in side:
base_mask[:, :, :, :left_overlap] = (
base_mask[:, :, :, :left_overlap] * lin_mask_left
)
if "bottom" in side:
base_mask[:, :, h - bottom_overlap :, :] = (
base_mask[:, :, h - bottom_overlap :, :] * lin_mask_bottom
)
if "top" in side:
base_mask[:, :, :top_overlap, :] = (
base_mask[:, :, :top_overlap, :] * lin_mask_top
)
return tile * base_mask
def get_tile_coords(d, tile_dim, overlap=0):
move = int(math.ceil(round(tile_dim * (1 - overlap), 10)))
c, tile_start, coords = 1, 0, [0]
while tile_start + tile_dim < d:
tile_start = move * c
if tile_start + tile_dim >= d:
coords.append(d - tile_dim)
else:
coords.append(tile_start)
c += 1
return coords
def get_tiles(img, tile_coords, tile_size):
tile_list = []
for y in tile_coords[0]:
for x in tile_coords[1]:
tile = img[:, :, y : y + tile_size[0], x : x + tile_size[1]]
tile_list.append(tile)
return tile_list
def final_overlap(tile_coords, tile_size):
last_row, last_col = len(tile_coords[0]) - 1, len(tile_coords[1]) - 1
f_ovlp = [
(tile_coords[0][last_row - 1] + tile_size[0]) - (tile_coords[0][last_row]),
(tile_coords[1][last_col - 1] + tile_size[1]) - (tile_coords[1][last_col]),
]
return f_ovlp
def add_tiles(tiles, base_img, tile_coords, tile_size, overlap):
f_ovlp = final_overlap(tile_coords, tile_size)
h, w = tiles[0].size(2), tiles[0].size(3)
if f_ovlp[0] == h:
f_ovlp[0] = 0
if f_ovlp[1] == w:
f_ovlp[1] = 0
t = 0
column, row, = (
0,
0,
)
for y in tile_coords[0]:
for x in tile_coords[1]:
mask_sides = ""
c_overlap = overlap.copy()
if row == 0:
mask_sides += "bottom"
elif 0 < row < len(tile_coords[0]) - 2:
mask_sides += "bottom,top"
elif row == len(tile_coords[0]) - 2:
mask_sides += "bottom,top"
elif row == len(tile_coords[0]) - 1:
mask_sides += "top"
if f_ovlp[0] > 0:
c_overlap[0] = f_ovlp[0] # Change top overlap
if column == 0:
mask_sides += ",right"
elif 0 < column < len(tile_coords[1]) - 2:
mask_sides += ",right,left"
elif column == len(tile_coords[1]) - 2:
mask_sides += ",right,left"
elif column == len(tile_coords[1]) - 1:
mask_sides += ",left"
if f_ovlp[1] > 0:
c_overlap[3] = f_ovlp[1] # Change left overlap
# print(f"mask_tile: tile.shape={tiles[t].shape}, overlap={c_overlap}, side={mask_sides} col={column}, row={row}")
tile = mask_tile(tiles[t], c_overlap, std_overlap=overlap, side=mask_sides)
# torch_img_to_pillow_img(tile).show()
base_img[:, :, y : y + tile_size[0], x : x + tile_size[1]] = (
base_img[:, :, y : y + tile_size[0], x : x + tile_size[1]] + tile
)
# torch_img_to_pillow_img(base_img).show()
t += 1
column += 1
row += 1
# if row >= 2:
# exit()
column = 0
return base_img
def tile_setup(tile_size, overlap_percent, base_size):
if not isinstance(tile_size, (tuple, list)):
tile_size = (tile_size, tile_size)
if not isinstance(overlap_percent, (tuple, list)):
overlap_percent = (overlap_percent, overlap_percent)
if min(tile_size) < 1:
raise ValueError("tile_size must be at least 1")
if max(overlap_percent) > 0.5:
raise ValueError("overlap_percent must not be greater than 0.5")
x_coords = get_tile_coords(base_size[1], tile_size[1], overlap_percent[1])
y_coords = get_tile_coords(base_size[0], tile_size[0], overlap_percent[0])
y_ovlp = int(math.floor(round(tile_size[0] * overlap_percent[0], 10)))
x_ovlp = int(math.floor(round(tile_size[1] * overlap_percent[1], 10)))
if len(x_coords) == 1:
x_ovlp = 0
if len(y_coords) == 1:
y_ovlp = 0
return (y_coords, x_coords), tile_size, [y_ovlp, y_ovlp, x_ovlp, x_ovlp]
def tile_image(img, tile_size, overlap_percent):
tile_coords, tile_size, _ = tile_setup(
tile_size, overlap_percent, (img.size(2), img.size(3))
)
return get_tiles(img, tile_coords, tile_size)
def rebuild_image(tiles, base_img, tile_size, overlap_percent):
if len(tiles) == 1:
return tiles[0]
base_img = torch.zeros_like(base_img)
tile_coords, tile_size, overlap = tile_setup(
tile_size, overlap_percent, (base_img.size(2), base_img.size(3))
)
return add_tiles(tiles, base_img, tile_coords, tile_size, overlap)

@ -42,6 +42,14 @@ def pillow_img_to_torch_image(img: PIL.Image.Image):
return 2.0 * img - 1.0
def torch_img_to_pillow_img(img: torch.Tensor):
img = rearrange(img, "b c h w -> b h w c")
img = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0)
img = (255.0 * img).cpu().numpy().astype(np.uint8)
img = Image.fromarray(img[0])
return img
def pillow_img_to_opencv_img(img: PIL.Image.Image):
open_cv_image = np.array(img)
# Convert RGB to BGR

@ -4,7 +4,9 @@ from contextlib import contextmanager
import pytorch_lightning as pl
import torch
from torch.cuda import OutOfMemoryError
from imaginairy.feather_tile import rebuild_image, tile_image
from imaginairy.modules.diffusion.model import Decoder, Encoder
from imaginairy.modules.distributions import DiagonalGaussianDistribution
from imaginairy.modules.ema import LitEma
@ -91,10 +93,49 @@ class AutoencoderKL(pl.LightningModule):
return posterior
def decode(self, z):
try:
return self.decode_all_at_once(z)
except OutOfMemoryError:
# Out of memory, trying sliced decoding.
try:
return self.decode_sliced(z, chunk_size=128)
except OutOfMemoryError:
return self.decode_sliced(z, chunk_size=64)
def decode_all_at_once(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def decode_sliced(self, z, chunk_size=128):
"""
decodes the tensor in slices.
This results in images that don't exactly match, so we overlap, feather, and merge to reduce
(but not completely elminate) impact.
"""
b, c, h, w = z.size()
final_tensor = torch.zeros([1, 3, h * 8, w * 8], device=z.device)
for z_latent in z.split(1):
decoded_chunks = []
overlap_pct = 0.5
chunks = tile_image(
z_latent, tile_size=chunk_size, overlap_percent=overlap_pct
)
for latent_chunk in chunks:
latent_chunk = self.post_quant_conv(latent_chunk)
dec = self.decoder(latent_chunk)
decoded_chunks.append(dec)
final_tensor = rebuild_image(
decoded_chunks,
base_img=final_tensor,
tile_size=chunk_size * 8,
overlap_percent=overlap_pct,
)
return final_tensor
def forward(self, input, sample_posterior=True): # noqa
posterior = self.encode(input)
if sample_posterior:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 324 KiB

After

Width:  |  Height:  |  Size: 370 KiB

@ -54,7 +54,7 @@ def test_model_versions(filename_base_for_orig_outputs, model_version):
)
)
threshold = 24000
threshold = 33000
for i, result in enumerate(imagine(prompts)):
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model}.png"
@ -257,7 +257,7 @@ def test_cliptext_inpainting_pearl_doctor(
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}_orig.jpg")
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=2800)
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=12000)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
@ -267,11 +267,11 @@ def test_tile_mode(filename_base_for_outputs):
prompt_text,
width=400,
height=400,
steps=5,
steps=15,
seed=1,
tile_mode="xy",
)
result = next(imagine(prompt))
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=1000)
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=22000)

@ -63,7 +63,7 @@ def test_clip_masking(filename_base_for_outputs):
result = next(imagine(prompt))
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=1100)
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=1200)
boolean_mask_test_cases = [

@ -0,0 +1,132 @@
import itertools
import pytest
from imaginairy import LazyLoadingImage
from imaginairy.feather_tile import rebuild_image, tile_image, tile_setup
from imaginairy.img_utils import pillow_img_to_torch_image, torch_img_to_pillow_img
from tests import TESTS_FOLDER
img_ratios = [0.2, 0.242, 0.3, 0.33333333, 0.5, 0.75, 1, 4 / 3.0, 16 / 9.0, 2, 21 / 9.0]
pcts = [
0,
0.09,
0.1,
0.2,
0.25,
0.3,
1 / 3,
0.4,
0.5,
0.6,
0.7,
0.75,
0.8,
0.9,
1.0,
]
initial_sizes = [512]
flip = [True, False]
cases = [
(1, 256, 0),
(1, 256, 0.125),
(1, 256, 0.25),
(1, 256, 0.5),
(1, 128, 0),
(1, 128, 0.125),
(1, 128, 0.25),
(1, 128, 0.5),
(1, 512, 0),
(0.2, 46, 0.09),
(0.2, 46, 0.1),
(0.242, 46, 0.2),
(0.2, 51, 1 / 3.0),
(0.2, 102, 0.09), # tile size same as width of image
]
@pytest.mark.parametrize("img_ratio, tile_size, overlap_pct", cases)
def test_feather_tile_simple(img_ratio, tile_size, overlap_pct):
img = pillow_img_to_torch_image(
LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg")
)
img = img[:, :, : img.shape[2], : int(img.shape[3] * img_ratio)]
img_sum = img.sum()
tiles = tile_image(img, tile_size=tile_size, overlap_percent=overlap_pct)
tile_coords, tile_size, overlap = tile_setup(
tile_size, overlap_pct, (img.size(2), img.size(3))
)
print(
f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
)
rebuilt = rebuild_image(
tiles, base_img=img, tile_size=tile_size, overlap_percent=overlap_pct
)
assert rebuilt.shape == img.shape
diff = abs(float(rebuilt.sum()) - float(img_sum))
if diff >= 1:
torch_img_to_pillow_img(img).show()
torch_img_to_pillow_img(rebuilt).show()
torch_img_to_pillow_img(rebuilt - img).show()
assert diff < 1
def test_feather_tile_brute():
source_img = pillow_img_to_torch_image(
LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg")
)
def tile_untile(img, tile_size, overlap_percent):
img_sum = img.sum()
tiles = tile_image(img, tile_size=tile_size, overlap_percent=overlap_percent)
tile_coords, tile_size, overlap = tile_setup(
tile_size, overlap_percent, (img.size(2), img.size(3))
)
print(
f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
)
rebuilt = rebuild_image(
tiles, base_img=img, tile_size=tile_size, overlap_percent=overlap_percent
)
assert rebuilt.shape == img.shape
diff = abs(float(rebuilt.sum()) - float(img_sum))
if diff > 1:
torch_img_to_pillow_img(img).show()
torch_img_to_pillow_img(rebuilt).show()
torch_img_to_pillow_img((rebuilt - img) * 20).show()
status = "🚫 FAILED"
else:
status = "✅ PASSED"
print(
f"{status}: img:{img.shape} tile_size={tile_size} overlap_percent={overlap_percent} diff={diff}"
)
assert diff < 1
for tile_size_pct, overlap_percent, img_ratio, flip_ratio in itertools.product(
pcts, pcts, img_ratios, flip
):
if flip_ratio:
img = source_img.clone()[:, :, :, : int(source_img.shape[3] * img_ratio)]
else:
img = source_img.clone()[:, :, : int(source_img.shape[2] * img_ratio), :]
tile_size = int(source_img.shape[3] * tile_size_pct)
if not tile_size:
continue
if overlap_percent >= 0.5:
continue
print(
f"img_ratio={img_ratio}, tile_size_pct={tile_size_pct}, overlap_percent={overlap_percent}, tile_size={tile_size} img.shape={img.shape}"
)
tile_untile(img, tile_size=tile_size, overlap_percent=overlap_percent)
del img
# tile_untile(img, tile_size=256, overlap_percent=0.25)
Loading…
Cancel
Save