feature: sliced latent decoding
allows generation of bigger images. tile seams can be noticeable occasionally despite the featheringpull/252/head
parent
ad62496557
commit
8a97213622
@ -0,0 +1,13 @@
|
||||
|
||||
As seen in:
|
||||
- 2023-02-05 - https://p.migdal.pl/blog/2023/02/ai-arts-information-theory/
|
||||
- 2023-01-26 - YouTube - [Live 96: InstructPix2Pix and STABLE DIFFUSION](https://www.youtube.com/watch?v=AyWDYVFuALs&t=485s)
|
||||
- 2023-01-23 - [Gigazine](https://gigazine.net/gsc_news/en/20230123-imaginairy-ai-image/)
|
||||
- 2023-01-23 - Changelog.com - ["ImaginAIry imagines & edits images from text inputs"](https://changelog.com/news/imaginairy-imagines-edits-images-from-text-inputs-QpzQ)
|
||||
- 2022-12-10 - YouTube - [Live 91: TRANSCRIBE Audio with WHISPER and GENERATE images with STABLE DIFFUSION, Locally](https://www.youtube.com/watch?v=CaLmLP2GTEU&t=1785s)
|
||||
- 2023-01-22 - Hacker News - [Show HN: New AI edits images based on text instructions](https://news.ycombinator.com/item?id=34474270)
|
||||
- 2023-01-10 - YouTube - [Build A Free AI Image Generator Bot in 20 minutes!](https://www.youtube.com/watch?v=ufQcDD1kQCI)
|
||||
- 2022-11-24 - Hacker News - [Stable Diffusion 2.0 on Mac and Linux via imaginAIry Python library](https://news.ycombinator.com/item?id=33729694)
|
||||
|
||||
Used by:
|
||||
- [imageeditor.ai](https://imageeditor.ai/)
|
@ -0,0 +1,205 @@
|
||||
# inspired by https://github.com/ProGamerGov/neural-dream/blob/master/neural_dream/dream_tile.py
|
||||
# but with all the bugs fixed and lots of simplifications
|
||||
# MIT License
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def mask_tile(tile, overlap, std_overlap, side="bottom"):
|
||||
h, w = tile.size(2), tile.size(3)
|
||||
top_overlap, bottom_overlap, right_overlap, left_overlap = overlap
|
||||
(
|
||||
std_top_overlap,
|
||||
std_bottom_overlap,
|
||||
std_right_overlap,
|
||||
std_left_overlap,
|
||||
) = std_overlap
|
||||
|
||||
if "left" in side:
|
||||
lin_mask_left = torch.linspace(0, 1, std_left_overlap, device=tile.device)
|
||||
if left_overlap > std_left_overlap:
|
||||
zeros_mask = torch.zeros(
|
||||
left_overlap - std_left_overlap, device=tile.device
|
||||
)
|
||||
lin_mask_left = (
|
||||
torch.cat([zeros_mask, lin_mask_left], 0)
|
||||
.repeat(h, 1)
|
||||
.repeat(3, 1, 1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
|
||||
if "right" in side:
|
||||
lin_mask_right = (
|
||||
torch.linspace(1, 0, right_overlap, device=tile.device)
|
||||
.repeat(h, 1)
|
||||
.repeat(3, 1, 1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
if "top" in side:
|
||||
lin_mask_top = torch.linspace(0, 1, std_top_overlap, device=tile.device)
|
||||
if top_overlap > std_top_overlap:
|
||||
zeros_mask = torch.zeros(top_overlap - std_top_overlap, device=tile.device)
|
||||
lin_mask_top = torch.cat([zeros_mask, lin_mask_top], 0)
|
||||
lin_mask_top = lin_mask_top.repeat(w, 1).rot90(3).repeat(3, 1, 1).unsqueeze(0)
|
||||
|
||||
if "bottom" in side:
|
||||
lin_mask_bottom = (
|
||||
torch.linspace(1, 0, std_bottom_overlap, device=tile.device)
|
||||
.repeat(w, 1)
|
||||
.rot90(3)
|
||||
.repeat(3, 1, 1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
|
||||
base_mask = torch.ones_like(tile)
|
||||
|
||||
if "right" in side:
|
||||
base_mask[:, :, :, w - right_overlap :] = (
|
||||
base_mask[:, :, :, w - right_overlap :] * lin_mask_right
|
||||
)
|
||||
if "left" in side:
|
||||
base_mask[:, :, :, :left_overlap] = (
|
||||
base_mask[:, :, :, :left_overlap] * lin_mask_left
|
||||
)
|
||||
if "bottom" in side:
|
||||
base_mask[:, :, h - bottom_overlap :, :] = (
|
||||
base_mask[:, :, h - bottom_overlap :, :] * lin_mask_bottom
|
||||
)
|
||||
if "top" in side:
|
||||
base_mask[:, :, :top_overlap, :] = (
|
||||
base_mask[:, :, :top_overlap, :] * lin_mask_top
|
||||
)
|
||||
return tile * base_mask
|
||||
|
||||
|
||||
def get_tile_coords(d, tile_dim, overlap=0):
|
||||
move = int(math.ceil(round(tile_dim * (1 - overlap), 10)))
|
||||
c, tile_start, coords = 1, 0, [0]
|
||||
while tile_start + tile_dim < d:
|
||||
tile_start = move * c
|
||||
if tile_start + tile_dim >= d:
|
||||
coords.append(d - tile_dim)
|
||||
else:
|
||||
coords.append(tile_start)
|
||||
c += 1
|
||||
return coords
|
||||
|
||||
|
||||
def get_tiles(img, tile_coords, tile_size):
|
||||
tile_list = []
|
||||
for y in tile_coords[0]:
|
||||
for x in tile_coords[1]:
|
||||
tile = img[:, :, y : y + tile_size[0], x : x + tile_size[1]]
|
||||
tile_list.append(tile)
|
||||
return tile_list
|
||||
|
||||
|
||||
def final_overlap(tile_coords, tile_size):
|
||||
last_row, last_col = len(tile_coords[0]) - 1, len(tile_coords[1]) - 1
|
||||
|
||||
f_ovlp = [
|
||||
(tile_coords[0][last_row - 1] + tile_size[0]) - (tile_coords[0][last_row]),
|
||||
(tile_coords[1][last_col - 1] + tile_size[1]) - (tile_coords[1][last_col]),
|
||||
]
|
||||
return f_ovlp
|
||||
|
||||
|
||||
def add_tiles(tiles, base_img, tile_coords, tile_size, overlap):
|
||||
f_ovlp = final_overlap(tile_coords, tile_size)
|
||||
h, w = tiles[0].size(2), tiles[0].size(3)
|
||||
if f_ovlp[0] == h:
|
||||
f_ovlp[0] = 0
|
||||
|
||||
if f_ovlp[1] == w:
|
||||
f_ovlp[1] = 0
|
||||
|
||||
t = 0
|
||||
column, row, = (
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
for y in tile_coords[0]:
|
||||
for x in tile_coords[1]:
|
||||
mask_sides = ""
|
||||
c_overlap = overlap.copy()
|
||||
if row == 0:
|
||||
mask_sides += "bottom"
|
||||
elif 0 < row < len(tile_coords[0]) - 2:
|
||||
mask_sides += "bottom,top"
|
||||
elif row == len(tile_coords[0]) - 2:
|
||||
mask_sides += "bottom,top"
|
||||
elif row == len(tile_coords[0]) - 1:
|
||||
mask_sides += "top"
|
||||
if f_ovlp[0] > 0:
|
||||
c_overlap[0] = f_ovlp[0] # Change top overlap
|
||||
|
||||
if column == 0:
|
||||
mask_sides += ",right"
|
||||
elif 0 < column < len(tile_coords[1]) - 2:
|
||||
mask_sides += ",right,left"
|
||||
elif column == len(tile_coords[1]) - 2:
|
||||
mask_sides += ",right,left"
|
||||
elif column == len(tile_coords[1]) - 1:
|
||||
mask_sides += ",left"
|
||||
if f_ovlp[1] > 0:
|
||||
c_overlap[3] = f_ovlp[1] # Change left overlap
|
||||
|
||||
# print(f"mask_tile: tile.shape={tiles[t].shape}, overlap={c_overlap}, side={mask_sides} col={column}, row={row}")
|
||||
|
||||
tile = mask_tile(tiles[t], c_overlap, std_overlap=overlap, side=mask_sides)
|
||||
# torch_img_to_pillow_img(tile).show()
|
||||
base_img[:, :, y : y + tile_size[0], x : x + tile_size[1]] = (
|
||||
base_img[:, :, y : y + tile_size[0], x : x + tile_size[1]] + tile
|
||||
)
|
||||
# torch_img_to_pillow_img(base_img).show()
|
||||
t += 1
|
||||
column += 1
|
||||
|
||||
row += 1
|
||||
# if row >= 2:
|
||||
# exit()
|
||||
column = 0
|
||||
return base_img
|
||||
|
||||
|
||||
def tile_setup(tile_size, overlap_percent, base_size):
|
||||
if not isinstance(tile_size, (tuple, list)):
|
||||
tile_size = (tile_size, tile_size)
|
||||
if not isinstance(overlap_percent, (tuple, list)):
|
||||
overlap_percent = (overlap_percent, overlap_percent)
|
||||
if min(tile_size) < 1:
|
||||
raise ValueError("tile_size must be at least 1")
|
||||
|
||||
if max(overlap_percent) > 0.5:
|
||||
raise ValueError("overlap_percent must not be greater than 0.5")
|
||||
|
||||
x_coords = get_tile_coords(base_size[1], tile_size[1], overlap_percent[1])
|
||||
y_coords = get_tile_coords(base_size[0], tile_size[0], overlap_percent[0])
|
||||
y_ovlp = int(math.floor(round(tile_size[0] * overlap_percent[0], 10)))
|
||||
x_ovlp = int(math.floor(round(tile_size[1] * overlap_percent[1], 10)))
|
||||
if len(x_coords) == 1:
|
||||
x_ovlp = 0
|
||||
if len(y_coords) == 1:
|
||||
y_ovlp = 0
|
||||
|
||||
return (y_coords, x_coords), tile_size, [y_ovlp, y_ovlp, x_ovlp, x_ovlp]
|
||||
|
||||
|
||||
def tile_image(img, tile_size, overlap_percent):
|
||||
tile_coords, tile_size, _ = tile_setup(
|
||||
tile_size, overlap_percent, (img.size(2), img.size(3))
|
||||
)
|
||||
|
||||
return get_tiles(img, tile_coords, tile_size)
|
||||
|
||||
|
||||
def rebuild_image(tiles, base_img, tile_size, overlap_percent):
|
||||
if len(tiles) == 1:
|
||||
return tiles[0]
|
||||
base_img = torch.zeros_like(base_img)
|
||||
tile_coords, tile_size, overlap = tile_setup(
|
||||
tile_size, overlap_percent, (base_img.size(2), base_img.size(3))
|
||||
)
|
||||
return add_tiles(tiles, base_img, tile_coords, tile_size, overlap)
|
Binary file not shown.
Before Width: | Height: | Size: 324 KiB After Width: | Height: | Size: 370 KiB |
@ -0,0 +1,132 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
|
||||
from imaginairy import LazyLoadingImage
|
||||
from imaginairy.feather_tile import rebuild_image, tile_image, tile_setup
|
||||
from imaginairy.img_utils import pillow_img_to_torch_image, torch_img_to_pillow_img
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
img_ratios = [0.2, 0.242, 0.3, 0.33333333, 0.5, 0.75, 1, 4 / 3.0, 16 / 9.0, 2, 21 / 9.0]
|
||||
pcts = [
|
||||
0,
|
||||
0.09,
|
||||
0.1,
|
||||
0.2,
|
||||
0.25,
|
||||
0.3,
|
||||
1 / 3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.75,
|
||||
0.8,
|
||||
0.9,
|
||||
1.0,
|
||||
]
|
||||
initial_sizes = [512]
|
||||
flip = [True, False]
|
||||
|
||||
cases = [
|
||||
(1, 256, 0),
|
||||
(1, 256, 0.125),
|
||||
(1, 256, 0.25),
|
||||
(1, 256, 0.5),
|
||||
(1, 128, 0),
|
||||
(1, 128, 0.125),
|
||||
(1, 128, 0.25),
|
||||
(1, 128, 0.5),
|
||||
(1, 512, 0),
|
||||
(0.2, 46, 0.09),
|
||||
(0.2, 46, 0.1),
|
||||
(0.242, 46, 0.2),
|
||||
(0.2, 51, 1 / 3.0),
|
||||
(0.2, 102, 0.09), # tile size same as width of image
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("img_ratio, tile_size, overlap_pct", cases)
|
||||
def test_feather_tile_simple(img_ratio, tile_size, overlap_pct):
|
||||
img = pillow_img_to_torch_image(
|
||||
LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg")
|
||||
)
|
||||
img = img[:, :, : img.shape[2], : int(img.shape[3] * img_ratio)]
|
||||
img_sum = img.sum()
|
||||
tiles = tile_image(img, tile_size=tile_size, overlap_percent=overlap_pct)
|
||||
tile_coords, tile_size, overlap = tile_setup(
|
||||
tile_size, overlap_pct, (img.size(2), img.size(3))
|
||||
)
|
||||
|
||||
print(
|
||||
f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
|
||||
)
|
||||
|
||||
rebuilt = rebuild_image(
|
||||
tiles, base_img=img, tile_size=tile_size, overlap_percent=overlap_pct
|
||||
)
|
||||
assert rebuilt.shape == img.shape
|
||||
diff = abs(float(rebuilt.sum()) - float(img_sum))
|
||||
if diff >= 1:
|
||||
torch_img_to_pillow_img(img).show()
|
||||
torch_img_to_pillow_img(rebuilt).show()
|
||||
torch_img_to_pillow_img(rebuilt - img).show()
|
||||
|
||||
assert diff < 1
|
||||
|
||||
|
||||
def test_feather_tile_brute():
|
||||
source_img = pillow_img_to_torch_image(
|
||||
LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg")
|
||||
)
|
||||
|
||||
def tile_untile(img, tile_size, overlap_percent):
|
||||
img_sum = img.sum()
|
||||
tiles = tile_image(img, tile_size=tile_size, overlap_percent=overlap_percent)
|
||||
tile_coords, tile_size, overlap = tile_setup(
|
||||
tile_size, overlap_percent, (img.size(2), img.size(3))
|
||||
)
|
||||
print(
|
||||
f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
|
||||
)
|
||||
|
||||
rebuilt = rebuild_image(
|
||||
tiles, base_img=img, tile_size=tile_size, overlap_percent=overlap_percent
|
||||
)
|
||||
assert rebuilt.shape == img.shape
|
||||
diff = abs(float(rebuilt.sum()) - float(img_sum))
|
||||
if diff > 1:
|
||||
torch_img_to_pillow_img(img).show()
|
||||
torch_img_to_pillow_img(rebuilt).show()
|
||||
torch_img_to_pillow_img((rebuilt - img) * 20).show()
|
||||
|
||||
status = "🚫 FAILED"
|
||||
|
||||
else:
|
||||
status = "✅ PASSED"
|
||||
print(
|
||||
f"{status}: img:{img.shape} tile_size={tile_size} overlap_percent={overlap_percent} diff={diff}"
|
||||
)
|
||||
assert diff < 1
|
||||
|
||||
for tile_size_pct, overlap_percent, img_ratio, flip_ratio in itertools.product(
|
||||
pcts, pcts, img_ratios, flip
|
||||
):
|
||||
if flip_ratio:
|
||||
img = source_img.clone()[:, :, :, : int(source_img.shape[3] * img_ratio)]
|
||||
else:
|
||||
img = source_img.clone()[:, :, : int(source_img.shape[2] * img_ratio), :]
|
||||
tile_size = int(source_img.shape[3] * tile_size_pct)
|
||||
if not tile_size:
|
||||
continue
|
||||
|
||||
if overlap_percent >= 0.5:
|
||||
continue
|
||||
|
||||
print(
|
||||
f"img_ratio={img_ratio}, tile_size_pct={tile_size_pct}, overlap_percent={overlap_percent}, tile_size={tile_size} img.shape={img.shape}"
|
||||
)
|
||||
tile_untile(img, tile_size=tile_size, overlap_percent=overlap_percent)
|
||||
del img
|
||||
|
||||
# tile_untile(img, tile_size=256, overlap_percent=0.25)
|
Loading…
Reference in New Issue