nikunjkdtechnoland
commited on
Commit
•
4b98c85
1
Parent(s):
e041d7d
some more add more files
Browse files- iopaint/file_manager/utils.py +65 -0
- iopaint/model/anytext/ldm/modules/diffusionmodules/upscaling.py +81 -0
- iopaint/model/anytext/ldm/modules/diffusionmodules/util.py +271 -0
- iopaint/model/anytext/ldm/util.py +197 -0
- iopaint/model/anytext/utils.py +151 -0
- iopaint/model/original_sd_configs/v1-inference.yaml +70 -0
- iopaint/model/original_sd_configs/v2-inference-v.yaml +68 -0
- iopaint/model/utils.py +1033 -0
- iopaint/model/zits.py +476 -0
- iopaint/plugins/segment_anything/modeling/tiny_vit_sam.py +822 -0
- iopaint/plugins/segment_anything/modeling/transformer.py +240 -0
- iopaint/plugins/segment_anything/utils/transforms.py +112 -0
- iopaint/tests/test_sdxl.py +172 -0
- iopaint/tests/utils.py +77 -0
- iopaint/web_config.py +307 -0
- pretrained-model/version.txt +1 -0
- pretrained-model/version_diffusers_cache.txt +1 -0
- utils/tools.py +505 -0
iopaint/file_manager/utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
|
2 |
+
import hashlib
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
|
8 |
+
def generate_filename(directory: Path, original_filename, *options) -> str:
|
9 |
+
text = str(directory.absolute()) + original_filename
|
10 |
+
for v in options:
|
11 |
+
text += "%s" % v
|
12 |
+
md5_hash = hashlib.md5()
|
13 |
+
md5_hash.update(text.encode("utf-8"))
|
14 |
+
return md5_hash.hexdigest() + ".jpg"
|
15 |
+
|
16 |
+
|
17 |
+
def parse_size(size):
|
18 |
+
if isinstance(size, int):
|
19 |
+
# If the size parameter is a single number, assume square aspect.
|
20 |
+
return [size, size]
|
21 |
+
|
22 |
+
if isinstance(size, (tuple, list)):
|
23 |
+
if len(size) == 1:
|
24 |
+
# If single value tuple/list is provided, exand it to two elements
|
25 |
+
return size + type(size)(size)
|
26 |
+
return size
|
27 |
+
|
28 |
+
try:
|
29 |
+
thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
|
30 |
+
except ValueError:
|
31 |
+
raise ValueError( # pylint: disable=raise-missing-from
|
32 |
+
"Bad thumbnail size format. Valid format is INTxINT."
|
33 |
+
)
|
34 |
+
|
35 |
+
if len(thumbnail_size) == 1:
|
36 |
+
# If the size parameter only contains a single integer, assume square aspect.
|
37 |
+
thumbnail_size.append(thumbnail_size[0])
|
38 |
+
|
39 |
+
return thumbnail_size
|
40 |
+
|
41 |
+
|
42 |
+
def aspect_to_string(size):
|
43 |
+
if isinstance(size, str):
|
44 |
+
return size
|
45 |
+
|
46 |
+
return "x".join(map(str, size))
|
47 |
+
|
48 |
+
|
49 |
+
IMG_SUFFIX = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
|
50 |
+
|
51 |
+
|
52 |
+
def glob_img(p: Union[Path, str], recursive: bool = False):
|
53 |
+
p = Path(p)
|
54 |
+
if p.is_file() and p.suffix in IMG_SUFFIX:
|
55 |
+
yield p
|
56 |
+
else:
|
57 |
+
if recursive:
|
58 |
+
files = Path(p).glob("**/*.*")
|
59 |
+
else:
|
60 |
+
files = Path(p).glob("*.*")
|
61 |
+
|
62 |
+
for it in files:
|
63 |
+
if it.suffix not in IMG_SUFFIX:
|
64 |
+
continue
|
65 |
+
yield it
|
iopaint/model/anytext/ldm/modules/diffusionmodules/upscaling.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
|
7 |
+
from iopaint.model.anytext.ldm.util import default
|
8 |
+
|
9 |
+
|
10 |
+
class AbstractLowScaleModel(nn.Module):
|
11 |
+
# for concatenating a downsampled image to the latent representation
|
12 |
+
def __init__(self, noise_schedule_config=None):
|
13 |
+
super(AbstractLowScaleModel, self).__init__()
|
14 |
+
if noise_schedule_config is not None:
|
15 |
+
self.register_schedule(**noise_schedule_config)
|
16 |
+
|
17 |
+
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
18 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
19 |
+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
20 |
+
cosine_s=cosine_s)
|
21 |
+
alphas = 1. - betas
|
22 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
23 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
24 |
+
|
25 |
+
timesteps, = betas.shape
|
26 |
+
self.num_timesteps = int(timesteps)
|
27 |
+
self.linear_start = linear_start
|
28 |
+
self.linear_end = linear_end
|
29 |
+
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
30 |
+
|
31 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
32 |
+
|
33 |
+
self.register_buffer('betas', to_torch(betas))
|
34 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
35 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
36 |
+
|
37 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
38 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
39 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
40 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
41 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
42 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
43 |
+
|
44 |
+
def q_sample(self, x_start, t, noise=None):
|
45 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
46 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
47 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return x, None
|
51 |
+
|
52 |
+
def decode(self, x):
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
class SimpleImageConcat(AbstractLowScaleModel):
|
57 |
+
# no noise level conditioning
|
58 |
+
def __init__(self):
|
59 |
+
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
|
60 |
+
self.max_noise_level = 0
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
# fix to constant noise level
|
64 |
+
return x, torch.zeros(x.shape[0], device=x.device).long()
|
65 |
+
|
66 |
+
|
67 |
+
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
68 |
+
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
|
69 |
+
super().__init__(noise_schedule_config=noise_schedule_config)
|
70 |
+
self.max_noise_level = max_noise_level
|
71 |
+
|
72 |
+
def forward(self, x, noise_level=None):
|
73 |
+
if noise_level is None:
|
74 |
+
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
75 |
+
else:
|
76 |
+
assert isinstance(noise_level, torch.Tensor)
|
77 |
+
z = self.q_sample(x, noise_level)
|
78 |
+
return z, noise_level
|
79 |
+
|
80 |
+
|
81 |
+
|
iopaint/model/anytext/ldm/modules/diffusionmodules/util.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adopted from
|
2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
3 |
+
# and
|
4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
5 |
+
# and
|
6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
7 |
+
#
|
8 |
+
# thanks!
|
9 |
+
|
10 |
+
|
11 |
+
import os
|
12 |
+
import math
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import numpy as np
|
16 |
+
from einops import repeat
|
17 |
+
|
18 |
+
from iopaint.model.anytext.ldm.util import instantiate_from_config
|
19 |
+
|
20 |
+
|
21 |
+
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
22 |
+
if schedule == "linear":
|
23 |
+
betas = (
|
24 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
25 |
+
)
|
26 |
+
|
27 |
+
elif schedule == "cosine":
|
28 |
+
timesteps = (
|
29 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
30 |
+
)
|
31 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
32 |
+
alphas = torch.cos(alphas).pow(2)
|
33 |
+
alphas = alphas / alphas[0]
|
34 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
35 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
36 |
+
|
37 |
+
elif schedule == "sqrt_linear":
|
38 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
39 |
+
elif schedule == "sqrt":
|
40 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
41 |
+
else:
|
42 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
43 |
+
return betas.numpy()
|
44 |
+
|
45 |
+
|
46 |
+
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
47 |
+
if ddim_discr_method == 'uniform':
|
48 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
49 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
50 |
+
elif ddim_discr_method == 'quad':
|
51 |
+
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
52 |
+
else:
|
53 |
+
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
54 |
+
|
55 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
56 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
57 |
+
steps_out = ddim_timesteps + 1
|
58 |
+
if verbose:
|
59 |
+
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
60 |
+
return steps_out
|
61 |
+
|
62 |
+
|
63 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
64 |
+
# select alphas for computing the variance schedule
|
65 |
+
alphas = alphacums[ddim_timesteps]
|
66 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
67 |
+
|
68 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
69 |
+
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
70 |
+
if verbose:
|
71 |
+
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
72 |
+
print(f'For the chosen value of eta, which is {eta}, '
|
73 |
+
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
74 |
+
return sigmas.to(torch.float32), alphas.to(torch.float32), alphas_prev.astype(np.float32)
|
75 |
+
|
76 |
+
|
77 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
78 |
+
"""
|
79 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
80 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
81 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
82 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
83 |
+
produces the cumulative product of (1-beta) up to that
|
84 |
+
part of the diffusion process.
|
85 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
86 |
+
prevent singularities.
|
87 |
+
"""
|
88 |
+
betas = []
|
89 |
+
for i in range(num_diffusion_timesteps):
|
90 |
+
t1 = i / num_diffusion_timesteps
|
91 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
92 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
93 |
+
return np.array(betas)
|
94 |
+
|
95 |
+
|
96 |
+
def extract_into_tensor(a, t, x_shape):
|
97 |
+
b, *_ = t.shape
|
98 |
+
out = a.gather(-1, t)
|
99 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
100 |
+
|
101 |
+
|
102 |
+
def checkpoint(func, inputs, params, flag):
|
103 |
+
"""
|
104 |
+
Evaluate a function without caching intermediate activations, allowing for
|
105 |
+
reduced memory at the expense of extra compute in the backward pass.
|
106 |
+
:param func: the function to evaluate.
|
107 |
+
:param inputs: the argument sequence to pass to `func`.
|
108 |
+
:param params: a sequence of parameters `func` depends on but does not
|
109 |
+
explicitly take as arguments.
|
110 |
+
:param flag: if False, disable gradient checkpointing.
|
111 |
+
"""
|
112 |
+
if flag:
|
113 |
+
args = tuple(inputs) + tuple(params)
|
114 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
115 |
+
else:
|
116 |
+
return func(*inputs)
|
117 |
+
|
118 |
+
|
119 |
+
class CheckpointFunction(torch.autograd.Function):
|
120 |
+
@staticmethod
|
121 |
+
def forward(ctx, run_function, length, *args):
|
122 |
+
ctx.run_function = run_function
|
123 |
+
ctx.input_tensors = list(args[:length])
|
124 |
+
ctx.input_params = list(args[length:])
|
125 |
+
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
126 |
+
"dtype": torch.get_autocast_gpu_dtype(),
|
127 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
128 |
+
with torch.no_grad():
|
129 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
130 |
+
return output_tensors
|
131 |
+
|
132 |
+
@staticmethod
|
133 |
+
def backward(ctx, *output_grads):
|
134 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
135 |
+
with torch.enable_grad(), \
|
136 |
+
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
137 |
+
# Fixes a bug where the first op in run_function modifies the
|
138 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
139 |
+
# Tensors.
|
140 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
141 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
142 |
+
input_grads = torch.autograd.grad(
|
143 |
+
output_tensors,
|
144 |
+
ctx.input_tensors + ctx.input_params,
|
145 |
+
output_grads,
|
146 |
+
allow_unused=True,
|
147 |
+
)
|
148 |
+
del ctx.input_tensors
|
149 |
+
del ctx.input_params
|
150 |
+
del output_tensors
|
151 |
+
return (None, None) + input_grads
|
152 |
+
|
153 |
+
|
154 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
155 |
+
"""
|
156 |
+
Create sinusoidal timestep embeddings.
|
157 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
158 |
+
These may be fractional.
|
159 |
+
:param dim: the dimension of the output.
|
160 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
161 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
162 |
+
"""
|
163 |
+
if not repeat_only:
|
164 |
+
half = dim // 2
|
165 |
+
freqs = torch.exp(
|
166 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
167 |
+
).to(device=timesteps.device)
|
168 |
+
args = timesteps[:, None].float() * freqs[None]
|
169 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
170 |
+
if dim % 2:
|
171 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
172 |
+
else:
|
173 |
+
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
174 |
+
return embedding
|
175 |
+
|
176 |
+
|
177 |
+
def zero_module(module):
|
178 |
+
"""
|
179 |
+
Zero out the parameters of a module and return it.
|
180 |
+
"""
|
181 |
+
for p in module.parameters():
|
182 |
+
p.detach().zero_()
|
183 |
+
return module
|
184 |
+
|
185 |
+
|
186 |
+
def scale_module(module, scale):
|
187 |
+
"""
|
188 |
+
Scale the parameters of a module and return it.
|
189 |
+
"""
|
190 |
+
for p in module.parameters():
|
191 |
+
p.detach().mul_(scale)
|
192 |
+
return module
|
193 |
+
|
194 |
+
|
195 |
+
def mean_flat(tensor):
|
196 |
+
"""
|
197 |
+
Take the mean over all non-batch dimensions.
|
198 |
+
"""
|
199 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
200 |
+
|
201 |
+
|
202 |
+
def normalization(channels):
|
203 |
+
"""
|
204 |
+
Make a standard normalization layer.
|
205 |
+
:param channels: number of input channels.
|
206 |
+
:return: an nn.Module for normalization.
|
207 |
+
"""
|
208 |
+
return GroupNorm32(32, channels)
|
209 |
+
|
210 |
+
|
211 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
212 |
+
class SiLU(nn.Module):
|
213 |
+
def forward(self, x):
|
214 |
+
return x * torch.sigmoid(x)
|
215 |
+
|
216 |
+
|
217 |
+
class GroupNorm32(nn.GroupNorm):
|
218 |
+
def forward(self, x):
|
219 |
+
# return super().forward(x.float()).type(x.dtype)
|
220 |
+
return super().forward(x).type(x.dtype)
|
221 |
+
|
222 |
+
def conv_nd(dims, *args, **kwargs):
|
223 |
+
"""
|
224 |
+
Create a 1D, 2D, or 3D convolution module.
|
225 |
+
"""
|
226 |
+
if dims == 1:
|
227 |
+
return nn.Conv1d(*args, **kwargs)
|
228 |
+
elif dims == 2:
|
229 |
+
return nn.Conv2d(*args, **kwargs)
|
230 |
+
elif dims == 3:
|
231 |
+
return nn.Conv3d(*args, **kwargs)
|
232 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
233 |
+
|
234 |
+
|
235 |
+
def linear(*args, **kwargs):
|
236 |
+
"""
|
237 |
+
Create a linear module.
|
238 |
+
"""
|
239 |
+
return nn.Linear(*args, **kwargs)
|
240 |
+
|
241 |
+
|
242 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
243 |
+
"""
|
244 |
+
Create a 1D, 2D, or 3D average pooling module.
|
245 |
+
"""
|
246 |
+
if dims == 1:
|
247 |
+
return nn.AvgPool1d(*args, **kwargs)
|
248 |
+
elif dims == 2:
|
249 |
+
return nn.AvgPool2d(*args, **kwargs)
|
250 |
+
elif dims == 3:
|
251 |
+
return nn.AvgPool3d(*args, **kwargs)
|
252 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
253 |
+
|
254 |
+
|
255 |
+
class HybridConditioner(nn.Module):
|
256 |
+
|
257 |
+
def __init__(self, c_concat_config, c_crossattn_config):
|
258 |
+
super().__init__()
|
259 |
+
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
260 |
+
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
261 |
+
|
262 |
+
def forward(self, c_concat, c_crossattn):
|
263 |
+
c_concat = self.concat_conditioner(c_concat)
|
264 |
+
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
265 |
+
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
266 |
+
|
267 |
+
|
268 |
+
def noise_like(shape, device, repeat=False):
|
269 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
270 |
+
noise = lambda: torch.randn(shape, device=device)
|
271 |
+
return repeat_noise() if repeat else noise()
|
iopaint/model/anytext/ldm/util.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import optim
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from inspect import isfunction
|
8 |
+
from PIL import Image, ImageDraw, ImageFont
|
9 |
+
|
10 |
+
|
11 |
+
def log_txt_as_img(wh, xc, size=10):
|
12 |
+
# wh a tuple of (width, height)
|
13 |
+
# xc a list of captions to plot
|
14 |
+
b = len(xc)
|
15 |
+
txts = list()
|
16 |
+
for bi in range(b):
|
17 |
+
txt = Image.new("RGB", wh, color="white")
|
18 |
+
draw = ImageDraw.Draw(txt)
|
19 |
+
font = ImageFont.truetype('font/Arial_Unicode.ttf', size=size)
|
20 |
+
nc = int(32 * (wh[0] / 256))
|
21 |
+
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
22 |
+
|
23 |
+
try:
|
24 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
25 |
+
except UnicodeEncodeError:
|
26 |
+
print("Cant encode string for logging. Skipping.")
|
27 |
+
|
28 |
+
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
29 |
+
txts.append(txt)
|
30 |
+
txts = np.stack(txts)
|
31 |
+
txts = torch.tensor(txts)
|
32 |
+
return txts
|
33 |
+
|
34 |
+
|
35 |
+
def ismap(x):
|
36 |
+
if not isinstance(x, torch.Tensor):
|
37 |
+
return False
|
38 |
+
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
39 |
+
|
40 |
+
|
41 |
+
def isimage(x):
|
42 |
+
if not isinstance(x,torch.Tensor):
|
43 |
+
return False
|
44 |
+
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
45 |
+
|
46 |
+
|
47 |
+
def exists(x):
|
48 |
+
return x is not None
|
49 |
+
|
50 |
+
|
51 |
+
def default(val, d):
|
52 |
+
if exists(val):
|
53 |
+
return val
|
54 |
+
return d() if isfunction(d) else d
|
55 |
+
|
56 |
+
|
57 |
+
def mean_flat(tensor):
|
58 |
+
"""
|
59 |
+
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
60 |
+
Take the mean over all non-batch dimensions.
|
61 |
+
"""
|
62 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
63 |
+
|
64 |
+
|
65 |
+
def count_params(model, verbose=False):
|
66 |
+
total_params = sum(p.numel() for p in model.parameters())
|
67 |
+
if verbose:
|
68 |
+
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
69 |
+
return total_params
|
70 |
+
|
71 |
+
|
72 |
+
def instantiate_from_config(config, **kwargs):
|
73 |
+
if "target" not in config:
|
74 |
+
if config == '__is_first_stage__':
|
75 |
+
return None
|
76 |
+
elif config == "__is_unconditional__":
|
77 |
+
return None
|
78 |
+
raise KeyError("Expected key `target` to instantiate.")
|
79 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
|
80 |
+
|
81 |
+
|
82 |
+
def get_obj_from_str(string, reload=False):
|
83 |
+
module, cls = string.rsplit(".", 1)
|
84 |
+
if reload:
|
85 |
+
module_imp = importlib.import_module(module)
|
86 |
+
importlib.reload(module_imp)
|
87 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
88 |
+
|
89 |
+
|
90 |
+
class AdamWwithEMAandWings(optim.Optimizer):
|
91 |
+
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
92 |
+
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
93 |
+
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
94 |
+
ema_power=1., param_names=()):
|
95 |
+
"""AdamW that saves EMA versions of the parameters."""
|
96 |
+
if not 0.0 <= lr:
|
97 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
98 |
+
if not 0.0 <= eps:
|
99 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
100 |
+
if not 0.0 <= betas[0] < 1.0:
|
101 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
102 |
+
if not 0.0 <= betas[1] < 1.0:
|
103 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
104 |
+
if not 0.0 <= weight_decay:
|
105 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
106 |
+
if not 0.0 <= ema_decay <= 1.0:
|
107 |
+
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
108 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
109 |
+
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
110 |
+
ema_power=ema_power, param_names=param_names)
|
111 |
+
super().__init__(params, defaults)
|
112 |
+
|
113 |
+
def __setstate__(self, state):
|
114 |
+
super().__setstate__(state)
|
115 |
+
for group in self.param_groups:
|
116 |
+
group.setdefault('amsgrad', False)
|
117 |
+
|
118 |
+
@torch.no_grad()
|
119 |
+
def step(self, closure=None):
|
120 |
+
"""Performs a single optimization step.
|
121 |
+
Args:
|
122 |
+
closure (callable, optional): A closure that reevaluates the model
|
123 |
+
and returns the loss.
|
124 |
+
"""
|
125 |
+
loss = None
|
126 |
+
if closure is not None:
|
127 |
+
with torch.enable_grad():
|
128 |
+
loss = closure()
|
129 |
+
|
130 |
+
for group in self.param_groups:
|
131 |
+
params_with_grad = []
|
132 |
+
grads = []
|
133 |
+
exp_avgs = []
|
134 |
+
exp_avg_sqs = []
|
135 |
+
ema_params_with_grad = []
|
136 |
+
state_sums = []
|
137 |
+
max_exp_avg_sqs = []
|
138 |
+
state_steps = []
|
139 |
+
amsgrad = group['amsgrad']
|
140 |
+
beta1, beta2 = group['betas']
|
141 |
+
ema_decay = group['ema_decay']
|
142 |
+
ema_power = group['ema_power']
|
143 |
+
|
144 |
+
for p in group['params']:
|
145 |
+
if p.grad is None:
|
146 |
+
continue
|
147 |
+
params_with_grad.append(p)
|
148 |
+
if p.grad.is_sparse:
|
149 |
+
raise RuntimeError('AdamW does not support sparse gradients')
|
150 |
+
grads.append(p.grad)
|
151 |
+
|
152 |
+
state = self.state[p]
|
153 |
+
|
154 |
+
# State initialization
|
155 |
+
if len(state) == 0:
|
156 |
+
state['step'] = 0
|
157 |
+
# Exponential moving average of gradient values
|
158 |
+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
159 |
+
# Exponential moving average of squared gradient values
|
160 |
+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
161 |
+
if amsgrad:
|
162 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
163 |
+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
164 |
+
# Exponential moving average of parameter values
|
165 |
+
state['param_exp_avg'] = p.detach().float().clone()
|
166 |
+
|
167 |
+
exp_avgs.append(state['exp_avg'])
|
168 |
+
exp_avg_sqs.append(state['exp_avg_sq'])
|
169 |
+
ema_params_with_grad.append(state['param_exp_avg'])
|
170 |
+
|
171 |
+
if amsgrad:
|
172 |
+
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
173 |
+
|
174 |
+
# update the steps for each param group update
|
175 |
+
state['step'] += 1
|
176 |
+
# record the step after step update
|
177 |
+
state_steps.append(state['step'])
|
178 |
+
|
179 |
+
optim._functional.adamw(params_with_grad,
|
180 |
+
grads,
|
181 |
+
exp_avgs,
|
182 |
+
exp_avg_sqs,
|
183 |
+
max_exp_avg_sqs,
|
184 |
+
state_steps,
|
185 |
+
amsgrad=amsgrad,
|
186 |
+
beta1=beta1,
|
187 |
+
beta2=beta2,
|
188 |
+
lr=group['lr'],
|
189 |
+
weight_decay=group['weight_decay'],
|
190 |
+
eps=group['eps'],
|
191 |
+
maximize=False)
|
192 |
+
|
193 |
+
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
194 |
+
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
195 |
+
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
196 |
+
|
197 |
+
return loss
|
iopaint/model/anytext/utils.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import datetime
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image, ImageDraw
|
6 |
+
|
7 |
+
|
8 |
+
def save_images(img_list, folder):
|
9 |
+
if not os.path.exists(folder):
|
10 |
+
os.makedirs(folder)
|
11 |
+
now = datetime.datetime.now()
|
12 |
+
date_str = now.strftime("%Y-%m-%d")
|
13 |
+
folder_path = os.path.join(folder, date_str)
|
14 |
+
if not os.path.exists(folder_path):
|
15 |
+
os.makedirs(folder_path)
|
16 |
+
time_str = now.strftime("%H_%M_%S")
|
17 |
+
for idx, img in enumerate(img_list):
|
18 |
+
image_number = idx + 1
|
19 |
+
filename = f"{time_str}_{image_number}.jpg"
|
20 |
+
save_path = os.path.join(folder_path, filename)
|
21 |
+
cv2.imwrite(save_path, img[..., ::-1])
|
22 |
+
|
23 |
+
|
24 |
+
def check_channels(image):
|
25 |
+
channels = image.shape[2] if len(image.shape) == 3 else 1
|
26 |
+
if channels == 1:
|
27 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
28 |
+
elif channels > 3:
|
29 |
+
image = image[:, :, :3]
|
30 |
+
return image
|
31 |
+
|
32 |
+
|
33 |
+
def resize_image(img, max_length=768):
|
34 |
+
height, width = img.shape[:2]
|
35 |
+
max_dimension = max(height, width)
|
36 |
+
|
37 |
+
if max_dimension > max_length:
|
38 |
+
scale_factor = max_length / max_dimension
|
39 |
+
new_width = int(round(width * scale_factor))
|
40 |
+
new_height = int(round(height * scale_factor))
|
41 |
+
new_size = (new_width, new_height)
|
42 |
+
img = cv2.resize(img, new_size)
|
43 |
+
height, width = img.shape[:2]
|
44 |
+
img = cv2.resize(img, (width - (width % 64), height - (height % 64)))
|
45 |
+
return img
|
46 |
+
|
47 |
+
|
48 |
+
def insert_spaces(string, nSpace):
|
49 |
+
if nSpace == 0:
|
50 |
+
return string
|
51 |
+
new_string = ""
|
52 |
+
for char in string:
|
53 |
+
new_string += char + " " * nSpace
|
54 |
+
return new_string[:-nSpace]
|
55 |
+
|
56 |
+
|
57 |
+
def draw_glyph(font, text):
|
58 |
+
g_size = 50
|
59 |
+
W, H = (512, 80)
|
60 |
+
new_font = font.font_variant(size=g_size)
|
61 |
+
img = Image.new(mode="1", size=(W, H), color=0)
|
62 |
+
draw = ImageDraw.Draw(img)
|
63 |
+
left, top, right, bottom = new_font.getbbox(text)
|
64 |
+
text_width = max(right - left, 5)
|
65 |
+
text_height = max(bottom - top, 5)
|
66 |
+
ratio = min(W * 0.9 / text_width, H * 0.9 / text_height)
|
67 |
+
new_font = font.font_variant(size=int(g_size * ratio))
|
68 |
+
|
69 |
+
text_width, text_height = new_font.getsize(text)
|
70 |
+
offset_x, offset_y = new_font.getoffset(text)
|
71 |
+
x = (img.width - text_width) // 2
|
72 |
+
y = (img.height - text_height) // 2 - offset_y // 2
|
73 |
+
draw.text((x, y), text, font=new_font, fill="white")
|
74 |
+
img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
|
75 |
+
return img
|
76 |
+
|
77 |
+
|
78 |
+
def draw_glyph2(
|
79 |
+
font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True
|
80 |
+
):
|
81 |
+
enlarge_polygon = polygon * scale
|
82 |
+
rect = cv2.minAreaRect(enlarge_polygon)
|
83 |
+
box = cv2.boxPoints(rect)
|
84 |
+
box = np.int0(box)
|
85 |
+
w, h = rect[1]
|
86 |
+
angle = rect[2]
|
87 |
+
if angle < -45:
|
88 |
+
angle += 90
|
89 |
+
angle = -angle
|
90 |
+
if w < h:
|
91 |
+
angle += 90
|
92 |
+
|
93 |
+
vert = False
|
94 |
+
if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng:
|
95 |
+
_w = max(box[:, 0]) - min(box[:, 0])
|
96 |
+
_h = max(box[:, 1]) - min(box[:, 1])
|
97 |
+
if _h >= _w:
|
98 |
+
vert = True
|
99 |
+
angle = 0
|
100 |
+
|
101 |
+
img = np.zeros((height * scale, width * scale, 3), np.uint8)
|
102 |
+
img = Image.fromarray(img)
|
103 |
+
|
104 |
+
# infer font size
|
105 |
+
image4ratio = Image.new("RGB", img.size, "white")
|
106 |
+
draw = ImageDraw.Draw(image4ratio)
|
107 |
+
_, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
|
108 |
+
text_w = min(w, h) * (_tw / _th)
|
109 |
+
if text_w <= max(w, h):
|
110 |
+
# add space
|
111 |
+
if len(text) > 1 and not vert and add_space:
|
112 |
+
for i in range(1, 100):
|
113 |
+
text_space = insert_spaces(text, i)
|
114 |
+
_, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)
|
115 |
+
if min(w, h) * (_tw2 / _th2) > max(w, h):
|
116 |
+
break
|
117 |
+
text = insert_spaces(text, i - 1)
|
118 |
+
font_size = min(w, h) * 0.80
|
119 |
+
else:
|
120 |
+
shrink = 0.75 if vert else 0.85
|
121 |
+
font_size = min(w, h) / (text_w / max(w, h)) * shrink
|
122 |
+
new_font = font.font_variant(size=int(font_size))
|
123 |
+
|
124 |
+
left, top, right, bottom = new_font.getbbox(text)
|
125 |
+
text_width = right - left
|
126 |
+
text_height = bottom - top
|
127 |
+
|
128 |
+
layer = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
129 |
+
draw = ImageDraw.Draw(layer)
|
130 |
+
if not vert:
|
131 |
+
draw.text(
|
132 |
+
(rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top),
|
133 |
+
text,
|
134 |
+
font=new_font,
|
135 |
+
fill=(255, 255, 255, 255),
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
x_s = min(box[:, 0]) + _w // 2 - text_height // 2
|
139 |
+
y_s = min(box[:, 1])
|
140 |
+
for c in text:
|
141 |
+
draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
|
142 |
+
_, _t, _, _b = new_font.getbbox(c)
|
143 |
+
y_s += _b
|
144 |
+
|
145 |
+
rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))
|
146 |
+
|
147 |
+
x_offset = int((img.width - rotated_layer.width) / 2)
|
148 |
+
y_offset = int((img.height - rotated_layer.height) / 2)
|
149 |
+
img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
|
150 |
+
img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64)
|
151 |
+
return img
|
iopaint/model/original_sd_configs/v1-inference.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
iopaint/model/original_sd_configs/v2-inference-v.yaml
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-4
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
parameterization: "v"
|
6 |
+
linear_start: 0.00085
|
7 |
+
linear_end: 0.0120
|
8 |
+
num_timesteps_cond: 1
|
9 |
+
log_every_t: 200
|
10 |
+
timesteps: 1000
|
11 |
+
first_stage_key: "jpg"
|
12 |
+
cond_stage_key: "txt"
|
13 |
+
image_size: 64
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: false
|
16 |
+
conditioning_key: crossattn
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_factor: 0.18215
|
19 |
+
use_ema: False # we set this to false because this is an inference only config
|
20 |
+
|
21 |
+
unet_config:
|
22 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
23 |
+
params:
|
24 |
+
use_checkpoint: True
|
25 |
+
use_fp16: True
|
26 |
+
image_size: 32 # unused
|
27 |
+
in_channels: 4
|
28 |
+
out_channels: 4
|
29 |
+
model_channels: 320
|
30 |
+
attention_resolutions: [ 4, 2, 1 ]
|
31 |
+
num_res_blocks: 2
|
32 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
33 |
+
num_head_channels: 64 # need to fix for flash-attn
|
34 |
+
use_spatial_transformer: True
|
35 |
+
use_linear_in_transformer: True
|
36 |
+
transformer_depth: 1
|
37 |
+
context_dim: 1024
|
38 |
+
legacy: False
|
39 |
+
|
40 |
+
first_stage_config:
|
41 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
42 |
+
params:
|
43 |
+
embed_dim: 4
|
44 |
+
monitor: val/rec_loss
|
45 |
+
ddconfig:
|
46 |
+
#attn_type: "vanilla-xformers"
|
47 |
+
double_z: true
|
48 |
+
z_channels: 4
|
49 |
+
resolution: 256
|
50 |
+
in_channels: 3
|
51 |
+
out_ch: 3
|
52 |
+
ch: 128
|
53 |
+
ch_mult:
|
54 |
+
- 1
|
55 |
+
- 2
|
56 |
+
- 4
|
57 |
+
- 4
|
58 |
+
num_res_blocks: 2
|
59 |
+
attn_resolutions: []
|
60 |
+
dropout: 0.0
|
61 |
+
lossconfig:
|
62 |
+
target: torch.nn.Identity
|
63 |
+
|
64 |
+
cond_stage_config:
|
65 |
+
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
66 |
+
params:
|
67 |
+
freeze: True
|
68 |
+
layer: "penultimate"
|
iopaint/model/utils.py
ADDED
@@ -0,0 +1,1033 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
import traceback
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import collections
|
10 |
+
from itertools import repeat
|
11 |
+
|
12 |
+
from diffusers import (
|
13 |
+
DDIMScheduler,
|
14 |
+
PNDMScheduler,
|
15 |
+
LMSDiscreteScheduler,
|
16 |
+
EulerDiscreteScheduler,
|
17 |
+
EulerAncestralDiscreteScheduler,
|
18 |
+
DPMSolverMultistepScheduler,
|
19 |
+
UniPCMultistepScheduler,
|
20 |
+
LCMScheduler,
|
21 |
+
DPMSolverSinglestepScheduler,
|
22 |
+
KDPM2DiscreteScheduler,
|
23 |
+
KDPM2AncestralDiscreteScheduler,
|
24 |
+
HeunDiscreteScheduler,
|
25 |
+
)
|
26 |
+
from loguru import logger
|
27 |
+
|
28 |
+
from iopaint.schema import SDSampler
|
29 |
+
from torch import conv2d, conv_transpose2d
|
30 |
+
|
31 |
+
|
32 |
+
def make_beta_schedule(
|
33 |
+
device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
34 |
+
):
|
35 |
+
if schedule == "linear":
|
36 |
+
betas = (
|
37 |
+
torch.linspace(
|
38 |
+
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
|
39 |
+
)
|
40 |
+
** 2
|
41 |
+
)
|
42 |
+
|
43 |
+
elif schedule == "cosine":
|
44 |
+
timesteps = (
|
45 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
46 |
+
).to(device)
|
47 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
48 |
+
alphas = torch.cos(alphas).pow(2).to(device)
|
49 |
+
alphas = alphas / alphas[0]
|
50 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
51 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
52 |
+
|
53 |
+
elif schedule == "sqrt_linear":
|
54 |
+
betas = torch.linspace(
|
55 |
+
linear_start, linear_end, n_timestep, dtype=torch.float64
|
56 |
+
)
|
57 |
+
elif schedule == "sqrt":
|
58 |
+
betas = (
|
59 |
+
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
60 |
+
** 0.5
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
64 |
+
return betas.numpy()
|
65 |
+
|
66 |
+
|
67 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
68 |
+
# select alphas for computing the variance schedule
|
69 |
+
alphas = alphacums[ddim_timesteps]
|
70 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
71 |
+
|
72 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
73 |
+
sigmas = eta * np.sqrt(
|
74 |
+
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
|
75 |
+
)
|
76 |
+
if verbose:
|
77 |
+
print(
|
78 |
+
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
|
79 |
+
)
|
80 |
+
print(
|
81 |
+
f"For the chosen value of eta, which is {eta}, "
|
82 |
+
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
|
83 |
+
)
|
84 |
+
return sigmas, alphas, alphas_prev
|
85 |
+
|
86 |
+
|
87 |
+
def make_ddim_timesteps(
|
88 |
+
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
|
89 |
+
):
|
90 |
+
if ddim_discr_method == "uniform":
|
91 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
92 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
93 |
+
elif ddim_discr_method == "quad":
|
94 |
+
ddim_timesteps = (
|
95 |
+
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
|
96 |
+
).astype(int)
|
97 |
+
else:
|
98 |
+
raise NotImplementedError(
|
99 |
+
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
100 |
+
)
|
101 |
+
|
102 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
103 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
104 |
+
steps_out = ddim_timesteps + 1
|
105 |
+
if verbose:
|
106 |
+
print(f"Selected timesteps for ddim sampler: {steps_out}")
|
107 |
+
return steps_out
|
108 |
+
|
109 |
+
|
110 |
+
def noise_like(shape, device, repeat=False):
|
111 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
112 |
+
shape[0], *((1,) * (len(shape) - 1))
|
113 |
+
)
|
114 |
+
noise = lambda: torch.randn(shape, device=device)
|
115 |
+
return repeat_noise() if repeat else noise()
|
116 |
+
|
117 |
+
|
118 |
+
def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False):
|
119 |
+
"""
|
120 |
+
Create sinusoidal timestep embeddings.
|
121 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
122 |
+
These may be fractional.
|
123 |
+
:param dim: the dimension of the output.
|
124 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
125 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
126 |
+
"""
|
127 |
+
half = dim // 2
|
128 |
+
freqs = torch.exp(
|
129 |
+
-math.log(max_period)
|
130 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
131 |
+
/ half
|
132 |
+
).to(device=device)
|
133 |
+
|
134 |
+
args = timesteps[:, None].float() * freqs[None]
|
135 |
+
|
136 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
137 |
+
if dim % 2:
|
138 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
139 |
+
return embedding
|
140 |
+
|
141 |
+
|
142 |
+
###### MAT and FcF #######
|
143 |
+
|
144 |
+
|
145 |
+
def normalize_2nd_moment(x, dim=1):
|
146 |
+
return (
|
147 |
+
x * (x.square().mean(dim=dim, keepdim=True) + torch.finfo(x.dtype).eps).rsqrt()
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
class EasyDict(dict):
|
152 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
153 |
+
|
154 |
+
def __getattr__(self, name: str) -> Any:
|
155 |
+
try:
|
156 |
+
return self[name]
|
157 |
+
except KeyError:
|
158 |
+
raise AttributeError(name)
|
159 |
+
|
160 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
161 |
+
self[name] = value
|
162 |
+
|
163 |
+
def __delattr__(self, name: str) -> None:
|
164 |
+
del self[name]
|
165 |
+
|
166 |
+
|
167 |
+
def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None):
|
168 |
+
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops."""
|
169 |
+
assert isinstance(x, torch.Tensor)
|
170 |
+
assert clamp is None or clamp >= 0
|
171 |
+
spec = activation_funcs[act]
|
172 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
173 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
174 |
+
clamp = float(clamp if clamp is not None else -1)
|
175 |
+
|
176 |
+
# Add bias.
|
177 |
+
if b is not None:
|
178 |
+
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
179 |
+
assert 0 <= dim < x.ndim
|
180 |
+
assert b.shape[0] == x.shape[dim]
|
181 |
+
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
182 |
+
|
183 |
+
# Evaluate activation function.
|
184 |
+
alpha = float(alpha)
|
185 |
+
x = spec.func(x, alpha=alpha)
|
186 |
+
|
187 |
+
# Scale by gain.
|
188 |
+
gain = float(gain)
|
189 |
+
if gain != 1:
|
190 |
+
x = x * gain
|
191 |
+
|
192 |
+
# Clamp.
|
193 |
+
if clamp >= 0:
|
194 |
+
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
195 |
+
return x
|
196 |
+
|
197 |
+
|
198 |
+
def bias_act(
|
199 |
+
x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref"
|
200 |
+
):
|
201 |
+
r"""Fused bias and activation function.
|
202 |
+
|
203 |
+
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
204 |
+
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
205 |
+
the fused op is considerably more efficient than performing the same calculation
|
206 |
+
using standard PyTorch ops. It supports first and second order gradients,
|
207 |
+
but not third order gradients.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
x: Input activation tensor. Can be of any shape.
|
211 |
+
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
212 |
+
as `x`. The shape must be known, and it must match the dimension of `x`
|
213 |
+
corresponding to `dim`.
|
214 |
+
dim: The dimension in `x` corresponding to the elements of `b`.
|
215 |
+
The value of `dim` is ignored if `b` is not specified.
|
216 |
+
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
217 |
+
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
218 |
+
See `activation_funcs` for a full list. `None` is not allowed.
|
219 |
+
alpha: Shape parameter for the activation function, or `None` to use the default.
|
220 |
+
gain: Scaling factor for the output tensor, or `None` to use default.
|
221 |
+
See `activation_funcs` for the default scaling of each activation function.
|
222 |
+
If unsure, consider specifying 1.
|
223 |
+
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
224 |
+
the clamping (default).
|
225 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
Tensor of the same shape and datatype as `x`.
|
229 |
+
"""
|
230 |
+
assert isinstance(x, torch.Tensor)
|
231 |
+
assert impl in ["ref", "cuda"]
|
232 |
+
return _bias_act_ref(
|
233 |
+
x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp
|
234 |
+
)
|
235 |
+
|
236 |
+
|
237 |
+
def _get_filter_size(f):
|
238 |
+
if f is None:
|
239 |
+
return 1, 1
|
240 |
+
|
241 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
242 |
+
fw = f.shape[-1]
|
243 |
+
fh = f.shape[0]
|
244 |
+
|
245 |
+
fw = int(fw)
|
246 |
+
fh = int(fh)
|
247 |
+
assert fw >= 1 and fh >= 1
|
248 |
+
return fw, fh
|
249 |
+
|
250 |
+
|
251 |
+
def _get_weight_shape(w):
|
252 |
+
shape = [int(sz) for sz in w.shape]
|
253 |
+
return shape
|
254 |
+
|
255 |
+
|
256 |
+
def _parse_scaling(scaling):
|
257 |
+
if isinstance(scaling, int):
|
258 |
+
scaling = [scaling, scaling]
|
259 |
+
assert isinstance(scaling, (list, tuple))
|
260 |
+
assert all(isinstance(x, int) for x in scaling)
|
261 |
+
sx, sy = scaling
|
262 |
+
assert sx >= 1 and sy >= 1
|
263 |
+
return sx, sy
|
264 |
+
|
265 |
+
|
266 |
+
def _parse_padding(padding):
|
267 |
+
if isinstance(padding, int):
|
268 |
+
padding = [padding, padding]
|
269 |
+
assert isinstance(padding, (list, tuple))
|
270 |
+
assert all(isinstance(x, int) for x in padding)
|
271 |
+
if len(padding) == 2:
|
272 |
+
padx, pady = padding
|
273 |
+
padding = [padx, padx, pady, pady]
|
274 |
+
padx0, padx1, pady0, pady1 = padding
|
275 |
+
return padx0, padx1, pady0, pady1
|
276 |
+
|
277 |
+
|
278 |
+
def setup_filter(
|
279 |
+
f,
|
280 |
+
device=torch.device("cpu"),
|
281 |
+
normalize=True,
|
282 |
+
flip_filter=False,
|
283 |
+
gain=1,
|
284 |
+
separable=None,
|
285 |
+
):
|
286 |
+
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
f: Torch tensor, numpy array, or python list of the shape
|
290 |
+
`[filter_height, filter_width]` (non-separable),
|
291 |
+
`[filter_taps]` (separable),
|
292 |
+
`[]` (impulse), or
|
293 |
+
`None` (identity).
|
294 |
+
device: Result device (default: cpu).
|
295 |
+
normalize: Normalize the filter so that it retains the magnitude
|
296 |
+
for constant input signal (DC)? (default: True).
|
297 |
+
flip_filter: Flip the filter? (default: False).
|
298 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
299 |
+
separable: Return a separable filter? (default: select automatically).
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
Float32 tensor of the shape
|
303 |
+
`[filter_height, filter_width]` (non-separable) or
|
304 |
+
`[filter_taps]` (separable).
|
305 |
+
"""
|
306 |
+
# Validate.
|
307 |
+
if f is None:
|
308 |
+
f = 1
|
309 |
+
f = torch.as_tensor(f, dtype=torch.float32)
|
310 |
+
assert f.ndim in [0, 1, 2]
|
311 |
+
assert f.numel() > 0
|
312 |
+
if f.ndim == 0:
|
313 |
+
f = f[np.newaxis]
|
314 |
+
|
315 |
+
# Separable?
|
316 |
+
if separable is None:
|
317 |
+
separable = f.ndim == 1 and f.numel() >= 8
|
318 |
+
if f.ndim == 1 and not separable:
|
319 |
+
f = f.ger(f)
|
320 |
+
assert f.ndim == (1 if separable else 2)
|
321 |
+
|
322 |
+
# Apply normalize, flip, gain, and device.
|
323 |
+
if normalize:
|
324 |
+
f /= f.sum()
|
325 |
+
if flip_filter:
|
326 |
+
f = f.flip(list(range(f.ndim)))
|
327 |
+
f = f * (gain ** (f.ndim / 2))
|
328 |
+
f = f.to(device=device)
|
329 |
+
return f
|
330 |
+
|
331 |
+
|
332 |
+
def _ntuple(n):
|
333 |
+
def parse(x):
|
334 |
+
if isinstance(x, collections.abc.Iterable):
|
335 |
+
return x
|
336 |
+
return tuple(repeat(x, n))
|
337 |
+
|
338 |
+
return parse
|
339 |
+
|
340 |
+
|
341 |
+
to_2tuple = _ntuple(2)
|
342 |
+
|
343 |
+
activation_funcs = {
|
344 |
+
"linear": EasyDict(
|
345 |
+
func=lambda x, **_: x,
|
346 |
+
def_alpha=0,
|
347 |
+
def_gain=1,
|
348 |
+
cuda_idx=1,
|
349 |
+
ref="",
|
350 |
+
has_2nd_grad=False,
|
351 |
+
),
|
352 |
+
"relu": EasyDict(
|
353 |
+
func=lambda x, **_: torch.nn.functional.relu(x),
|
354 |
+
def_alpha=0,
|
355 |
+
def_gain=np.sqrt(2),
|
356 |
+
cuda_idx=2,
|
357 |
+
ref="y",
|
358 |
+
has_2nd_grad=False,
|
359 |
+
),
|
360 |
+
"lrelu": EasyDict(
|
361 |
+
func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
|
362 |
+
def_alpha=0.2,
|
363 |
+
def_gain=np.sqrt(2),
|
364 |
+
cuda_idx=3,
|
365 |
+
ref="y",
|
366 |
+
has_2nd_grad=False,
|
367 |
+
),
|
368 |
+
"tanh": EasyDict(
|
369 |
+
func=lambda x, **_: torch.tanh(x),
|
370 |
+
def_alpha=0,
|
371 |
+
def_gain=1,
|
372 |
+
cuda_idx=4,
|
373 |
+
ref="y",
|
374 |
+
has_2nd_grad=True,
|
375 |
+
),
|
376 |
+
"sigmoid": EasyDict(
|
377 |
+
func=lambda x, **_: torch.sigmoid(x),
|
378 |
+
def_alpha=0,
|
379 |
+
def_gain=1,
|
380 |
+
cuda_idx=5,
|
381 |
+
ref="y",
|
382 |
+
has_2nd_grad=True,
|
383 |
+
),
|
384 |
+
"elu": EasyDict(
|
385 |
+
func=lambda x, **_: torch.nn.functional.elu(x),
|
386 |
+
def_alpha=0,
|
387 |
+
def_gain=1,
|
388 |
+
cuda_idx=6,
|
389 |
+
ref="y",
|
390 |
+
has_2nd_grad=True,
|
391 |
+
),
|
392 |
+
"selu": EasyDict(
|
393 |
+
func=lambda x, **_: torch.nn.functional.selu(x),
|
394 |
+
def_alpha=0,
|
395 |
+
def_gain=1,
|
396 |
+
cuda_idx=7,
|
397 |
+
ref="y",
|
398 |
+
has_2nd_grad=True,
|
399 |
+
),
|
400 |
+
"softplus": EasyDict(
|
401 |
+
func=lambda x, **_: torch.nn.functional.softplus(x),
|
402 |
+
def_alpha=0,
|
403 |
+
def_gain=1,
|
404 |
+
cuda_idx=8,
|
405 |
+
ref="y",
|
406 |
+
has_2nd_grad=True,
|
407 |
+
),
|
408 |
+
"swish": EasyDict(
|
409 |
+
func=lambda x, **_: torch.sigmoid(x) * x,
|
410 |
+
def_alpha=0,
|
411 |
+
def_gain=np.sqrt(2),
|
412 |
+
cuda_idx=9,
|
413 |
+
ref="x",
|
414 |
+
has_2nd_grad=True,
|
415 |
+
),
|
416 |
+
}
|
417 |
+
|
418 |
+
|
419 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
|
420 |
+
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
421 |
+
|
422 |
+
Performs the following sequence of operations for each channel:
|
423 |
+
|
424 |
+
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
425 |
+
|
426 |
+
2. Pad the image with the specified number of zeros on each side (`padding`).
|
427 |
+
Negative padding corresponds to cropping the image.
|
428 |
+
|
429 |
+
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
430 |
+
so that the footprint of all output pixels lies within the input image.
|
431 |
+
|
432 |
+
4. Downsample the image by keeping every Nth pixel (`down`).
|
433 |
+
|
434 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
435 |
+
The fused op is considerably more efficient than performing the same calculation
|
436 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
x: Float32/float64/float16 input tensor of the shape
|
440 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
441 |
+
f: Float32 FIR filter of the shape
|
442 |
+
`[filter_height, filter_width]` (non-separable),
|
443 |
+
`[filter_taps]` (separable), or
|
444 |
+
`None` (identity).
|
445 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
446 |
+
`[x, y]` (default: 1).
|
447 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
448 |
+
`[x, y]` (default: 1).
|
449 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
450 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
451 |
+
(default: 0).
|
452 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
453 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
454 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
455 |
+
|
456 |
+
Returns:
|
457 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
458 |
+
"""
|
459 |
+
# assert isinstance(x, torch.Tensor)
|
460 |
+
# assert impl in ['ref', 'cuda']
|
461 |
+
return _upfirdn2d_ref(
|
462 |
+
x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
|
463 |
+
)
|
464 |
+
|
465 |
+
|
466 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
467 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
|
468 |
+
# Validate arguments.
|
469 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
470 |
+
if f is None:
|
471 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
472 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
473 |
+
assert not f.requires_grad
|
474 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
475 |
+
# upx, upy = _parse_scaling(up)
|
476 |
+
# downx, downy = _parse_scaling(down)
|
477 |
+
|
478 |
+
upx, upy = up, up
|
479 |
+
downx, downy = down, down
|
480 |
+
|
481 |
+
# padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
482 |
+
padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3]
|
483 |
+
|
484 |
+
# Upsample by inserting zeros.
|
485 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
486 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
487 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
488 |
+
|
489 |
+
# Pad or crop.
|
490 |
+
x = torch.nn.functional.pad(
|
491 |
+
x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
|
492 |
+
)
|
493 |
+
x = x[
|
494 |
+
:,
|
495 |
+
:,
|
496 |
+
max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
|
497 |
+
max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
|
498 |
+
]
|
499 |
+
|
500 |
+
# Setup filter.
|
501 |
+
f = f * (gain ** (f.ndim / 2))
|
502 |
+
f = f.to(x.dtype)
|
503 |
+
if not flip_filter:
|
504 |
+
f = f.flip(list(range(f.ndim)))
|
505 |
+
|
506 |
+
# Convolve with the filter.
|
507 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
508 |
+
if f.ndim == 4:
|
509 |
+
x = conv2d(input=x, weight=f, groups=num_channels)
|
510 |
+
else:
|
511 |
+
x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
512 |
+
x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
513 |
+
|
514 |
+
# Downsample by throwing away pixels.
|
515 |
+
x = x[:, :, ::downy, ::downx]
|
516 |
+
return x
|
517 |
+
|
518 |
+
|
519 |
+
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
|
520 |
+
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
521 |
+
|
522 |
+
By default, the result is padded so that its shape is a fraction of the input.
|
523 |
+
User-specified padding is applied on top of that, with negative values
|
524 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
525 |
+
|
526 |
+
Args:
|
527 |
+
x: Float32/float64/float16 input tensor of the shape
|
528 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
529 |
+
f: Float32 FIR filter of the shape
|
530 |
+
`[filter_height, filter_width]` (non-separable),
|
531 |
+
`[filter_taps]` (separable), or
|
532 |
+
`None` (identity).
|
533 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
534 |
+
`[x, y]` (default: 1).
|
535 |
+
padding: Padding with respect to the input. Can be a single number or a
|
536 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
537 |
+
(default: 0).
|
538 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
539 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
540 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
541 |
+
|
542 |
+
Returns:
|
543 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
544 |
+
"""
|
545 |
+
downx, downy = _parse_scaling(down)
|
546 |
+
# padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
547 |
+
padx0, padx1, pady0, pady1 = padding, padding, padding, padding
|
548 |
+
|
549 |
+
fw, fh = _get_filter_size(f)
|
550 |
+
p = [
|
551 |
+
padx0 + (fw - downx + 1) // 2,
|
552 |
+
padx1 + (fw - downx) // 2,
|
553 |
+
pady0 + (fh - downy + 1) // 2,
|
554 |
+
pady1 + (fh - downy) // 2,
|
555 |
+
]
|
556 |
+
return upfirdn2d(
|
557 |
+
x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl
|
558 |
+
)
|
559 |
+
|
560 |
+
|
561 |
+
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
|
562 |
+
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
563 |
+
|
564 |
+
By default, the result is padded so that its shape is a multiple of the input.
|
565 |
+
User-specified padding is applied on top of that, with negative values
|
566 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
567 |
+
|
568 |
+
Args:
|
569 |
+
x: Float32/float64/float16 input tensor of the shape
|
570 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
571 |
+
f: Float32 FIR filter of the shape
|
572 |
+
`[filter_height, filter_width]` (non-separable),
|
573 |
+
`[filter_taps]` (separable), or
|
574 |
+
`None` (identity).
|
575 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
576 |
+
`[x, y]` (default: 1).
|
577 |
+
padding: Padding with respect to the output. Can be a single number or a
|
578 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
579 |
+
(default: 0).
|
580 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
581 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
582 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
583 |
+
|
584 |
+
Returns:
|
585 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
586 |
+
"""
|
587 |
+
upx, upy = _parse_scaling(up)
|
588 |
+
# upx, upy = up, up
|
589 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
590 |
+
# padx0, padx1, pady0, pady1 = padding, padding, padding, padding
|
591 |
+
fw, fh = _get_filter_size(f)
|
592 |
+
p = [
|
593 |
+
padx0 + (fw + upx - 1) // 2,
|
594 |
+
padx1 + (fw - upx) // 2,
|
595 |
+
pady0 + (fh + upy - 1) // 2,
|
596 |
+
pady1 + (fh - upy) // 2,
|
597 |
+
]
|
598 |
+
return upfirdn2d(
|
599 |
+
x,
|
600 |
+
f,
|
601 |
+
up=up,
|
602 |
+
padding=p,
|
603 |
+
flip_filter=flip_filter,
|
604 |
+
gain=gain * upx * upy,
|
605 |
+
impl=impl,
|
606 |
+
)
|
607 |
+
|
608 |
+
|
609 |
+
class MinibatchStdLayer(torch.nn.Module):
|
610 |
+
def __init__(self, group_size, num_channels=1):
|
611 |
+
super().__init__()
|
612 |
+
self.group_size = group_size
|
613 |
+
self.num_channels = num_channels
|
614 |
+
|
615 |
+
def forward(self, x):
|
616 |
+
N, C, H, W = x.shape
|
617 |
+
G = (
|
618 |
+
torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N))
|
619 |
+
if self.group_size is not None
|
620 |
+
else N
|
621 |
+
)
|
622 |
+
F = self.num_channels
|
623 |
+
c = C // F
|
624 |
+
|
625 |
+
y = x.reshape(
|
626 |
+
G, -1, F, c, H, W
|
627 |
+
) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
|
628 |
+
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
|
629 |
+
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
|
630 |
+
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
|
631 |
+
y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels.
|
632 |
+
y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
|
633 |
+
y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
|
634 |
+
x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
|
635 |
+
return x
|
636 |
+
|
637 |
+
|
638 |
+
class FullyConnectedLayer(torch.nn.Module):
|
639 |
+
def __init__(
|
640 |
+
self,
|
641 |
+
in_features, # Number of input features.
|
642 |
+
out_features, # Number of output features.
|
643 |
+
bias=True, # Apply additive bias before the activation function?
|
644 |
+
activation="linear", # Activation function: 'relu', 'lrelu', etc.
|
645 |
+
lr_multiplier=1, # Learning rate multiplier.
|
646 |
+
bias_init=0, # Initial value for the additive bias.
|
647 |
+
):
|
648 |
+
super().__init__()
|
649 |
+
self.weight = torch.nn.Parameter(
|
650 |
+
torch.randn([out_features, in_features]) / lr_multiplier
|
651 |
+
)
|
652 |
+
self.bias = (
|
653 |
+
torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
|
654 |
+
if bias
|
655 |
+
else None
|
656 |
+
)
|
657 |
+
self.activation = activation
|
658 |
+
|
659 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
660 |
+
self.bias_gain = lr_multiplier
|
661 |
+
|
662 |
+
def forward(self, x):
|
663 |
+
w = self.weight * self.weight_gain
|
664 |
+
b = self.bias
|
665 |
+
if b is not None and self.bias_gain != 1:
|
666 |
+
b = b * self.bias_gain
|
667 |
+
|
668 |
+
if self.activation == "linear" and b is not None:
|
669 |
+
# out = torch.addmm(b.unsqueeze(0), x, w.t())
|
670 |
+
x = x.matmul(w.t())
|
671 |
+
out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)])
|
672 |
+
else:
|
673 |
+
x = x.matmul(w.t())
|
674 |
+
out = bias_act(x, b, act=self.activation, dim=x.ndim - 1)
|
675 |
+
return out
|
676 |
+
|
677 |
+
|
678 |
+
def _conv2d_wrapper(
|
679 |
+
x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True
|
680 |
+
):
|
681 |
+
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations."""
|
682 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
683 |
+
|
684 |
+
# Flip weight if requested.
|
685 |
+
if (
|
686 |
+
not flip_weight
|
687 |
+
): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
688 |
+
w = w.flip([2, 3])
|
689 |
+
|
690 |
+
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
|
691 |
+
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
|
692 |
+
if (
|
693 |
+
kw == 1
|
694 |
+
and kh == 1
|
695 |
+
and stride == 1
|
696 |
+
and padding in [0, [0, 0], (0, 0)]
|
697 |
+
and not transpose
|
698 |
+
):
|
699 |
+
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
|
700 |
+
if out_channels <= 4 and groups == 1:
|
701 |
+
in_shape = x.shape
|
702 |
+
x = w.squeeze(3).squeeze(2) @ x.reshape(
|
703 |
+
[in_shape[0], in_channels_per_group, -1]
|
704 |
+
)
|
705 |
+
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
|
706 |
+
else:
|
707 |
+
x = x.to(memory_format=torch.contiguous_format)
|
708 |
+
w = w.to(memory_format=torch.contiguous_format)
|
709 |
+
x = conv2d(x, w, groups=groups)
|
710 |
+
return x.to(memory_format=torch.channels_last)
|
711 |
+
|
712 |
+
# Otherwise => execute using conv2d_gradfix.
|
713 |
+
op = conv_transpose2d if transpose else conv2d
|
714 |
+
return op(x, w, stride=stride, padding=padding, groups=groups)
|
715 |
+
|
716 |
+
|
717 |
+
def conv2d_resample(
|
718 |
+
x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False
|
719 |
+
):
|
720 |
+
r"""2D convolution with optional up/downsampling.
|
721 |
+
|
722 |
+
Padding is performed only once at the beginning, not between the operations.
|
723 |
+
|
724 |
+
Args:
|
725 |
+
x: Input tensor of shape
|
726 |
+
`[batch_size, in_channels, in_height, in_width]`.
|
727 |
+
w: Weight tensor of shape
|
728 |
+
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
729 |
+
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
730 |
+
calling setup_filter(). None = identity (default).
|
731 |
+
up: Integer upsampling factor (default: 1).
|
732 |
+
down: Integer downsampling factor (default: 1).
|
733 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
734 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
735 |
+
(default: 0).
|
736 |
+
groups: Split input channels into N groups (default: 1).
|
737 |
+
flip_weight: False = convolution, True = correlation (default: True).
|
738 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
739 |
+
|
740 |
+
Returns:
|
741 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
742 |
+
"""
|
743 |
+
# Validate arguments.
|
744 |
+
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
745 |
+
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
746 |
+
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2])
|
747 |
+
assert isinstance(up, int) and (up >= 1)
|
748 |
+
assert isinstance(down, int) and (down >= 1)
|
749 |
+
# assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
|
750 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
751 |
+
fw, fh = _get_filter_size(f)
|
752 |
+
# px0, px1, py0, py1 = _parse_padding(padding)
|
753 |
+
px0, px1, py0, py1 = padding, padding, padding, padding
|
754 |
+
|
755 |
+
# Adjust padding to account for up/downsampling.
|
756 |
+
if up > 1:
|
757 |
+
px0 += (fw + up - 1) // 2
|
758 |
+
px1 += (fw - up) // 2
|
759 |
+
py0 += (fh + up - 1) // 2
|
760 |
+
py1 += (fh - up) // 2
|
761 |
+
if down > 1:
|
762 |
+
px0 += (fw - down + 1) // 2
|
763 |
+
px1 += (fw - down) // 2
|
764 |
+
py0 += (fh - down + 1) // 2
|
765 |
+
py1 += (fh - down) // 2
|
766 |
+
|
767 |
+
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
768 |
+
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
769 |
+
x = upfirdn2d(
|
770 |
+
x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter
|
771 |
+
)
|
772 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
773 |
+
return x
|
774 |
+
|
775 |
+
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
776 |
+
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
777 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
778 |
+
x = upfirdn2d(
|
779 |
+
x=x,
|
780 |
+
f=f,
|
781 |
+
up=up,
|
782 |
+
padding=[px0, px1, py0, py1],
|
783 |
+
gain=up**2,
|
784 |
+
flip_filter=flip_filter,
|
785 |
+
)
|
786 |
+
return x
|
787 |
+
|
788 |
+
# Fast path: downsampling only => use strided convolution.
|
789 |
+
if down > 1 and up == 1:
|
790 |
+
x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
|
791 |
+
x = _conv2d_wrapper(
|
792 |
+
x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight
|
793 |
+
)
|
794 |
+
return x
|
795 |
+
|
796 |
+
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
797 |
+
if up > 1:
|
798 |
+
if groups == 1:
|
799 |
+
w = w.transpose(0, 1)
|
800 |
+
else:
|
801 |
+
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
802 |
+
w = w.transpose(1, 2)
|
803 |
+
w = w.reshape(
|
804 |
+
groups * in_channels_per_group, out_channels // groups, kh, kw
|
805 |
+
)
|
806 |
+
px0 -= kw - 1
|
807 |
+
px1 -= kw - up
|
808 |
+
py0 -= kh - 1
|
809 |
+
py1 -= kh - up
|
810 |
+
pxt = max(min(-px0, -px1), 0)
|
811 |
+
pyt = max(min(-py0, -py1), 0)
|
812 |
+
x = _conv2d_wrapper(
|
813 |
+
x=x,
|
814 |
+
w=w,
|
815 |
+
stride=up,
|
816 |
+
padding=[pyt, pxt],
|
817 |
+
groups=groups,
|
818 |
+
transpose=True,
|
819 |
+
flip_weight=(not flip_weight),
|
820 |
+
)
|
821 |
+
x = upfirdn2d(
|
822 |
+
x=x,
|
823 |
+
f=f,
|
824 |
+
padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
|
825 |
+
gain=up**2,
|
826 |
+
flip_filter=flip_filter,
|
827 |
+
)
|
828 |
+
if down > 1:
|
829 |
+
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
830 |
+
return x
|
831 |
+
|
832 |
+
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
833 |
+
if up == 1 and down == 1:
|
834 |
+
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
835 |
+
return _conv2d_wrapper(
|
836 |
+
x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight
|
837 |
+
)
|
838 |
+
|
839 |
+
# Fallback: Generic reference implementation.
|
840 |
+
x = upfirdn2d(
|
841 |
+
x=x,
|
842 |
+
f=(f if up > 1 else None),
|
843 |
+
up=up,
|
844 |
+
padding=[px0, px1, py0, py1],
|
845 |
+
gain=up**2,
|
846 |
+
flip_filter=flip_filter,
|
847 |
+
)
|
848 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
849 |
+
if down > 1:
|
850 |
+
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
851 |
+
return x
|
852 |
+
|
853 |
+
|
854 |
+
class Conv2dLayer(torch.nn.Module):
|
855 |
+
def __init__(
|
856 |
+
self,
|
857 |
+
in_channels, # Number of input channels.
|
858 |
+
out_channels, # Number of output channels.
|
859 |
+
kernel_size, # Width and height of the convolution kernel.
|
860 |
+
bias=True, # Apply additive bias before the activation function?
|
861 |
+
activation="linear", # Activation function: 'relu', 'lrelu', etc.
|
862 |
+
up=1, # Integer upsampling factor.
|
863 |
+
down=1, # Integer downsampling factor.
|
864 |
+
resample_filter=[
|
865 |
+
1,
|
866 |
+
3,
|
867 |
+
3,
|
868 |
+
1,
|
869 |
+
], # Low-pass filter to apply when resampling activations.
|
870 |
+
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
871 |
+
channels_last=False, # Expect the input to have memory_format=channels_last?
|
872 |
+
trainable=True, # Update the weights of this layer during training?
|
873 |
+
):
|
874 |
+
super().__init__()
|
875 |
+
self.activation = activation
|
876 |
+
self.up = up
|
877 |
+
self.down = down
|
878 |
+
self.register_buffer("resample_filter", setup_filter(resample_filter))
|
879 |
+
self.conv_clamp = conv_clamp
|
880 |
+
self.padding = kernel_size // 2
|
881 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
|
882 |
+
self.act_gain = activation_funcs[activation].def_gain
|
883 |
+
|
884 |
+
memory_format = (
|
885 |
+
torch.channels_last if channels_last else torch.contiguous_format
|
886 |
+
)
|
887 |
+
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
|
888 |
+
memory_format=memory_format
|
889 |
+
)
|
890 |
+
bias = torch.zeros([out_channels]) if bias else None
|
891 |
+
if trainable:
|
892 |
+
self.weight = torch.nn.Parameter(weight)
|
893 |
+
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
894 |
+
else:
|
895 |
+
self.register_buffer("weight", weight)
|
896 |
+
if bias is not None:
|
897 |
+
self.register_buffer("bias", bias)
|
898 |
+
else:
|
899 |
+
self.bias = None
|
900 |
+
|
901 |
+
def forward(self, x, gain=1):
|
902 |
+
w = self.weight * self.weight_gain
|
903 |
+
x = conv2d_resample(
|
904 |
+
x=x,
|
905 |
+
w=w,
|
906 |
+
f=self.resample_filter,
|
907 |
+
up=self.up,
|
908 |
+
down=self.down,
|
909 |
+
padding=self.padding,
|
910 |
+
)
|
911 |
+
|
912 |
+
act_gain = self.act_gain * gain
|
913 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
914 |
+
out = bias_act(
|
915 |
+
x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
|
916 |
+
)
|
917 |
+
return out
|
918 |
+
|
919 |
+
|
920 |
+
def torch_gc():
|
921 |
+
if torch.cuda.is_available():
|
922 |
+
torch.cuda.empty_cache()
|
923 |
+
torch.cuda.ipc_collect()
|
924 |
+
gc.collect()
|
925 |
+
|
926 |
+
|
927 |
+
def set_seed(seed: int):
|
928 |
+
random.seed(seed)
|
929 |
+
np.random.seed(seed)
|
930 |
+
torch.manual_seed(seed)
|
931 |
+
torch.cuda.manual_seed_all(seed)
|
932 |
+
|
933 |
+
|
934 |
+
def get_scheduler(sd_sampler, scheduler_config):
|
935 |
+
# https://github.com/huggingface/diffusers/issues/4167
|
936 |
+
keys_to_pop = ["use_karras_sigmas", "algorithm_type"]
|
937 |
+
scheduler_config = dict(scheduler_config)
|
938 |
+
for it in keys_to_pop:
|
939 |
+
scheduler_config.pop(it, None)
|
940 |
+
|
941 |
+
# fmt: off
|
942 |
+
samplers = {
|
943 |
+
SDSampler.dpm_plus_plus_2m: [DPMSolverMultistepScheduler],
|
944 |
+
SDSampler.dpm_plus_plus_2m_karras: [DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)],
|
945 |
+
SDSampler.dpm_plus_plus_2m_sde: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++")],
|
946 |
+
SDSampler.dpm_plus_plus_2m_sde_karras: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)],
|
947 |
+
SDSampler.dpm_plus_plus_sde: [DPMSolverSinglestepScheduler],
|
948 |
+
SDSampler.dpm_plus_plus_sde_karras: [DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)],
|
949 |
+
SDSampler.dpm2: [KDPM2DiscreteScheduler],
|
950 |
+
SDSampler.dpm2_karras: [KDPM2DiscreteScheduler, dict(use_karras_sigmas=True)],
|
951 |
+
SDSampler.dpm2_a: [KDPM2AncestralDiscreteScheduler],
|
952 |
+
SDSampler.dpm2_a_karras: [KDPM2AncestralDiscreteScheduler, dict(use_karras_sigmas=True)],
|
953 |
+
SDSampler.euler: [EulerDiscreteScheduler],
|
954 |
+
SDSampler.euler_a: [EulerAncestralDiscreteScheduler],
|
955 |
+
SDSampler.heun: [HeunDiscreteScheduler],
|
956 |
+
SDSampler.lms: [LMSDiscreteScheduler],
|
957 |
+
SDSampler.lms_karras: [LMSDiscreteScheduler, dict(use_karras_sigmas=True)],
|
958 |
+
SDSampler.ddim: [DDIMScheduler],
|
959 |
+
SDSampler.pndm: [PNDMScheduler],
|
960 |
+
SDSampler.uni_pc: [UniPCMultistepScheduler],
|
961 |
+
SDSampler.lcm: [LCMScheduler],
|
962 |
+
}
|
963 |
+
# fmt: on
|
964 |
+
if sd_sampler in samplers:
|
965 |
+
if len(samplers[sd_sampler]) == 2:
|
966 |
+
scheduler_cls, kwargs = samplers[sd_sampler]
|
967 |
+
else:
|
968 |
+
scheduler_cls, kwargs = samplers[sd_sampler][0], {}
|
969 |
+
return scheduler_cls.from_config(scheduler_config, **kwargs)
|
970 |
+
else:
|
971 |
+
raise ValueError(sd_sampler)
|
972 |
+
|
973 |
+
|
974 |
+
def is_local_files_only(**kwargs) -> bool:
|
975 |
+
from huggingface_hub.constants import HF_HUB_OFFLINE
|
976 |
+
|
977 |
+
return HF_HUB_OFFLINE or kwargs.get("local_files_only", False)
|
978 |
+
|
979 |
+
|
980 |
+
def handle_from_pretrained_exceptions(func, **kwargs):
|
981 |
+
try:
|
982 |
+
return func(**kwargs)
|
983 |
+
except ValueError as e:
|
984 |
+
if "You are trying to load the model files of the `variant=fp16`" in str(e):
|
985 |
+
logger.info("variant=fp16 not found, try revision=fp16")
|
986 |
+
try:
|
987 |
+
return func(**{**kwargs, "variant": None, "revision": "fp16"})
|
988 |
+
except Exception as e:
|
989 |
+
logger.info("revision=fp16 not found, try revision=main")
|
990 |
+
return func(**{**kwargs, "variant": None, "revision": "main"})
|
991 |
+
raise e
|
992 |
+
except OSError as e:
|
993 |
+
previous_traceback = traceback.format_exc()
|
994 |
+
if "RevisionNotFoundError: 404 Client Error." in previous_traceback:
|
995 |
+
logger.info("revision=fp16 not found, try revision=main")
|
996 |
+
return func(**{**kwargs, "variant": None, "revision": "main"})
|
997 |
+
elif "Max retries exceeded" in previous_traceback:
|
998 |
+
logger.exception(
|
999 |
+
"Fetching model from HuggingFace failed. "
|
1000 |
+
"If this is your first time downloading the model, you may need to set up proxy in terminal."
|
1001 |
+
"If the model has already been downloaded, you can add --local-files-only when starting."
|
1002 |
+
)
|
1003 |
+
exit(-1)
|
1004 |
+
raise e
|
1005 |
+
except Exception as e:
|
1006 |
+
raise e
|
1007 |
+
|
1008 |
+
|
1009 |
+
def get_torch_dtype(device, no_half: bool):
|
1010 |
+
device = str(device)
|
1011 |
+
use_fp16 = not no_half
|
1012 |
+
use_gpu = device == "cuda"
|
1013 |
+
# https://github.com/huggingface/diffusers/issues/4480
|
1014 |
+
# pipe.enable_attention_slicing and float16 will cause black output on mps
|
1015 |
+
# if device in ["cuda", "mps"] and use_fp16:
|
1016 |
+
if device in ["cuda"] and use_fp16:
|
1017 |
+
return use_gpu, torch.float16
|
1018 |
+
return use_gpu, torch.float32
|
1019 |
+
|
1020 |
+
|
1021 |
+
def enable_low_mem(pipe, enable: bool):
|
1022 |
+
if torch.backends.mps.is_available():
|
1023 |
+
# https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
|
1024 |
+
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers.
|
1025 |
+
if enable:
|
1026 |
+
pipe.enable_attention_slicing("max")
|
1027 |
+
else:
|
1028 |
+
# https://huggingface.co/docs/diffusers/optimization/mps
|
1029 |
+
# Devices with less than 64GB of memory are recommended to use enable_attention_slicing
|
1030 |
+
pipe.enable_attention_slicing()
|
1031 |
+
|
1032 |
+
if enable:
|
1033 |
+
pipe.vae.enable_tiling()
|
iopaint/model/zits.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from iopaint.helper import get_cache_path_by_url, load_jit_model, download_model
|
9 |
+
from iopaint.schema import InpaintRequest
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from .base import InpaintModel
|
13 |
+
|
14 |
+
ZITS_INPAINT_MODEL_URL = os.environ.get(
|
15 |
+
"ZITS_INPAINT_MODEL_URL",
|
16 |
+
"https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt",
|
17 |
+
)
|
18 |
+
ZITS_INPAINT_MODEL_MD5 = os.environ.get(
|
19 |
+
"ZITS_INPAINT_MODEL_MD5", "9978cc7157dc29699e42308d675b2154"
|
20 |
+
)
|
21 |
+
|
22 |
+
ZITS_EDGE_LINE_MODEL_URL = os.environ.get(
|
23 |
+
"ZITS_EDGE_LINE_MODEL_URL",
|
24 |
+
"https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt",
|
25 |
+
)
|
26 |
+
ZITS_EDGE_LINE_MODEL_MD5 = os.environ.get(
|
27 |
+
"ZITS_EDGE_LINE_MODEL_MD5", "55e31af21ba96bbf0c80603c76ea8c5f"
|
28 |
+
)
|
29 |
+
|
30 |
+
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get(
|
31 |
+
"ZITS_STRUCTURE_UPSAMPLE_MODEL_URL",
|
32 |
+
"https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt",
|
33 |
+
)
|
34 |
+
ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 = os.environ.get(
|
35 |
+
"ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5", "3d88a07211bd41b2ec8cc0d999f29927"
|
36 |
+
)
|
37 |
+
|
38 |
+
ZITS_WIRE_FRAME_MODEL_URL = os.environ.get(
|
39 |
+
"ZITS_WIRE_FRAME_MODEL_URL",
|
40 |
+
"https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt",
|
41 |
+
)
|
42 |
+
ZITS_WIRE_FRAME_MODEL_MD5 = os.environ.get(
|
43 |
+
"ZITS_WIRE_FRAME_MODEL_MD5", "a9727c63a8b48b65c905d351b21ce46b"
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
def resize(img, height, width, center_crop=False):
|
48 |
+
imgh, imgw = img.shape[0:2]
|
49 |
+
|
50 |
+
if center_crop and imgh != imgw:
|
51 |
+
# center crop
|
52 |
+
side = np.minimum(imgh, imgw)
|
53 |
+
j = (imgh - side) // 2
|
54 |
+
i = (imgw - side) // 2
|
55 |
+
img = img[j : j + side, i : i + side, ...]
|
56 |
+
|
57 |
+
if imgh > height and imgw > width:
|
58 |
+
inter = cv2.INTER_AREA
|
59 |
+
else:
|
60 |
+
inter = cv2.INTER_LINEAR
|
61 |
+
img = cv2.resize(img, (height, width), interpolation=inter)
|
62 |
+
|
63 |
+
return img
|
64 |
+
|
65 |
+
|
66 |
+
def to_tensor(img, scale=True, norm=False):
|
67 |
+
if img.ndim == 2:
|
68 |
+
img = img[:, :, np.newaxis]
|
69 |
+
c = img.shape[-1]
|
70 |
+
|
71 |
+
if scale:
|
72 |
+
img_t = torch.from_numpy(img).permute(2, 0, 1).float().div(255)
|
73 |
+
else:
|
74 |
+
img_t = torch.from_numpy(img).permute(2, 0, 1).float()
|
75 |
+
|
76 |
+
if norm:
|
77 |
+
mean = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1)
|
78 |
+
std = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1)
|
79 |
+
img_t = (img_t - mean) / std
|
80 |
+
return img_t
|
81 |
+
|
82 |
+
|
83 |
+
def load_masked_position_encoding(mask):
|
84 |
+
ones_filter = np.ones((3, 3), dtype=np.float32)
|
85 |
+
d_filter1 = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32)
|
86 |
+
d_filter2 = np.array([[0, 0, 0], [1, 1, 0], [1, 1, 0]], dtype=np.float32)
|
87 |
+
d_filter3 = np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]], dtype=np.float32)
|
88 |
+
d_filter4 = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=np.float32)
|
89 |
+
str_size = 256
|
90 |
+
pos_num = 128
|
91 |
+
|
92 |
+
ori_mask = mask.copy()
|
93 |
+
ori_h, ori_w = ori_mask.shape[0:2]
|
94 |
+
ori_mask = ori_mask / 255
|
95 |
+
mask = cv2.resize(mask, (str_size, str_size), interpolation=cv2.INTER_AREA)
|
96 |
+
mask[mask > 0] = 255
|
97 |
+
h, w = mask.shape[0:2]
|
98 |
+
mask3 = mask.copy()
|
99 |
+
mask3 = 1.0 - (mask3 / 255.0)
|
100 |
+
pos = np.zeros((h, w), dtype=np.int32)
|
101 |
+
direct = np.zeros((h, w, 4), dtype=np.int32)
|
102 |
+
i = 0
|
103 |
+
while np.sum(1 - mask3) > 0:
|
104 |
+
i += 1
|
105 |
+
mask3_ = cv2.filter2D(mask3, -1, ones_filter)
|
106 |
+
mask3_[mask3_ > 0] = 1
|
107 |
+
sub_mask = mask3_ - mask3
|
108 |
+
pos[sub_mask == 1] = i
|
109 |
+
|
110 |
+
m = cv2.filter2D(mask3, -1, d_filter1)
|
111 |
+
m[m > 0] = 1
|
112 |
+
m = m - mask3
|
113 |
+
direct[m == 1, 0] = 1
|
114 |
+
|
115 |
+
m = cv2.filter2D(mask3, -1, d_filter2)
|
116 |
+
m[m > 0] = 1
|
117 |
+
m = m - mask3
|
118 |
+
direct[m == 1, 1] = 1
|
119 |
+
|
120 |
+
m = cv2.filter2D(mask3, -1, d_filter3)
|
121 |
+
m[m > 0] = 1
|
122 |
+
m = m - mask3
|
123 |
+
direct[m == 1, 2] = 1
|
124 |
+
|
125 |
+
m = cv2.filter2D(mask3, -1, d_filter4)
|
126 |
+
m[m > 0] = 1
|
127 |
+
m = m - mask3
|
128 |
+
direct[m == 1, 3] = 1
|
129 |
+
|
130 |
+
mask3 = mask3_
|
131 |
+
|
132 |
+
abs_pos = pos.copy()
|
133 |
+
rel_pos = pos / (str_size / 2) # to 0~1 maybe larger than 1
|
134 |
+
rel_pos = (rel_pos * pos_num).astype(np.int32)
|
135 |
+
rel_pos = np.clip(rel_pos, 0, pos_num - 1)
|
136 |
+
|
137 |
+
if ori_w != w or ori_h != h:
|
138 |
+
rel_pos = cv2.resize(rel_pos, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
|
139 |
+
rel_pos[ori_mask == 0] = 0
|
140 |
+
direct = cv2.resize(direct, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
|
141 |
+
direct[ori_mask == 0, :] = 0
|
142 |
+
|
143 |
+
return rel_pos, abs_pos, direct
|
144 |
+
|
145 |
+
|
146 |
+
def load_image(img, mask, device, sigma256=3.0):
|
147 |
+
"""
|
148 |
+
Args:
|
149 |
+
img: [H, W, C] RGB
|
150 |
+
mask: [H, W] 255 为 masks 区域
|
151 |
+
sigma256:
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
|
155 |
+
"""
|
156 |
+
h, w, _ = img.shape
|
157 |
+
imgh, imgw = img.shape[0:2]
|
158 |
+
img_256 = resize(img, 256, 256)
|
159 |
+
|
160 |
+
mask = (mask > 127).astype(np.uint8) * 255
|
161 |
+
mask_256 = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_AREA)
|
162 |
+
mask_256[mask_256 > 0] = 255
|
163 |
+
|
164 |
+
mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA)
|
165 |
+
mask_512[mask_512 > 0] = 255
|
166 |
+
|
167 |
+
# original skimage implemention
|
168 |
+
# https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny
|
169 |
+
# low_threshold: Lower bound for hysteresis thresholding (linking edges). If None, low_threshold is set to 10% of dtype’s max.
|
170 |
+
# high_threshold: Upper bound for hysteresis thresholding (linking edges). If None, high_threshold is set to 20% of dtype’s max.
|
171 |
+
|
172 |
+
try:
|
173 |
+
import skimage
|
174 |
+
|
175 |
+
gray_256 = skimage.color.rgb2gray(img_256)
|
176 |
+
edge_256 = skimage.feature.canny(gray_256, sigma=3.0, mask=None).astype(float)
|
177 |
+
# cv2.imwrite("skimage_gray.jpg", (gray_256*255).astype(np.uint8))
|
178 |
+
# cv2.imwrite("skimage_edge.jpg", (edge_256*255).astype(np.uint8))
|
179 |
+
except:
|
180 |
+
gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY)
|
181 |
+
gray_256_blured = cv2.GaussianBlur(
|
182 |
+
gray_256, ksize=(7, 7), sigmaX=sigma256, sigmaY=sigma256
|
183 |
+
)
|
184 |
+
edge_256 = cv2.Canny(
|
185 |
+
gray_256_blured, threshold1=int(255 * 0.1), threshold2=int(255 * 0.2)
|
186 |
+
)
|
187 |
+
|
188 |
+
# cv2.imwrite("opencv_edge.jpg", edge_256)
|
189 |
+
|
190 |
+
# line
|
191 |
+
img_512 = resize(img, 512, 512)
|
192 |
+
|
193 |
+
rel_pos, abs_pos, direct = load_masked_position_encoding(mask)
|
194 |
+
|
195 |
+
batch = dict()
|
196 |
+
batch["images"] = to_tensor(img.copy()).unsqueeze(0).to(device)
|
197 |
+
batch["img_256"] = to_tensor(img_256, norm=True).unsqueeze(0).to(device)
|
198 |
+
batch["masks"] = to_tensor(mask).unsqueeze(0).to(device)
|
199 |
+
batch["mask_256"] = to_tensor(mask_256).unsqueeze(0).to(device)
|
200 |
+
batch["mask_512"] = to_tensor(mask_512).unsqueeze(0).to(device)
|
201 |
+
batch["edge_256"] = to_tensor(edge_256, scale=False).unsqueeze(0).to(device)
|
202 |
+
batch["img_512"] = to_tensor(img_512).unsqueeze(0).to(device)
|
203 |
+
batch["rel_pos"] = torch.LongTensor(rel_pos).unsqueeze(0).to(device)
|
204 |
+
batch["abs_pos"] = torch.LongTensor(abs_pos).unsqueeze(0).to(device)
|
205 |
+
batch["direct"] = torch.LongTensor(direct).unsqueeze(0).to(device)
|
206 |
+
batch["h"] = imgh
|
207 |
+
batch["w"] = imgw
|
208 |
+
|
209 |
+
return batch
|
210 |
+
|
211 |
+
|
212 |
+
def to_device(data, device):
|
213 |
+
if isinstance(data, torch.Tensor):
|
214 |
+
return data.to(device)
|
215 |
+
if isinstance(data, dict):
|
216 |
+
for key in data:
|
217 |
+
if isinstance(data[key], torch.Tensor):
|
218 |
+
data[key] = data[key].to(device)
|
219 |
+
return data
|
220 |
+
if isinstance(data, list):
|
221 |
+
return [to_device(d, device) for d in data]
|
222 |
+
|
223 |
+
|
224 |
+
class ZITS(InpaintModel):
|
225 |
+
name = "zits"
|
226 |
+
min_size = 256
|
227 |
+
pad_mod = 32
|
228 |
+
pad_to_square = True
|
229 |
+
is_erase_model = True
|
230 |
+
|
231 |
+
def __init__(self, device, **kwargs):
|
232 |
+
"""
|
233 |
+
|
234 |
+
Args:
|
235 |
+
device:
|
236 |
+
"""
|
237 |
+
super().__init__(device)
|
238 |
+
self.device = device
|
239 |
+
self.sample_edge_line_iterations = 1
|
240 |
+
|
241 |
+
def init_model(self, device, **kwargs):
|
242 |
+
self.wireframe = load_jit_model(
|
243 |
+
ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5
|
244 |
+
)
|
245 |
+
self.edge_line = load_jit_model(
|
246 |
+
ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5
|
247 |
+
)
|
248 |
+
self.structure_upsample = load_jit_model(
|
249 |
+
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
|
250 |
+
)
|
251 |
+
self.inpaint = load_jit_model(
|
252 |
+
ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5
|
253 |
+
)
|
254 |
+
|
255 |
+
@staticmethod
|
256 |
+
def download():
|
257 |
+
download_model(ZITS_WIRE_FRAME_MODEL_URL, ZITS_WIRE_FRAME_MODEL_MD5)
|
258 |
+
download_model(ZITS_EDGE_LINE_MODEL_URL, ZITS_EDGE_LINE_MODEL_MD5)
|
259 |
+
download_model(
|
260 |
+
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
|
261 |
+
)
|
262 |
+
download_model(ZITS_INPAINT_MODEL_URL, ZITS_INPAINT_MODEL_MD5)
|
263 |
+
|
264 |
+
@staticmethod
|
265 |
+
def is_downloaded() -> bool:
|
266 |
+
model_paths = [
|
267 |
+
get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL),
|
268 |
+
get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL),
|
269 |
+
get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL),
|
270 |
+
get_cache_path_by_url(ZITS_INPAINT_MODEL_URL),
|
271 |
+
]
|
272 |
+
return all([os.path.exists(it) for it in model_paths])
|
273 |
+
|
274 |
+
def wireframe_edge_and_line(self, items, enable: bool):
|
275 |
+
# 最终向 items 中添加 edge 和 line key
|
276 |
+
if not enable:
|
277 |
+
items["edge"] = torch.zeros_like(items["masks"])
|
278 |
+
items["line"] = torch.zeros_like(items["masks"])
|
279 |
+
return
|
280 |
+
|
281 |
+
start = time.time()
|
282 |
+
try:
|
283 |
+
line_256 = self.wireframe_forward(
|
284 |
+
items["img_512"],
|
285 |
+
h=256,
|
286 |
+
w=256,
|
287 |
+
masks=items["mask_512"],
|
288 |
+
mask_th=0.85,
|
289 |
+
)
|
290 |
+
except:
|
291 |
+
line_256 = torch.zeros_like(items["mask_256"])
|
292 |
+
|
293 |
+
print(f"wireframe_forward time: {(time.time() - start) * 1000:.2f}ms")
|
294 |
+
|
295 |
+
# np_line = (line[0][0].numpy() * 255).astype(np.uint8)
|
296 |
+
# cv2.imwrite("line.jpg", np_line)
|
297 |
+
|
298 |
+
start = time.time()
|
299 |
+
edge_pred, line_pred = self.sample_edge_line_logits(
|
300 |
+
context=[items["img_256"], items["edge_256"], line_256],
|
301 |
+
mask=items["mask_256"].clone(),
|
302 |
+
iterations=self.sample_edge_line_iterations,
|
303 |
+
add_v=0.05,
|
304 |
+
mul_v=4,
|
305 |
+
)
|
306 |
+
print(f"sample_edge_line_logits time: {(time.time() - start) * 1000:.2f}ms")
|
307 |
+
|
308 |
+
# np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
|
309 |
+
# cv2.imwrite("edge_pred.jpg", np_edge_pred)
|
310 |
+
# np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8)
|
311 |
+
# cv2.imwrite("line_pred.jpg", np_line_pred)
|
312 |
+
# exit()
|
313 |
+
|
314 |
+
input_size = min(items["h"], items["w"])
|
315 |
+
if input_size != 256 and input_size > 256:
|
316 |
+
while edge_pred.shape[2] < input_size:
|
317 |
+
edge_pred = self.structure_upsample(edge_pred)
|
318 |
+
edge_pred = torch.sigmoid((edge_pred + 2) * 2)
|
319 |
+
|
320 |
+
line_pred = self.structure_upsample(line_pred)
|
321 |
+
line_pred = torch.sigmoid((line_pred + 2) * 2)
|
322 |
+
|
323 |
+
edge_pred = F.interpolate(
|
324 |
+
edge_pred,
|
325 |
+
size=(input_size, input_size),
|
326 |
+
mode="bilinear",
|
327 |
+
align_corners=False,
|
328 |
+
)
|
329 |
+
line_pred = F.interpolate(
|
330 |
+
line_pred,
|
331 |
+
size=(input_size, input_size),
|
332 |
+
mode="bilinear",
|
333 |
+
align_corners=False,
|
334 |
+
)
|
335 |
+
|
336 |
+
# np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
|
337 |
+
# cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred)
|
338 |
+
# np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8)
|
339 |
+
# cv2.imwrite("line_pred_upsample.jpg", np_line_pred)
|
340 |
+
# exit()
|
341 |
+
|
342 |
+
items["edge"] = edge_pred.detach()
|
343 |
+
items["line"] = line_pred.detach()
|
344 |
+
|
345 |
+
@torch.no_grad()
|
346 |
+
def forward(self, image, mask, config: InpaintRequest):
|
347 |
+
"""Input images and output images have same size
|
348 |
+
images: [H, W, C] RGB
|
349 |
+
masks: [H, W]
|
350 |
+
return: BGR IMAGE
|
351 |
+
"""
|
352 |
+
mask = mask[:, :, 0]
|
353 |
+
items = load_image(image, mask, device=self.device)
|
354 |
+
|
355 |
+
self.wireframe_edge_and_line(items, config.zits_wireframe)
|
356 |
+
|
357 |
+
inpainted_image = self.inpaint(
|
358 |
+
items["images"],
|
359 |
+
items["masks"],
|
360 |
+
items["edge"],
|
361 |
+
items["line"],
|
362 |
+
items["rel_pos"],
|
363 |
+
items["direct"],
|
364 |
+
)
|
365 |
+
|
366 |
+
inpainted_image = inpainted_image * 255.0
|
367 |
+
inpainted_image = (
|
368 |
+
inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8)
|
369 |
+
)
|
370 |
+
inpainted_image = inpainted_image[:, :, ::-1]
|
371 |
+
|
372 |
+
# cv2.imwrite("inpainted.jpg", inpainted_image)
|
373 |
+
# exit()
|
374 |
+
|
375 |
+
return inpainted_image
|
376 |
+
|
377 |
+
def wireframe_forward(self, images, h, w, masks, mask_th=0.925):
|
378 |
+
lcnn_mean = torch.tensor([109.730, 103.832, 98.681]).reshape(1, 3, 1, 1)
|
379 |
+
lcnn_std = torch.tensor([22.275, 22.124, 23.229]).reshape(1, 3, 1, 1)
|
380 |
+
images = images * 255.0
|
381 |
+
# the masks value of lcnn is 127.5
|
382 |
+
masked_images = images * (1 - masks) + torch.ones_like(images) * masks * 127.5
|
383 |
+
masked_images = (masked_images - lcnn_mean) / lcnn_std
|
384 |
+
|
385 |
+
def to_int(x):
|
386 |
+
return tuple(map(int, x))
|
387 |
+
|
388 |
+
lines_tensor = []
|
389 |
+
lmap = np.zeros((h, w))
|
390 |
+
|
391 |
+
output_masked = self.wireframe(masked_images)
|
392 |
+
|
393 |
+
output_masked = to_device(output_masked, "cpu")
|
394 |
+
if output_masked["num_proposals"] == 0:
|
395 |
+
lines_masked = []
|
396 |
+
scores_masked = []
|
397 |
+
else:
|
398 |
+
lines_masked = output_masked["lines_pred"].numpy()
|
399 |
+
lines_masked = [
|
400 |
+
[line[1] * h, line[0] * w, line[3] * h, line[2] * w]
|
401 |
+
for line in lines_masked
|
402 |
+
]
|
403 |
+
scores_masked = output_masked["lines_score"].numpy()
|
404 |
+
|
405 |
+
for line, score in zip(lines_masked, scores_masked):
|
406 |
+
if score > mask_th:
|
407 |
+
try:
|
408 |
+
import skimage
|
409 |
+
|
410 |
+
rr, cc, value = skimage.draw.line_aa(
|
411 |
+
*to_int(line[0:2]), *to_int(line[2:4])
|
412 |
+
)
|
413 |
+
lmap[rr, cc] = np.maximum(lmap[rr, cc], value)
|
414 |
+
except:
|
415 |
+
cv2.line(
|
416 |
+
lmap,
|
417 |
+
to_int(line[0:2][::-1]),
|
418 |
+
to_int(line[2:4][::-1]),
|
419 |
+
(1, 1, 1),
|
420 |
+
1,
|
421 |
+
cv2.LINE_AA,
|
422 |
+
)
|
423 |
+
|
424 |
+
lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8)
|
425 |
+
lines_tensor.append(to_tensor(lmap).unsqueeze(0))
|
426 |
+
|
427 |
+
lines_tensor = torch.cat(lines_tensor, dim=0)
|
428 |
+
return lines_tensor.detach().to(self.device)
|
429 |
+
|
430 |
+
def sample_edge_line_logits(
|
431 |
+
self, context, mask=None, iterations=1, add_v=0, mul_v=4
|
432 |
+
):
|
433 |
+
[img, edge, line] = context
|
434 |
+
|
435 |
+
img = img * (1 - mask)
|
436 |
+
edge = edge * (1 - mask)
|
437 |
+
line = line * (1 - mask)
|
438 |
+
|
439 |
+
for i in range(iterations):
|
440 |
+
edge_logits, line_logits = self.edge_line(img, edge, line, masks=mask)
|
441 |
+
|
442 |
+
edge_pred = torch.sigmoid(edge_logits)
|
443 |
+
line_pred = torch.sigmoid((line_logits + add_v) * mul_v)
|
444 |
+
edge = edge + edge_pred * mask
|
445 |
+
edge[edge >= 0.25] = 1
|
446 |
+
edge[edge < 0.25] = 0
|
447 |
+
line = line + line_pred * mask
|
448 |
+
|
449 |
+
b, _, h, w = edge_pred.shape
|
450 |
+
edge_pred = edge_pred.reshape(b, -1, 1)
|
451 |
+
line_pred = line_pred.reshape(b, -1, 1)
|
452 |
+
mask = mask.reshape(b, -1)
|
453 |
+
|
454 |
+
edge_probs = torch.cat([1 - edge_pred, edge_pred], dim=-1)
|
455 |
+
line_probs = torch.cat([1 - line_pred, line_pred], dim=-1)
|
456 |
+
edge_probs[:, :, 1] += 0.5
|
457 |
+
line_probs[:, :, 1] += 0.5
|
458 |
+
edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100)
|
459 |
+
line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100)
|
460 |
+
|
461 |
+
indices = torch.sort(
|
462 |
+
edge_max_probs + line_max_probs, dim=-1, descending=True
|
463 |
+
)[1]
|
464 |
+
|
465 |
+
for ii in range(b):
|
466 |
+
keep = int((i + 1) / iterations * torch.sum(mask[ii, ...]))
|
467 |
+
|
468 |
+
assert torch.sum(mask[ii][indices[ii, :keep]]) == keep, "Error!!!"
|
469 |
+
mask[ii][indices[ii, :keep]] = 0
|
470 |
+
|
471 |
+
mask = mask.reshape(b, 1, h, w)
|
472 |
+
edge = edge * (1 - mask)
|
473 |
+
line = line * (1 - mask)
|
474 |
+
|
475 |
+
edge, line = edge.to(torch.float32), line.to(torch.float32)
|
476 |
+
return edge, line
|
iopaint/plugins/segment_anything/modeling/tiny_vit_sam.py
ADDED
@@ -0,0 +1,822 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# TinyViT Model Architecture
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Adapted from LeViT and Swin Transformer
|
5 |
+
# LeViT: (https://github.com/facebookresearch/levit)
|
6 |
+
# Swin: (https://github.com/microsoft/swin-transformer)
|
7 |
+
# Build the TinyViT Model
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import collections
|
11 |
+
import itertools
|
12 |
+
import math
|
13 |
+
import warnings
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torch.utils.checkpoint as checkpoint
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
|
21 |
+
def _ntuple(n):
|
22 |
+
def parse(x):
|
23 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
24 |
+
return x
|
25 |
+
return tuple(itertools.repeat(x, n))
|
26 |
+
|
27 |
+
return parse
|
28 |
+
|
29 |
+
|
30 |
+
to_2tuple = _ntuple(2)
|
31 |
+
|
32 |
+
|
33 |
+
def _trunc_normal_(tensor, mean, std, a, b):
|
34 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
35 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
36 |
+
def norm_cdf(x):
|
37 |
+
# Computes standard normal cumulative distribution function
|
38 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
39 |
+
|
40 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
41 |
+
warnings.warn(
|
42 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
43 |
+
"The distribution of values may be incorrect.",
|
44 |
+
stacklevel=2,
|
45 |
+
)
|
46 |
+
|
47 |
+
# Values are generated by using a truncated uniform distribution and
|
48 |
+
# then using the inverse CDF for the normal distribution.
|
49 |
+
# Get upper and lower cdf values
|
50 |
+
l = norm_cdf((a - mean) / std)
|
51 |
+
u = norm_cdf((b - mean) / std)
|
52 |
+
|
53 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
54 |
+
# [2l-1, 2u-1].
|
55 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
56 |
+
|
57 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
58 |
+
# standard normal
|
59 |
+
tensor.erfinv_()
|
60 |
+
|
61 |
+
# Transform to proper mean, std
|
62 |
+
tensor.mul_(std * math.sqrt(2.0))
|
63 |
+
tensor.add_(mean)
|
64 |
+
|
65 |
+
# Clamp to ensure it's in the proper range
|
66 |
+
tensor.clamp_(min=a, max=b)
|
67 |
+
return tensor
|
68 |
+
|
69 |
+
|
70 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
71 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
72 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
73 |
+
normal distribution. The values are effectively drawn from the
|
74 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
75 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
76 |
+
the bounds. The method used for generating the random values works
|
77 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
78 |
+
|
79 |
+
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
|
80 |
+
applied while sampling the normal with mean/std applied, therefore a, b args
|
81 |
+
should be adjusted to match the range of mean, std args.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
tensor: an n-dimensional `torch.Tensor`
|
85 |
+
mean: the mean of the normal distribution
|
86 |
+
std: the standard deviation of the normal distribution
|
87 |
+
a: the minimum cutoff value
|
88 |
+
b: the maximum cutoff value
|
89 |
+
Examples:
|
90 |
+
>>> w = torch.empty(3, 5)
|
91 |
+
>>> nn.init.trunc_normal_(w)
|
92 |
+
"""
|
93 |
+
with torch.no_grad():
|
94 |
+
return _trunc_normal_(tensor, mean, std, a, b)
|
95 |
+
|
96 |
+
|
97 |
+
def drop_path(
|
98 |
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
99 |
+
):
|
100 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
101 |
+
|
102 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
103 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
104 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
105 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
106 |
+
'survival rate' as the argument.
|
107 |
+
|
108 |
+
"""
|
109 |
+
if drop_prob == 0.0 or not training:
|
110 |
+
return x
|
111 |
+
keep_prob = 1 - drop_prob
|
112 |
+
shape = (x.shape[0],) + (1,) * (
|
113 |
+
x.ndim - 1
|
114 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
115 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
116 |
+
if keep_prob > 0.0 and scale_by_keep:
|
117 |
+
random_tensor.div_(keep_prob)
|
118 |
+
return x * random_tensor
|
119 |
+
|
120 |
+
|
121 |
+
class TimmDropPath(nn.Module):
|
122 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
123 |
+
|
124 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
125 |
+
super(TimmDropPath, self).__init__()
|
126 |
+
self.drop_prob = drop_prob
|
127 |
+
self.scale_by_keep = scale_by_keep
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
131 |
+
|
132 |
+
def extra_repr(self):
|
133 |
+
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
134 |
+
|
135 |
+
|
136 |
+
class Conv2d_BN(torch.nn.Sequential):
|
137 |
+
def __init__(
|
138 |
+
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1
|
139 |
+
):
|
140 |
+
super().__init__()
|
141 |
+
self.add_module(
|
142 |
+
"c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)
|
143 |
+
)
|
144 |
+
bn = torch.nn.BatchNorm2d(b)
|
145 |
+
torch.nn.init.constant_(bn.weight, bn_weight_init)
|
146 |
+
torch.nn.init.constant_(bn.bias, 0)
|
147 |
+
self.add_module("bn", bn)
|
148 |
+
|
149 |
+
@torch.no_grad()
|
150 |
+
def fuse(self):
|
151 |
+
c, bn = self._modules.values()
|
152 |
+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
153 |
+
w = c.weight * w[:, None, None, None]
|
154 |
+
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
155 |
+
m = torch.nn.Conv2d(
|
156 |
+
w.size(1) * self.c.groups,
|
157 |
+
w.size(0),
|
158 |
+
w.shape[2:],
|
159 |
+
stride=self.c.stride,
|
160 |
+
padding=self.c.padding,
|
161 |
+
dilation=self.c.dilation,
|
162 |
+
groups=self.c.groups,
|
163 |
+
)
|
164 |
+
m.weight.data.copy_(w)
|
165 |
+
m.bias.data.copy_(b)
|
166 |
+
return m
|
167 |
+
|
168 |
+
|
169 |
+
class DropPath(TimmDropPath):
|
170 |
+
def __init__(self, drop_prob=None):
|
171 |
+
super().__init__(drop_prob=drop_prob)
|
172 |
+
self.drop_prob = drop_prob
|
173 |
+
|
174 |
+
def __repr__(self):
|
175 |
+
msg = super().__repr__()
|
176 |
+
msg += f"(drop_prob={self.drop_prob})"
|
177 |
+
return msg
|
178 |
+
|
179 |
+
|
180 |
+
class PatchEmbed(nn.Module):
|
181 |
+
def __init__(self, in_chans, embed_dim, resolution, activation):
|
182 |
+
super().__init__()
|
183 |
+
img_size: Tuple[int, int] = to_2tuple(resolution)
|
184 |
+
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
185 |
+
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
|
186 |
+
self.in_chans = in_chans
|
187 |
+
self.embed_dim = embed_dim
|
188 |
+
n = embed_dim
|
189 |
+
self.seq = nn.Sequential(
|
190 |
+
Conv2d_BN(in_chans, n // 2, 3, 2, 1),
|
191 |
+
activation(),
|
192 |
+
Conv2d_BN(n // 2, n, 3, 2, 1),
|
193 |
+
)
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
return self.seq(x)
|
197 |
+
|
198 |
+
|
199 |
+
class MBConv(nn.Module):
|
200 |
+
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
201 |
+
super().__init__()
|
202 |
+
self.in_chans = in_chans
|
203 |
+
self.hidden_chans = int(in_chans * expand_ratio)
|
204 |
+
self.out_chans = out_chans
|
205 |
+
|
206 |
+
self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
|
207 |
+
self.act1 = activation()
|
208 |
+
|
209 |
+
self.conv2 = Conv2d_BN(
|
210 |
+
self.hidden_chans,
|
211 |
+
self.hidden_chans,
|
212 |
+
ks=3,
|
213 |
+
stride=1,
|
214 |
+
pad=1,
|
215 |
+
groups=self.hidden_chans,
|
216 |
+
)
|
217 |
+
self.act2 = activation()
|
218 |
+
|
219 |
+
self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
|
220 |
+
self.act3 = activation()
|
221 |
+
|
222 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
shortcut = x
|
226 |
+
|
227 |
+
x = self.conv1(x)
|
228 |
+
x = self.act1(x)
|
229 |
+
|
230 |
+
x = self.conv2(x)
|
231 |
+
x = self.act2(x)
|
232 |
+
|
233 |
+
x = self.conv3(x)
|
234 |
+
|
235 |
+
x = self.drop_path(x)
|
236 |
+
|
237 |
+
x += shortcut
|
238 |
+
x = self.act3(x)
|
239 |
+
|
240 |
+
return x
|
241 |
+
|
242 |
+
|
243 |
+
class PatchMerging(nn.Module):
|
244 |
+
def __init__(self, input_resolution, dim, out_dim, activation):
|
245 |
+
super().__init__()
|
246 |
+
|
247 |
+
self.input_resolution = input_resolution
|
248 |
+
self.dim = dim
|
249 |
+
self.out_dim = out_dim
|
250 |
+
self.act = activation()
|
251 |
+
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
252 |
+
stride_c = 2
|
253 |
+
if out_dim == 320 or out_dim == 448 or out_dim == 576:
|
254 |
+
stride_c = 1
|
255 |
+
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
256 |
+
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
257 |
+
|
258 |
+
def forward(self, x):
|
259 |
+
if x.ndim == 3:
|
260 |
+
H, W = self.input_resolution
|
261 |
+
B = len(x)
|
262 |
+
# (B, C, H, W)
|
263 |
+
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
264 |
+
|
265 |
+
x = self.conv1(x)
|
266 |
+
x = self.act(x)
|
267 |
+
|
268 |
+
x = self.conv2(x)
|
269 |
+
x = self.act(x)
|
270 |
+
x = self.conv3(x)
|
271 |
+
x = x.flatten(2).transpose(1, 2)
|
272 |
+
return x
|
273 |
+
|
274 |
+
|
275 |
+
class ConvLayer(nn.Module):
|
276 |
+
def __init__(
|
277 |
+
self,
|
278 |
+
dim,
|
279 |
+
input_resolution,
|
280 |
+
depth,
|
281 |
+
activation,
|
282 |
+
drop_path=0.0,
|
283 |
+
downsample=None,
|
284 |
+
use_checkpoint=False,
|
285 |
+
out_dim=None,
|
286 |
+
conv_expand_ratio=4.0,
|
287 |
+
):
|
288 |
+
super().__init__()
|
289 |
+
self.dim = dim
|
290 |
+
self.input_resolution = input_resolution
|
291 |
+
self.depth = depth
|
292 |
+
self.use_checkpoint = use_checkpoint
|
293 |
+
|
294 |
+
# build blocks
|
295 |
+
self.blocks = nn.ModuleList(
|
296 |
+
[
|
297 |
+
MBConv(
|
298 |
+
dim,
|
299 |
+
dim,
|
300 |
+
conv_expand_ratio,
|
301 |
+
activation,
|
302 |
+
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
303 |
+
)
|
304 |
+
for i in range(depth)
|
305 |
+
]
|
306 |
+
)
|
307 |
+
|
308 |
+
# patch merging layer
|
309 |
+
if downsample is not None:
|
310 |
+
self.downsample = downsample(
|
311 |
+
input_resolution, dim=dim, out_dim=out_dim, activation=activation
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
self.downsample = None
|
315 |
+
|
316 |
+
def forward(self, x):
|
317 |
+
for blk in self.blocks:
|
318 |
+
if self.use_checkpoint:
|
319 |
+
x = checkpoint.checkpoint(blk, x)
|
320 |
+
else:
|
321 |
+
x = blk(x)
|
322 |
+
if self.downsample is not None:
|
323 |
+
x = self.downsample(x)
|
324 |
+
return x
|
325 |
+
|
326 |
+
|
327 |
+
class Mlp(nn.Module):
|
328 |
+
def __init__(
|
329 |
+
self,
|
330 |
+
in_features,
|
331 |
+
hidden_features=None,
|
332 |
+
out_features=None,
|
333 |
+
act_layer=nn.GELU,
|
334 |
+
drop=0.0,
|
335 |
+
):
|
336 |
+
super().__init__()
|
337 |
+
out_features = out_features or in_features
|
338 |
+
hidden_features = hidden_features or in_features
|
339 |
+
self.norm = nn.LayerNorm(in_features)
|
340 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
341 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
342 |
+
self.act = act_layer()
|
343 |
+
self.drop = nn.Dropout(drop)
|
344 |
+
|
345 |
+
def forward(self, x):
|
346 |
+
x = self.norm(x)
|
347 |
+
|
348 |
+
x = self.fc1(x)
|
349 |
+
x = self.act(x)
|
350 |
+
x = self.drop(x)
|
351 |
+
x = self.fc2(x)
|
352 |
+
x = self.drop(x)
|
353 |
+
return x
|
354 |
+
|
355 |
+
|
356 |
+
class Attention(torch.nn.Module):
|
357 |
+
def __init__(
|
358 |
+
self,
|
359 |
+
dim,
|
360 |
+
key_dim,
|
361 |
+
num_heads=8,
|
362 |
+
attn_ratio=4,
|
363 |
+
resolution=(14, 14),
|
364 |
+
):
|
365 |
+
super().__init__()
|
366 |
+
# (h, w)
|
367 |
+
assert isinstance(resolution, tuple) and len(resolution) == 2
|
368 |
+
self.num_heads = num_heads
|
369 |
+
self.scale = key_dim**-0.5
|
370 |
+
self.key_dim = key_dim
|
371 |
+
self.nh_kd = nh_kd = key_dim * num_heads
|
372 |
+
self.d = int(attn_ratio * key_dim)
|
373 |
+
self.dh = int(attn_ratio * key_dim) * num_heads
|
374 |
+
self.attn_ratio = attn_ratio
|
375 |
+
h = self.dh + nh_kd * 2
|
376 |
+
|
377 |
+
self.norm = nn.LayerNorm(dim)
|
378 |
+
self.qkv = nn.Linear(dim, h)
|
379 |
+
self.proj = nn.Linear(self.dh, dim)
|
380 |
+
|
381 |
+
points = list(itertools.product(range(resolution[0]), range(resolution[1])))
|
382 |
+
N = len(points)
|
383 |
+
attention_offsets = {}
|
384 |
+
idxs = []
|
385 |
+
for p1 in points:
|
386 |
+
for p2 in points:
|
387 |
+
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
388 |
+
if offset not in attention_offsets:
|
389 |
+
attention_offsets[offset] = len(attention_offsets)
|
390 |
+
idxs.append(attention_offsets[offset])
|
391 |
+
self.attention_biases = torch.nn.Parameter(
|
392 |
+
torch.zeros(num_heads, len(attention_offsets))
|
393 |
+
)
|
394 |
+
self.register_buffer(
|
395 |
+
"attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False
|
396 |
+
)
|
397 |
+
|
398 |
+
@torch.no_grad()
|
399 |
+
def train(self, mode=True):
|
400 |
+
super().train(mode)
|
401 |
+
if mode and hasattr(self, "ab"):
|
402 |
+
del self.ab
|
403 |
+
else:
|
404 |
+
self.register_buffer(
|
405 |
+
"ab",
|
406 |
+
self.attention_biases[:, self.attention_bias_idxs],
|
407 |
+
persistent=False,
|
408 |
+
)
|
409 |
+
|
410 |
+
def forward(self, x): # x (B,N,C)
|
411 |
+
B, N, _ = x.shape
|
412 |
+
|
413 |
+
# Normalization
|
414 |
+
x = self.norm(x)
|
415 |
+
|
416 |
+
qkv = self.qkv(x)
|
417 |
+
# (B, N, num_heads, d)
|
418 |
+
q, k, v = qkv.view(B, N, self.num_heads, -1).split(
|
419 |
+
[self.key_dim, self.key_dim, self.d], dim=3
|
420 |
+
)
|
421 |
+
# (B, num_heads, N, d)
|
422 |
+
q = q.permute(0, 2, 1, 3)
|
423 |
+
k = k.permute(0, 2, 1, 3)
|
424 |
+
v = v.permute(0, 2, 1, 3)
|
425 |
+
|
426 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale + (
|
427 |
+
self.attention_biases[:, self.attention_bias_idxs]
|
428 |
+
if self.training
|
429 |
+
else self.ab
|
430 |
+
)
|
431 |
+
attn = attn.softmax(dim=-1)
|
432 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
433 |
+
x = self.proj(x)
|
434 |
+
return x
|
435 |
+
|
436 |
+
|
437 |
+
class TinyViTBlock(nn.Module):
|
438 |
+
r"""TinyViT Block.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
dim (int): Number of input channels.
|
442 |
+
input_resolution (tuple[int, int]): Input resolution.
|
443 |
+
num_heads (int): Number of attention heads.
|
444 |
+
window_size (int): Window size.
|
445 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
446 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
447 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
448 |
+
local_conv_size (int): the kernel size of the convolution between
|
449 |
+
Attention and MLP. Default: 3
|
450 |
+
activation: the activation function. Default: nn.GELU
|
451 |
+
"""
|
452 |
+
|
453 |
+
def __init__(
|
454 |
+
self,
|
455 |
+
dim,
|
456 |
+
input_resolution,
|
457 |
+
num_heads,
|
458 |
+
window_size=7,
|
459 |
+
mlp_ratio=4.0,
|
460 |
+
drop=0.0,
|
461 |
+
drop_path=0.0,
|
462 |
+
local_conv_size=3,
|
463 |
+
activation=nn.GELU,
|
464 |
+
):
|
465 |
+
super().__init__()
|
466 |
+
self.dim = dim
|
467 |
+
self.input_resolution = input_resolution
|
468 |
+
self.num_heads = num_heads
|
469 |
+
assert window_size > 0, "window_size must be greater than 0"
|
470 |
+
self.window_size = window_size
|
471 |
+
self.mlp_ratio = mlp_ratio
|
472 |
+
|
473 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
474 |
+
|
475 |
+
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
476 |
+
head_dim = dim // num_heads
|
477 |
+
|
478 |
+
window_resolution = (window_size, window_size)
|
479 |
+
self.attn = Attention(
|
480 |
+
dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution
|
481 |
+
)
|
482 |
+
|
483 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
484 |
+
mlp_activation = activation
|
485 |
+
self.mlp = Mlp(
|
486 |
+
in_features=dim,
|
487 |
+
hidden_features=mlp_hidden_dim,
|
488 |
+
act_layer=mlp_activation,
|
489 |
+
drop=drop,
|
490 |
+
)
|
491 |
+
|
492 |
+
pad = local_conv_size // 2
|
493 |
+
self.local_conv = Conv2d_BN(
|
494 |
+
dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim
|
495 |
+
)
|
496 |
+
|
497 |
+
def forward(self, x):
|
498 |
+
H, W = self.input_resolution
|
499 |
+
B, L, C = x.shape
|
500 |
+
assert L == H * W, "input feature has wrong size"
|
501 |
+
res_x = x
|
502 |
+
if H == self.window_size and W == self.window_size:
|
503 |
+
x = self.attn(x)
|
504 |
+
else:
|
505 |
+
x = x.view(B, H, W, C)
|
506 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
507 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
508 |
+
padding = pad_b > 0 or pad_r > 0
|
509 |
+
|
510 |
+
if padding:
|
511 |
+
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
512 |
+
|
513 |
+
pH, pW = H + pad_b, W + pad_r
|
514 |
+
nH = pH // self.window_size
|
515 |
+
nW = pW // self.window_size
|
516 |
+
# window partition
|
517 |
+
x = (
|
518 |
+
x.view(B, nH, self.window_size, nW, self.window_size, C)
|
519 |
+
.transpose(2, 3)
|
520 |
+
.reshape(B * nH * nW, self.window_size * self.window_size, C)
|
521 |
+
)
|
522 |
+
x = self.attn(x)
|
523 |
+
# window reverse
|
524 |
+
x = (
|
525 |
+
x.view(B, nH, nW, self.window_size, self.window_size, C)
|
526 |
+
.transpose(2, 3)
|
527 |
+
.reshape(B, pH, pW, C)
|
528 |
+
)
|
529 |
+
|
530 |
+
if padding:
|
531 |
+
x = x[:, :H, :W].contiguous()
|
532 |
+
|
533 |
+
x = x.view(B, L, C)
|
534 |
+
|
535 |
+
x = res_x + self.drop_path(x)
|
536 |
+
|
537 |
+
x = x.transpose(1, 2).reshape(B, C, H, W)
|
538 |
+
x = self.local_conv(x)
|
539 |
+
x = x.view(B, C, L).transpose(1, 2)
|
540 |
+
|
541 |
+
x = x + self.drop_path(self.mlp(x))
|
542 |
+
return x
|
543 |
+
|
544 |
+
def extra_repr(self) -> str:
|
545 |
+
return (
|
546 |
+
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
547 |
+
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
|
548 |
+
)
|
549 |
+
|
550 |
+
|
551 |
+
class BasicLayer(nn.Module):
|
552 |
+
"""A basic TinyViT layer for one stage.
|
553 |
+
|
554 |
+
Args:
|
555 |
+
dim (int): Number of input channels.
|
556 |
+
input_resolution (tuple[int]): Input resolution.
|
557 |
+
depth (int): Number of blocks.
|
558 |
+
num_heads (int): Number of attention heads.
|
559 |
+
window_size (int): Local window size.
|
560 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
561 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
562 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
563 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
564 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
565 |
+
local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
|
566 |
+
activation: the activation function. Default: nn.GELU
|
567 |
+
out_dim: the output dimension of the layer. Default: dim
|
568 |
+
"""
|
569 |
+
|
570 |
+
def __init__(
|
571 |
+
self,
|
572 |
+
dim,
|
573 |
+
input_resolution,
|
574 |
+
depth,
|
575 |
+
num_heads,
|
576 |
+
window_size,
|
577 |
+
mlp_ratio=4.0,
|
578 |
+
drop=0.0,
|
579 |
+
drop_path=0.0,
|
580 |
+
downsample=None,
|
581 |
+
use_checkpoint=False,
|
582 |
+
local_conv_size=3,
|
583 |
+
activation=nn.GELU,
|
584 |
+
out_dim=None,
|
585 |
+
):
|
586 |
+
super().__init__()
|
587 |
+
self.dim = dim
|
588 |
+
self.input_resolution = input_resolution
|
589 |
+
self.depth = depth
|
590 |
+
self.use_checkpoint = use_checkpoint
|
591 |
+
|
592 |
+
# build blocks
|
593 |
+
self.blocks = nn.ModuleList(
|
594 |
+
[
|
595 |
+
TinyViTBlock(
|
596 |
+
dim=dim,
|
597 |
+
input_resolution=input_resolution,
|
598 |
+
num_heads=num_heads,
|
599 |
+
window_size=window_size,
|
600 |
+
mlp_ratio=mlp_ratio,
|
601 |
+
drop=drop,
|
602 |
+
drop_path=drop_path[i]
|
603 |
+
if isinstance(drop_path, list)
|
604 |
+
else drop_path,
|
605 |
+
local_conv_size=local_conv_size,
|
606 |
+
activation=activation,
|
607 |
+
)
|
608 |
+
for i in range(depth)
|
609 |
+
]
|
610 |
+
)
|
611 |
+
|
612 |
+
# patch merging layer
|
613 |
+
if downsample is not None:
|
614 |
+
self.downsample = downsample(
|
615 |
+
input_resolution, dim=dim, out_dim=out_dim, activation=activation
|
616 |
+
)
|
617 |
+
else:
|
618 |
+
self.downsample = None
|
619 |
+
|
620 |
+
def forward(self, x):
|
621 |
+
for blk in self.blocks:
|
622 |
+
if self.use_checkpoint:
|
623 |
+
x = checkpoint.checkpoint(blk, x)
|
624 |
+
else:
|
625 |
+
x = blk(x)
|
626 |
+
if self.downsample is not None:
|
627 |
+
x = self.downsample(x)
|
628 |
+
return x
|
629 |
+
|
630 |
+
def extra_repr(self) -> str:
|
631 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
632 |
+
|
633 |
+
|
634 |
+
class LayerNorm2d(nn.Module):
|
635 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
636 |
+
super().__init__()
|
637 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
638 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
639 |
+
self.eps = eps
|
640 |
+
|
641 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
642 |
+
u = x.mean(1, keepdim=True)
|
643 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
644 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
645 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
646 |
+
return x
|
647 |
+
|
648 |
+
|
649 |
+
class TinyViT(nn.Module):
|
650 |
+
def __init__(
|
651 |
+
self,
|
652 |
+
img_size=224,
|
653 |
+
in_chans=3,
|
654 |
+
num_classes=1000,
|
655 |
+
embed_dims=[96, 192, 384, 768],
|
656 |
+
depths=[2, 2, 6, 2],
|
657 |
+
num_heads=[3, 6, 12, 24],
|
658 |
+
window_sizes=[7, 7, 14, 7],
|
659 |
+
mlp_ratio=4.0,
|
660 |
+
drop_rate=0.0,
|
661 |
+
drop_path_rate=0.1,
|
662 |
+
use_checkpoint=False,
|
663 |
+
mbconv_expand_ratio=4.0,
|
664 |
+
local_conv_size=3,
|
665 |
+
layer_lr_decay=1.0,
|
666 |
+
):
|
667 |
+
super().__init__()
|
668 |
+
self.img_size = img_size
|
669 |
+
self.num_classes = num_classes
|
670 |
+
self.depths = depths
|
671 |
+
self.num_layers = len(depths)
|
672 |
+
self.mlp_ratio = mlp_ratio
|
673 |
+
|
674 |
+
activation = nn.GELU
|
675 |
+
|
676 |
+
self.patch_embed = PatchEmbed(
|
677 |
+
in_chans=in_chans,
|
678 |
+
embed_dim=embed_dims[0],
|
679 |
+
resolution=img_size,
|
680 |
+
activation=activation,
|
681 |
+
)
|
682 |
+
|
683 |
+
patches_resolution = self.patch_embed.patches_resolution
|
684 |
+
self.patches_resolution = patches_resolution
|
685 |
+
|
686 |
+
# stochastic depth
|
687 |
+
dpr = [
|
688 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
689 |
+
] # stochastic depth decay rule
|
690 |
+
|
691 |
+
# build layers
|
692 |
+
self.layers = nn.ModuleList()
|
693 |
+
for i_layer in range(self.num_layers):
|
694 |
+
kwargs = dict(
|
695 |
+
dim=embed_dims[i_layer],
|
696 |
+
input_resolution=(
|
697 |
+
patches_resolution[0]
|
698 |
+
// (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
699 |
+
patches_resolution[1]
|
700 |
+
// (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
701 |
+
),
|
702 |
+
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
703 |
+
# patches_resolution[1] // (2 ** i_layer)),
|
704 |
+
depth=depths[i_layer],
|
705 |
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
706 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
707 |
+
use_checkpoint=use_checkpoint,
|
708 |
+
out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
|
709 |
+
activation=activation,
|
710 |
+
)
|
711 |
+
if i_layer == 0:
|
712 |
+
layer = ConvLayer(
|
713 |
+
conv_expand_ratio=mbconv_expand_ratio,
|
714 |
+
**kwargs,
|
715 |
+
)
|
716 |
+
else:
|
717 |
+
layer = BasicLayer(
|
718 |
+
num_heads=num_heads[i_layer],
|
719 |
+
window_size=window_sizes[i_layer],
|
720 |
+
mlp_ratio=self.mlp_ratio,
|
721 |
+
drop=drop_rate,
|
722 |
+
local_conv_size=local_conv_size,
|
723 |
+
**kwargs,
|
724 |
+
)
|
725 |
+
self.layers.append(layer)
|
726 |
+
|
727 |
+
# Classifier head
|
728 |
+
self.norm_head = nn.LayerNorm(embed_dims[-1])
|
729 |
+
self.head = (
|
730 |
+
nn.Linear(embed_dims[-1], num_classes)
|
731 |
+
if num_classes > 0
|
732 |
+
else torch.nn.Identity()
|
733 |
+
)
|
734 |
+
|
735 |
+
# init weights
|
736 |
+
self.apply(self._init_weights)
|
737 |
+
self.set_layer_lr_decay(layer_lr_decay)
|
738 |
+
self.neck = nn.Sequential(
|
739 |
+
nn.Conv2d(
|
740 |
+
embed_dims[-1],
|
741 |
+
256,
|
742 |
+
kernel_size=1,
|
743 |
+
bias=False,
|
744 |
+
),
|
745 |
+
LayerNorm2d(256),
|
746 |
+
nn.Conv2d(
|
747 |
+
256,
|
748 |
+
256,
|
749 |
+
kernel_size=3,
|
750 |
+
padding=1,
|
751 |
+
bias=False,
|
752 |
+
),
|
753 |
+
LayerNorm2d(256),
|
754 |
+
)
|
755 |
+
|
756 |
+
def set_layer_lr_decay(self, layer_lr_decay):
|
757 |
+
decay_rate = layer_lr_decay
|
758 |
+
|
759 |
+
# layers -> blocks (depth)
|
760 |
+
depth = sum(self.depths)
|
761 |
+
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
762 |
+
# print("LR SCALES:", lr_scales)
|
763 |
+
|
764 |
+
def _set_lr_scale(m, scale):
|
765 |
+
for p in m.parameters():
|
766 |
+
p.lr_scale = scale
|
767 |
+
|
768 |
+
self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
|
769 |
+
i = 0
|
770 |
+
for layer in self.layers:
|
771 |
+
for block in layer.blocks:
|
772 |
+
block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
|
773 |
+
i += 1
|
774 |
+
if layer.downsample is not None:
|
775 |
+
layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
|
776 |
+
assert i == depth
|
777 |
+
for m in [self.norm_head, self.head]:
|
778 |
+
m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
|
779 |
+
|
780 |
+
for k, p in self.named_parameters():
|
781 |
+
p.param_name = k
|
782 |
+
|
783 |
+
def _check_lr_scale(m):
|
784 |
+
for p in m.parameters():
|
785 |
+
assert hasattr(p, "lr_scale"), p.param_name
|
786 |
+
|
787 |
+
self.apply(_check_lr_scale)
|
788 |
+
|
789 |
+
def _init_weights(self, m):
|
790 |
+
if isinstance(m, nn.Linear):
|
791 |
+
trunc_normal_(m.weight, std=0.02)
|
792 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
793 |
+
nn.init.constant_(m.bias, 0)
|
794 |
+
elif isinstance(m, nn.LayerNorm):
|
795 |
+
nn.init.constant_(m.bias, 0)
|
796 |
+
nn.init.constant_(m.weight, 1.0)
|
797 |
+
|
798 |
+
@torch.jit.ignore
|
799 |
+
def no_weight_decay_keywords(self):
|
800 |
+
return {"attention_biases"}
|
801 |
+
|
802 |
+
def forward_features(self, x):
|
803 |
+
# x: (N, C, H, W)
|
804 |
+
x = self.patch_embed(x)
|
805 |
+
|
806 |
+
x = self.layers[0](x)
|
807 |
+
start_i = 1
|
808 |
+
|
809 |
+
for i in range(start_i, len(self.layers)):
|
810 |
+
layer = self.layers[i]
|
811 |
+
x = layer(x)
|
812 |
+
B, _, C = x.size()
|
813 |
+
x = x.view(B, 64, 64, C)
|
814 |
+
x = x.permute(0, 3, 1, 2)
|
815 |
+
x = self.neck(x)
|
816 |
+
return x
|
817 |
+
|
818 |
+
def forward(self, x):
|
819 |
+
x = self.forward_features(x)
|
820 |
+
# x = self.norm_head(x)
|
821 |
+
# x = self.head(x)
|
822 |
+
return x
|
iopaint/plugins/segment_anything/modeling/transformer.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import Tensor, nn
|
9 |
+
|
10 |
+
import math
|
11 |
+
from typing import Tuple, Type
|
12 |
+
|
13 |
+
from .common import MLPBlock
|
14 |
+
|
15 |
+
|
16 |
+
class TwoWayTransformer(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
depth: int,
|
20 |
+
embedding_dim: int,
|
21 |
+
num_heads: int,
|
22 |
+
mlp_dim: int,
|
23 |
+
activation: Type[nn.Module] = nn.ReLU,
|
24 |
+
attention_downsample_rate: int = 2,
|
25 |
+
) -> None:
|
26 |
+
"""
|
27 |
+
A transformer decoder that attends to an input image using
|
28 |
+
queries whose positional embedding is supplied.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
depth (int): number of layers in the transformer
|
32 |
+
embedding_dim (int): the channel dimension for the input embeddings
|
33 |
+
num_heads (int): the number of heads for multihead attention. Must
|
34 |
+
divide embedding_dim
|
35 |
+
mlp_dim (int): the channel dimension internal to the MLP block
|
36 |
+
activation (nn.Module): the activation to use in the MLP block
|
37 |
+
"""
|
38 |
+
super().__init__()
|
39 |
+
self.depth = depth
|
40 |
+
self.embedding_dim = embedding_dim
|
41 |
+
self.num_heads = num_heads
|
42 |
+
self.mlp_dim = mlp_dim
|
43 |
+
self.layers = nn.ModuleList()
|
44 |
+
|
45 |
+
for i in range(depth):
|
46 |
+
self.layers.append(
|
47 |
+
TwoWayAttentionBlock(
|
48 |
+
embedding_dim=embedding_dim,
|
49 |
+
num_heads=num_heads,
|
50 |
+
mlp_dim=mlp_dim,
|
51 |
+
activation=activation,
|
52 |
+
attention_downsample_rate=attention_downsample_rate,
|
53 |
+
skip_first_layer_pe=(i == 0),
|
54 |
+
)
|
55 |
+
)
|
56 |
+
|
57 |
+
self.final_attn_token_to_image = Attention(
|
58 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
59 |
+
)
|
60 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
61 |
+
|
62 |
+
def forward(
|
63 |
+
self,
|
64 |
+
image_embedding: Tensor,
|
65 |
+
image_pe: Tensor,
|
66 |
+
point_embedding: Tensor,
|
67 |
+
) -> Tuple[Tensor, Tensor]:
|
68 |
+
"""
|
69 |
+
Args:
|
70 |
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
71 |
+
B x embedding_dim x h x w for any h and w.
|
72 |
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
73 |
+
have the same shape as image_embedding.
|
74 |
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
75 |
+
Must have shape B x N_points x embedding_dim for any N_points.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
torch.Tensor: the processed point_embedding
|
79 |
+
torch.Tensor: the processed image_embedding
|
80 |
+
"""
|
81 |
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
82 |
+
bs, c, h, w = image_embedding.shape
|
83 |
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
84 |
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
85 |
+
|
86 |
+
# Prepare queries
|
87 |
+
queries = point_embedding
|
88 |
+
keys = image_embedding
|
89 |
+
|
90 |
+
# Apply transformer blocks and final layernorm
|
91 |
+
for layer in self.layers:
|
92 |
+
queries, keys = layer(
|
93 |
+
queries=queries,
|
94 |
+
keys=keys,
|
95 |
+
query_pe=point_embedding,
|
96 |
+
key_pe=image_pe,
|
97 |
+
)
|
98 |
+
|
99 |
+
# Apply the final attenion layer from the points to the image
|
100 |
+
q = queries + point_embedding
|
101 |
+
k = keys + image_pe
|
102 |
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
103 |
+
queries = queries + attn_out
|
104 |
+
queries = self.norm_final_attn(queries)
|
105 |
+
|
106 |
+
return queries, keys
|
107 |
+
|
108 |
+
|
109 |
+
class TwoWayAttentionBlock(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
embedding_dim: int,
|
113 |
+
num_heads: int,
|
114 |
+
mlp_dim: int = 2048,
|
115 |
+
activation: Type[nn.Module] = nn.ReLU,
|
116 |
+
attention_downsample_rate: int = 2,
|
117 |
+
skip_first_layer_pe: bool = False,
|
118 |
+
) -> None:
|
119 |
+
"""
|
120 |
+
A transformer block with four layers: (1) self-attention of sparse
|
121 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
122 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
123 |
+
inputs.
|
124 |
+
|
125 |
+
Arguments:
|
126 |
+
embedding_dim (int): the channel dimension of the embeddings
|
127 |
+
num_heads (int): the number of heads in the attention layers
|
128 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
129 |
+
activation (nn.Module): the activation of the mlp block
|
130 |
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
131 |
+
"""
|
132 |
+
super().__init__()
|
133 |
+
self.self_attn = Attention(embedding_dim, num_heads)
|
134 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
135 |
+
|
136 |
+
self.cross_attn_token_to_image = Attention(
|
137 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
138 |
+
)
|
139 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
140 |
+
|
141 |
+
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
142 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
143 |
+
|
144 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
145 |
+
self.cross_attn_image_to_token = Attention(
|
146 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
147 |
+
)
|
148 |
+
|
149 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
150 |
+
|
151 |
+
def forward(
|
152 |
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
153 |
+
) -> Tuple[Tensor, Tensor]:
|
154 |
+
# Self attention block
|
155 |
+
if self.skip_first_layer_pe:
|
156 |
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
157 |
+
else:
|
158 |
+
q = queries + query_pe
|
159 |
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
160 |
+
queries = queries + attn_out
|
161 |
+
queries = self.norm1(queries)
|
162 |
+
|
163 |
+
# Cross attention block, tokens attending to image embedding
|
164 |
+
q = queries + query_pe
|
165 |
+
k = keys + key_pe
|
166 |
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
167 |
+
queries = queries + attn_out
|
168 |
+
queries = self.norm2(queries)
|
169 |
+
|
170 |
+
# MLP block
|
171 |
+
mlp_out = self.mlp(queries)
|
172 |
+
queries = queries + mlp_out
|
173 |
+
queries = self.norm3(queries)
|
174 |
+
|
175 |
+
# Cross attention block, image embedding attending to tokens
|
176 |
+
q = queries + query_pe
|
177 |
+
k = keys + key_pe
|
178 |
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
179 |
+
keys = keys + attn_out
|
180 |
+
keys = self.norm4(keys)
|
181 |
+
|
182 |
+
return queries, keys
|
183 |
+
|
184 |
+
|
185 |
+
class Attention(nn.Module):
|
186 |
+
"""
|
187 |
+
An attention layer that allows for downscaling the size of the embedding
|
188 |
+
after projection to queries, keys, and values.
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
embedding_dim: int,
|
194 |
+
num_heads: int,
|
195 |
+
downsample_rate: int = 1,
|
196 |
+
) -> None:
|
197 |
+
super().__init__()
|
198 |
+
self.embedding_dim = embedding_dim
|
199 |
+
self.internal_dim = embedding_dim // downsample_rate
|
200 |
+
self.num_heads = num_heads
|
201 |
+
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
202 |
+
|
203 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
204 |
+
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
205 |
+
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
206 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
207 |
+
|
208 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
209 |
+
b, n, c = x.shape
|
210 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
211 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
212 |
+
|
213 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
214 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
215 |
+
x = x.transpose(1, 2)
|
216 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
217 |
+
|
218 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
219 |
+
# Input projections
|
220 |
+
q = self.q_proj(q)
|
221 |
+
k = self.k_proj(k)
|
222 |
+
v = self.v_proj(v)
|
223 |
+
|
224 |
+
# Separate into heads
|
225 |
+
q = self._separate_heads(q, self.num_heads)
|
226 |
+
k = self._separate_heads(k, self.num_heads)
|
227 |
+
v = self._separate_heads(v, self.num_heads)
|
228 |
+
|
229 |
+
# Attention
|
230 |
+
_, _, _, c_per_head = q.shape
|
231 |
+
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
232 |
+
attn = attn / math.sqrt(c_per_head)
|
233 |
+
attn = torch.softmax(attn, dim=-1)
|
234 |
+
|
235 |
+
# Get output
|
236 |
+
out = attn @ v
|
237 |
+
out = self._recombine_heads(out)
|
238 |
+
out = self.out_proj(out)
|
239 |
+
|
240 |
+
return out
|
iopaint/plugins/segment_anything/utils/transforms.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
|
11 |
+
|
12 |
+
from copy import deepcopy
|
13 |
+
from typing import Tuple
|
14 |
+
|
15 |
+
|
16 |
+
class ResizeLongestSide:
|
17 |
+
"""
|
18 |
+
Resizes images to longest side 'target_length', as well as provides
|
19 |
+
methods for resizing coordinates and boxes. Provides methods for
|
20 |
+
transforming both numpy array and batched torch tensors.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, target_length: int) -> None:
|
24 |
+
self.target_length = target_length
|
25 |
+
|
26 |
+
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
27 |
+
"""
|
28 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
29 |
+
"""
|
30 |
+
target_size = self.get_preprocess_shape(
|
31 |
+
image.shape[0], image.shape[1], self.target_length
|
32 |
+
)
|
33 |
+
return np.array(resize(to_pil_image(image), target_size))
|
34 |
+
|
35 |
+
def apply_coords(
|
36 |
+
self, coords: np.ndarray, original_size: Tuple[int, ...]
|
37 |
+
) -> np.ndarray:
|
38 |
+
"""
|
39 |
+
Expects a numpy array of length 2 in the final dimension. Requires the
|
40 |
+
original image size in (H, W) format.
|
41 |
+
"""
|
42 |
+
old_h, old_w = original_size
|
43 |
+
new_h, new_w = self.get_preprocess_shape(
|
44 |
+
original_size[0], original_size[1], self.target_length
|
45 |
+
)
|
46 |
+
coords = deepcopy(coords).astype(float)
|
47 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
48 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
49 |
+
return coords
|
50 |
+
|
51 |
+
def apply_boxes(
|
52 |
+
self, boxes: np.ndarray, original_size: Tuple[int, ...]
|
53 |
+
) -> np.ndarray:
|
54 |
+
"""
|
55 |
+
Expects a numpy array shape Bx4. Requires the original image size
|
56 |
+
in (H, W) format.
|
57 |
+
"""
|
58 |
+
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
59 |
+
return boxes.reshape(-1, 4)
|
60 |
+
|
61 |
+
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
62 |
+
"""
|
63 |
+
Expects batched images with shape BxCxHxW and float format. This
|
64 |
+
transformation may not exactly match apply_image. apply_image is
|
65 |
+
the transformation expected by the model.
|
66 |
+
"""
|
67 |
+
# Expects an image in BCHW format. May not exactly match apply_image.
|
68 |
+
target_size = self.get_preprocess_shape(
|
69 |
+
image.shape[0], image.shape[1], self.target_length
|
70 |
+
)
|
71 |
+
return F.interpolate(
|
72 |
+
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
73 |
+
)
|
74 |
+
|
75 |
+
def apply_coords_torch(
|
76 |
+
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
77 |
+
) -> torch.Tensor:
|
78 |
+
"""
|
79 |
+
Expects a torch tensor with length 2 in the last dimension. Requires the
|
80 |
+
original image size in (H, W) format.
|
81 |
+
"""
|
82 |
+
old_h, old_w = original_size
|
83 |
+
new_h, new_w = self.get_preprocess_shape(
|
84 |
+
original_size[0], original_size[1], self.target_length
|
85 |
+
)
|
86 |
+
coords = deepcopy(coords).to(torch.float)
|
87 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
88 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
89 |
+
return coords
|
90 |
+
|
91 |
+
def apply_boxes_torch(
|
92 |
+
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
93 |
+
) -> torch.Tensor:
|
94 |
+
"""
|
95 |
+
Expects a torch tensor with shape Bx4. Requires the original image
|
96 |
+
size in (H, W) format.
|
97 |
+
"""
|
98 |
+
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
99 |
+
return boxes.reshape(-1, 4)
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def get_preprocess_shape(
|
103 |
+
oldh: int, oldw: int, long_side_length: int
|
104 |
+
) -> Tuple[int, int]:
|
105 |
+
"""
|
106 |
+
Compute the output size given input size and target long side length.
|
107 |
+
"""
|
108 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
109 |
+
newh, neww = oldh * scale, oldw * scale
|
110 |
+
neww = int(neww + 0.5)
|
111 |
+
newh = int(newh + 0.5)
|
112 |
+
return (newh, neww)
|
iopaint/tests/test_sdxl.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from iopaint.tests.utils import check_device, current_dir
|
4 |
+
|
5 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
6 |
+
|
7 |
+
import pytest
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from iopaint.model_manager import ModelManager
|
11 |
+
from iopaint.schema import HDStrategy, SDSampler, FREEUConfig
|
12 |
+
from iopaint.tests.test_model import get_config, assert_equal
|
13 |
+
|
14 |
+
|
15 |
+
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
16 |
+
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
17 |
+
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
18 |
+
def test_sdxl(device, strategy, sampler):
|
19 |
+
sd_steps = check_device(device)
|
20 |
+
|
21 |
+
model = ModelManager(
|
22 |
+
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
23 |
+
device=torch.device(device),
|
24 |
+
disable_nsfw=True,
|
25 |
+
sd_cpu_textencoder=False,
|
26 |
+
)
|
27 |
+
cfg = get_config(
|
28 |
+
strategy=strategy,
|
29 |
+
prompt="face of a fox, sitting on a bench",
|
30 |
+
sd_steps=sd_steps,
|
31 |
+
sd_strength=1.0,
|
32 |
+
sd_guidance_scale=7.0,
|
33 |
+
)
|
34 |
+
cfg.sd_sampler = sampler
|
35 |
+
|
36 |
+
assert_equal(
|
37 |
+
model,
|
38 |
+
cfg,
|
39 |
+
f"sdxl_device_{device}.png",
|
40 |
+
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
41 |
+
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
42 |
+
fx=2,
|
43 |
+
fy=2,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
48 |
+
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
49 |
+
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
50 |
+
def test_sdxl_cpu_text_encoder(device, strategy, sampler):
|
51 |
+
sd_steps = check_device(device)
|
52 |
+
|
53 |
+
model = ModelManager(
|
54 |
+
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
55 |
+
device=torch.device(device),
|
56 |
+
disable_nsfw=True,
|
57 |
+
sd_cpu_textencoder=True,
|
58 |
+
)
|
59 |
+
cfg = get_config(
|
60 |
+
strategy=strategy,
|
61 |
+
prompt="face of a fox, sitting on a bench",
|
62 |
+
sd_steps=sd_steps,
|
63 |
+
sd_strength=1.0,
|
64 |
+
sd_guidance_scale=7.0,
|
65 |
+
)
|
66 |
+
cfg.sd_sampler = sampler
|
67 |
+
|
68 |
+
assert_equal(
|
69 |
+
model,
|
70 |
+
cfg,
|
71 |
+
f"sdxl_device_{device}.png",
|
72 |
+
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
73 |
+
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
74 |
+
fx=2,
|
75 |
+
fy=2,
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
80 |
+
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
81 |
+
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
82 |
+
def test_sdxl_lcm_lora_and_freeu(device, strategy, sampler):
|
83 |
+
sd_steps = check_device(device)
|
84 |
+
|
85 |
+
model = ModelManager(
|
86 |
+
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
87 |
+
device=torch.device(device),
|
88 |
+
disable_nsfw=True,
|
89 |
+
sd_cpu_textencoder=False,
|
90 |
+
)
|
91 |
+
cfg = get_config(
|
92 |
+
strategy=strategy,
|
93 |
+
prompt="face of a fox, sitting on a bench",
|
94 |
+
sd_steps=sd_steps,
|
95 |
+
sd_strength=1.0,
|
96 |
+
sd_guidance_scale=2.0,
|
97 |
+
sd_lcm_lora=True,
|
98 |
+
)
|
99 |
+
cfg.sd_sampler = sampler
|
100 |
+
|
101 |
+
name = f"device_{device}_{sampler}"
|
102 |
+
|
103 |
+
assert_equal(
|
104 |
+
model,
|
105 |
+
cfg,
|
106 |
+
f"sdxl_{name}_lcm_lora.png",
|
107 |
+
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
108 |
+
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
109 |
+
fx=2,
|
110 |
+
fy=2,
|
111 |
+
)
|
112 |
+
|
113 |
+
cfg = get_config(
|
114 |
+
strategy=strategy,
|
115 |
+
prompt="face of a fox, sitting on a bench",
|
116 |
+
sd_steps=sd_steps,
|
117 |
+
sd_guidance_scale=7.5,
|
118 |
+
sd_freeu=True,
|
119 |
+
sd_freeu_config=FREEUConfig(),
|
120 |
+
)
|
121 |
+
|
122 |
+
assert_equal(
|
123 |
+
model,
|
124 |
+
cfg,
|
125 |
+
f"sdxl_{name}_freeu_device_{device}.png",
|
126 |
+
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
127 |
+
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
128 |
+
fx=2,
|
129 |
+
fy=2,
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
134 |
+
@pytest.mark.parametrize(
|
135 |
+
"rect",
|
136 |
+
[
|
137 |
+
[-128, -128, 1024, 1024],
|
138 |
+
],
|
139 |
+
)
|
140 |
+
def test_sdxl_outpainting(device, rect):
|
141 |
+
sd_steps = check_device(device)
|
142 |
+
|
143 |
+
model = ModelManager(
|
144 |
+
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
145 |
+
device=torch.device(device),
|
146 |
+
disable_nsfw=True,
|
147 |
+
sd_cpu_textencoder=False,
|
148 |
+
)
|
149 |
+
|
150 |
+
cfg = get_config(
|
151 |
+
strategy=HDStrategy.ORIGINAL,
|
152 |
+
prompt="a dog sitting on a bench in the park",
|
153 |
+
sd_steps=sd_steps,
|
154 |
+
use_extender=True,
|
155 |
+
extender_x=rect[0],
|
156 |
+
extender_y=rect[1],
|
157 |
+
extender_width=rect[2],
|
158 |
+
extender_height=rect[3],
|
159 |
+
sd_strength=1.0,
|
160 |
+
sd_guidance_scale=8.0,
|
161 |
+
sd_sampler=SDSampler.ddim,
|
162 |
+
)
|
163 |
+
|
164 |
+
assert_equal(
|
165 |
+
model,
|
166 |
+
cfg,
|
167 |
+
f"sdxl_outpainting_dog_ddim_{'_'.join(map(str, rect))}_device_{device}.png",
|
168 |
+
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
169 |
+
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
170 |
+
fx=1.5,
|
171 |
+
fy=1.5,
|
172 |
+
)
|
iopaint/tests/utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import cv2
|
3 |
+
import pytest
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from iopaint.helper import encode_pil_to_base64
|
7 |
+
from iopaint.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
current_dir = Path(__file__).parent.absolute().resolve()
|
11 |
+
save_dir = current_dir / "result"
|
12 |
+
save_dir.mkdir(exist_ok=True, parents=True)
|
13 |
+
|
14 |
+
|
15 |
+
def check_device(device: str) -> int:
|
16 |
+
if device == "cuda" and not torch.cuda.is_available():
|
17 |
+
pytest.skip("CUDA is not available, skip test on cuda")
|
18 |
+
if device == "mps" and not torch.backends.mps.is_available():
|
19 |
+
pytest.skip("mps is not available, skip test on mps")
|
20 |
+
steps = 2 if device == "cpu" else 20
|
21 |
+
return steps
|
22 |
+
|
23 |
+
|
24 |
+
def assert_equal(
|
25 |
+
model,
|
26 |
+
config: InpaintRequest,
|
27 |
+
gt_name,
|
28 |
+
fx: float = 1,
|
29 |
+
fy: float = 1,
|
30 |
+
img_p=current_dir / "image.png",
|
31 |
+
mask_p=current_dir / "mask.png",
|
32 |
+
):
|
33 |
+
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
34 |
+
print(f"Input image shape: {img.shape}")
|
35 |
+
res = model(img, mask, config)
|
36 |
+
ok = cv2.imwrite(
|
37 |
+
str(save_dir / gt_name),
|
38 |
+
res,
|
39 |
+
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
40 |
+
)
|
41 |
+
assert ok, save_dir / gt_name
|
42 |
+
|
43 |
+
"""
|
44 |
+
Note that JPEG is lossy compression, so even if it is the highest quality 100,
|
45 |
+
when the saved images is reloaded, a difference occurs with the original pixel value.
|
46 |
+
If you want to save the original images as it is, save it as PNG or BMP.
|
47 |
+
"""
|
48 |
+
# gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED)
|
49 |
+
# assert np.array_equal(res, gt)
|
50 |
+
|
51 |
+
|
52 |
+
def get_data(
|
53 |
+
fx: float = 1,
|
54 |
+
fy: float = 1.0,
|
55 |
+
img_p=current_dir / "image.png",
|
56 |
+
mask_p=current_dir / "mask.png",
|
57 |
+
):
|
58 |
+
img = cv2.imread(str(img_p))
|
59 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
60 |
+
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
61 |
+
img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
62 |
+
mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
|
63 |
+
return img, mask
|
64 |
+
|
65 |
+
|
66 |
+
def get_config(**kwargs):
|
67 |
+
data = dict(
|
68 |
+
sd_sampler=kwargs.get("sd_sampler", SDSampler.uni_pc),
|
69 |
+
ldm_steps=1,
|
70 |
+
ldm_sampler=LDMSampler.plms,
|
71 |
+
hd_strategy=kwargs.get("strategy", HDStrategy.ORIGINAL),
|
72 |
+
hd_strategy_crop_margin=32,
|
73 |
+
hd_strategy_crop_trigger_size=200,
|
74 |
+
hd_strategy_resize_limit=200,
|
75 |
+
)
|
76 |
+
data.update(**kwargs)
|
77 |
+
return InpaintRequest(image="", mask="", **data)
|
iopaint/web_config.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from iopaint.schema import (
|
6 |
+
Device,
|
7 |
+
InteractiveSegModel,
|
8 |
+
RemoveBGModel,
|
9 |
+
RealESRGANModel,
|
10 |
+
ApiConfig,
|
11 |
+
)
|
12 |
+
|
13 |
+
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
14 |
+
|
15 |
+
from datetime import datetime
|
16 |
+
from json import JSONDecodeError
|
17 |
+
|
18 |
+
import gradio as gr
|
19 |
+
from iopaint.download import scan_models
|
20 |
+
from loguru import logger
|
21 |
+
|
22 |
+
from iopaint.const import *
|
23 |
+
|
24 |
+
|
25 |
+
_config_file: Path = None
|
26 |
+
|
27 |
+
|
28 |
+
default_configs = dict(
|
29 |
+
host="127.0.0.1",
|
30 |
+
port=8080,
|
31 |
+
inbrowser=True,
|
32 |
+
model=DEFAULT_MODEL,
|
33 |
+
model_dir=DEFAULT_MODEL_DIR,
|
34 |
+
no_half=False,
|
35 |
+
low_mem=False,
|
36 |
+
cpu_offload=False,
|
37 |
+
disable_nsfw_checker=False,
|
38 |
+
local_files_only=False,
|
39 |
+
cpu_textencoder=False,
|
40 |
+
device=Device.cuda,
|
41 |
+
input=None,
|
42 |
+
output_dir=None,
|
43 |
+
quality=95,
|
44 |
+
enable_interactive_seg=False,
|
45 |
+
interactive_seg_model=InteractiveSegModel.vit_b,
|
46 |
+
interactive_seg_device=Device.cpu,
|
47 |
+
enable_remove_bg=False,
|
48 |
+
remove_bg_model=RemoveBGModel.briaai_rmbg_1_4,
|
49 |
+
enable_anime_seg=False,
|
50 |
+
enable_realesrgan=False,
|
51 |
+
realesrgan_device=Device.cpu,
|
52 |
+
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
|
53 |
+
enable_gfpgan=False,
|
54 |
+
gfpgan_device=Device.cpu,
|
55 |
+
enable_restoreformer=False,
|
56 |
+
restoreformer_device=Device.cpu,
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
class WebConfig(ApiConfig):
|
61 |
+
model_dir: str = DEFAULT_MODEL_DIR
|
62 |
+
|
63 |
+
|
64 |
+
def load_config(p: Path) -> WebConfig:
|
65 |
+
if p.exists():
|
66 |
+
with open(p, "r", encoding="utf-8") as f:
|
67 |
+
try:
|
68 |
+
return WebConfig(**{**default_configs, **json.load(f)})
|
69 |
+
except JSONDecodeError:
|
70 |
+
print(f"Load config file failed, using default configs")
|
71 |
+
return WebConfig(**default_configs)
|
72 |
+
else:
|
73 |
+
return WebConfig(**default_configs)
|
74 |
+
|
75 |
+
|
76 |
+
def save_config(
|
77 |
+
host,
|
78 |
+
port,
|
79 |
+
model,
|
80 |
+
model_dir,
|
81 |
+
no_half,
|
82 |
+
low_mem,
|
83 |
+
cpu_offload,
|
84 |
+
disable_nsfw_checker,
|
85 |
+
local_files_only,
|
86 |
+
cpu_textencoder,
|
87 |
+
device,
|
88 |
+
input,
|
89 |
+
output_dir,
|
90 |
+
quality,
|
91 |
+
enable_interactive_seg,
|
92 |
+
interactive_seg_model,
|
93 |
+
interactive_seg_device,
|
94 |
+
enable_remove_bg,
|
95 |
+
remove_bg_model,
|
96 |
+
enable_anime_seg,
|
97 |
+
enable_realesrgan,
|
98 |
+
realesrgan_device,
|
99 |
+
realesrgan_model,
|
100 |
+
enable_gfpgan,
|
101 |
+
gfpgan_device,
|
102 |
+
enable_restoreformer,
|
103 |
+
restoreformer_device,
|
104 |
+
inbrowser,
|
105 |
+
):
|
106 |
+
config = WebConfig(**locals())
|
107 |
+
if str(config.input) == ".":
|
108 |
+
config.input = None
|
109 |
+
if str(config.output_dir) == ".":
|
110 |
+
config.output_dir = None
|
111 |
+
config.model = config.model.strip()
|
112 |
+
print(config.model_dump_json(indent=4))
|
113 |
+
if config.input and not os.path.exists(config.input):
|
114 |
+
return "[Error] Input file or directory does not exist"
|
115 |
+
|
116 |
+
current_time = datetime.now().strftime("%H:%M:%S")
|
117 |
+
msg = f"[{current_time}] Successful save config to: {str(_config_file.absolute())}"
|
118 |
+
logger.info(msg)
|
119 |
+
try:
|
120 |
+
with open(_config_file, "w", encoding="utf-8") as f:
|
121 |
+
f.write(config.model_dump_json(indent=4))
|
122 |
+
except Exception as e:
|
123 |
+
return f"Save configure file failed: {str(e)}"
|
124 |
+
return msg
|
125 |
+
|
126 |
+
|
127 |
+
def change_current_model(new_model):
|
128 |
+
return new_model
|
129 |
+
|
130 |
+
|
131 |
+
def main(config_file: Path):
|
132 |
+
global _config_file
|
133 |
+
_config_file = config_file
|
134 |
+
|
135 |
+
init_config = load_config(config_file)
|
136 |
+
downloaded_models = [it.name for it in scan_models()]
|
137 |
+
|
138 |
+
with gr.Blocks() as demo:
|
139 |
+
with gr.Row():
|
140 |
+
with gr.Column():
|
141 |
+
gr.Textbox(config_file, label="Config file", interactive=False)
|
142 |
+
with gr.Column():
|
143 |
+
save_btn = gr.Button(value="Save configurations")
|
144 |
+
message = gr.HTML()
|
145 |
+
|
146 |
+
with gr.Tabs():
|
147 |
+
with gr.Tab("Common"):
|
148 |
+
with gr.Row():
|
149 |
+
host = gr.Textbox(init_config.host, label="Host")
|
150 |
+
port = gr.Number(init_config.port, label="Port", precision=0)
|
151 |
+
inbrowser = gr.Checkbox(init_config.inbrowser, label=INBROWSER_HELP)
|
152 |
+
|
153 |
+
with gr.Column():
|
154 |
+
model = gr.Textbox(
|
155 |
+
init_config.model,
|
156 |
+
label="Current Model. This is the model that will be used when the service starts. "
|
157 |
+
"If the model has not been downloaded before, it will be automatically downloaded. "
|
158 |
+
"You can select a model from the dropdown box below or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.",
|
159 |
+
)
|
160 |
+
with gr.Row():
|
161 |
+
recommend_model = gr.Dropdown(
|
162 |
+
["lama", "mat", "migan"] + DIFFUSION_MODELS,
|
163 |
+
label="Recommended Models",
|
164 |
+
)
|
165 |
+
downloaded_model = gr.Dropdown(
|
166 |
+
downloaded_models, label="Downloaded Models"
|
167 |
+
)
|
168 |
+
|
169 |
+
device = gr.Radio(
|
170 |
+
Device.values(), label="Device", value=init_config.device
|
171 |
+
)
|
172 |
+
quality = gr.Slider(
|
173 |
+
value=95,
|
174 |
+
label=f"Image Quality ({QUALITY_HELP})",
|
175 |
+
minimum=75,
|
176 |
+
maximum=100,
|
177 |
+
step=1,
|
178 |
+
)
|
179 |
+
|
180 |
+
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
|
181 |
+
cpu_offload = gr.Checkbox(
|
182 |
+
init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}"
|
183 |
+
)
|
184 |
+
low_mem = gr.Checkbox(init_config.low_mem, label=f"{LOW_MEM_HELP}")
|
185 |
+
cpu_textencoder = gr.Checkbox(
|
186 |
+
init_config.cpu_textencoder, label=f"{CPU_TEXTENCODER_HELP}"
|
187 |
+
)
|
188 |
+
disable_nsfw_checker = gr.Checkbox(
|
189 |
+
init_config.disable_nsfw_checker, label=f"{DISABLE_NSFW_HELP}"
|
190 |
+
)
|
191 |
+
local_files_only = gr.Checkbox(
|
192 |
+
init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}"
|
193 |
+
)
|
194 |
+
|
195 |
+
with gr.Column():
|
196 |
+
model_dir = gr.Textbox(
|
197 |
+
init_config.model_dir, label=f"{MODEL_DIR_HELP}"
|
198 |
+
)
|
199 |
+
input = gr.Textbox(
|
200 |
+
init_config.input,
|
201 |
+
label=f"Input file or directory. {INPUT_HELP}",
|
202 |
+
)
|
203 |
+
output_dir = gr.Textbox(
|
204 |
+
init_config.output_dir,
|
205 |
+
label=f"Output directory. {OUTPUT_DIR_HELP}",
|
206 |
+
)
|
207 |
+
|
208 |
+
with gr.Tab("Plugins"):
|
209 |
+
with gr.Row():
|
210 |
+
enable_interactive_seg = gr.Checkbox(
|
211 |
+
init_config.enable_interactive_seg, label=INTERACTIVE_SEG_HELP
|
212 |
+
)
|
213 |
+
interactive_seg_model = gr.Radio(
|
214 |
+
InteractiveSegModel.values(),
|
215 |
+
label=f"Segment Anything models. {INTERACTIVE_SEG_MODEL_HELP}",
|
216 |
+
value=init_config.interactive_seg_model,
|
217 |
+
)
|
218 |
+
interactive_seg_device = gr.Radio(
|
219 |
+
Device.values(),
|
220 |
+
label="Segment Anything Device",
|
221 |
+
value=init_config.interactive_seg_device,
|
222 |
+
)
|
223 |
+
with gr.Row():
|
224 |
+
enable_remove_bg = gr.Checkbox(
|
225 |
+
init_config.enable_remove_bg, label=REMOVE_BG_HELP
|
226 |
+
)
|
227 |
+
remove_bg_model = gr.Radio(
|
228 |
+
RemoveBGModel.values(),
|
229 |
+
label="Remove bg model",
|
230 |
+
value=init_config.remove_bg_model,
|
231 |
+
)
|
232 |
+
with gr.Row():
|
233 |
+
enable_anime_seg = gr.Checkbox(
|
234 |
+
init_config.enable_anime_seg, label=ANIMESEG_HELP
|
235 |
+
)
|
236 |
+
|
237 |
+
with gr.Row():
|
238 |
+
enable_realesrgan = gr.Checkbox(
|
239 |
+
init_config.enable_realesrgan, label=REALESRGAN_HELP
|
240 |
+
)
|
241 |
+
realesrgan_device = gr.Radio(
|
242 |
+
Device.values(),
|
243 |
+
label="RealESRGAN Device",
|
244 |
+
value=init_config.realesrgan_device,
|
245 |
+
)
|
246 |
+
realesrgan_model = gr.Radio(
|
247 |
+
RealESRGANModel.values(),
|
248 |
+
label="RealESRGAN model",
|
249 |
+
value=init_config.realesrgan_model,
|
250 |
+
)
|
251 |
+
with gr.Row():
|
252 |
+
enable_gfpgan = gr.Checkbox(
|
253 |
+
init_config.enable_gfpgan, label=GFPGAN_HELP
|
254 |
+
)
|
255 |
+
gfpgan_device = gr.Radio(
|
256 |
+
Device.values(),
|
257 |
+
label="GFPGAN Device",
|
258 |
+
value=init_config.gfpgan_device,
|
259 |
+
)
|
260 |
+
with gr.Row():
|
261 |
+
enable_restoreformer = gr.Checkbox(
|
262 |
+
init_config.enable_restoreformer, label=RESTOREFORMER_HELP
|
263 |
+
)
|
264 |
+
restoreformer_device = gr.Radio(
|
265 |
+
Device.values(),
|
266 |
+
label="RestoreFormer Device",
|
267 |
+
value=init_config.restoreformer_device,
|
268 |
+
)
|
269 |
+
|
270 |
+
downloaded_model.change(change_current_model, [downloaded_model], model)
|
271 |
+
recommend_model.change(change_current_model, [recommend_model], model)
|
272 |
+
|
273 |
+
save_btn.click(
|
274 |
+
save_config,
|
275 |
+
[
|
276 |
+
host,
|
277 |
+
port,
|
278 |
+
model,
|
279 |
+
model_dir,
|
280 |
+
no_half,
|
281 |
+
low_mem,
|
282 |
+
cpu_offload,
|
283 |
+
disable_nsfw_checker,
|
284 |
+
local_files_only,
|
285 |
+
cpu_textencoder,
|
286 |
+
device,
|
287 |
+
input,
|
288 |
+
output_dir,
|
289 |
+
quality,
|
290 |
+
enable_interactive_seg,
|
291 |
+
interactive_seg_model,
|
292 |
+
interactive_seg_device,
|
293 |
+
enable_remove_bg,
|
294 |
+
remove_bg_model,
|
295 |
+
enable_anime_seg,
|
296 |
+
enable_realesrgan,
|
297 |
+
realesrgan_device,
|
298 |
+
realesrgan_model,
|
299 |
+
enable_gfpgan,
|
300 |
+
gfpgan_device,
|
301 |
+
enable_restoreformer,
|
302 |
+
restoreformer_device,
|
303 |
+
inbrowser,
|
304 |
+
],
|
305 |
+
message,
|
306 |
+
)
|
307 |
+
demo.launch(inbrowser=True, show_api=False)
|
pretrained-model/version.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1
|
pretrained-model/version_diffusers_cache.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1
|
utils/tools.py
ADDED
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import yaml
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
def pil_loader(path):
|
11 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
12 |
+
with open(path, 'rb') as f:
|
13 |
+
img = Image.open(f)
|
14 |
+
return img.convert('RGB')
|
15 |
+
|
16 |
+
|
17 |
+
def default_loader(path):
|
18 |
+
return pil_loader(path)
|
19 |
+
|
20 |
+
|
21 |
+
def tensor_img_to_npimg(tensor_img):
|
22 |
+
"""
|
23 |
+
Turn a tensor image with shape CxHxW to a numpy array image with shape HxWxC
|
24 |
+
:param tensor_img:
|
25 |
+
:return: a numpy array image with shape HxWxC
|
26 |
+
"""
|
27 |
+
if not (torch.is_tensor(tensor_img) and tensor_img.ndimension() == 3):
|
28 |
+
raise NotImplementedError("Not supported tensor image. Only tensors with dimension CxHxW are supported.")
|
29 |
+
npimg = np.transpose(tensor_img.numpy(), (1, 2, 0))
|
30 |
+
npimg = npimg.squeeze()
|
31 |
+
assert isinstance(npimg, np.ndarray) and (npimg.ndim in {2, 3})
|
32 |
+
return npimg
|
33 |
+
|
34 |
+
|
35 |
+
# Change the values of tensor x from range [0, 1] to [-1, 1]
|
36 |
+
def normalize(x):
|
37 |
+
return x.mul_(2).add_(-1)
|
38 |
+
|
39 |
+
def same_padding(images, ksizes, strides, rates):
|
40 |
+
assert len(images.size()) == 4
|
41 |
+
batch_size, channel, rows, cols = images.size()
|
42 |
+
out_rows = (rows + strides[0] - 1) // strides[0]
|
43 |
+
out_cols = (cols + strides[1] - 1) // strides[1]
|
44 |
+
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
|
45 |
+
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
|
46 |
+
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
|
47 |
+
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
|
48 |
+
# Pad the input
|
49 |
+
padding_top = int(padding_rows / 2.)
|
50 |
+
padding_left = int(padding_cols / 2.)
|
51 |
+
padding_bottom = padding_rows - padding_top
|
52 |
+
padding_right = padding_cols - padding_left
|
53 |
+
paddings = (padding_left, padding_right, padding_top, padding_bottom)
|
54 |
+
images = torch.nn.ZeroPad2d(paddings)(images)
|
55 |
+
return images
|
56 |
+
|
57 |
+
|
58 |
+
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
|
59 |
+
"""
|
60 |
+
Extract patches from images and put them in the C output dimension.
|
61 |
+
:param padding:
|
62 |
+
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
|
63 |
+
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
|
64 |
+
each dimension of images
|
65 |
+
:param strides: [stride_rows, stride_cols]
|
66 |
+
:param rates: [dilation_rows, dilation_cols]
|
67 |
+
:return: A Tensor
|
68 |
+
"""
|
69 |
+
assert len(images.size()) == 4
|
70 |
+
assert padding in ['same', 'valid']
|
71 |
+
batch_size, channel, height, width = images.size()
|
72 |
+
|
73 |
+
if padding == 'same':
|
74 |
+
images = same_padding(images, ksizes, strides, rates)
|
75 |
+
elif padding == 'valid':
|
76 |
+
pass
|
77 |
+
else:
|
78 |
+
raise NotImplementedError('Unsupported padding type: {}.\
|
79 |
+
Only "same" or "valid" are supported.'.format(padding))
|
80 |
+
|
81 |
+
unfold = torch.nn.Unfold(kernel_size=ksizes,
|
82 |
+
dilation=rates,
|
83 |
+
padding=0,
|
84 |
+
stride=strides)
|
85 |
+
patches = unfold(images)
|
86 |
+
return patches # [N, C*k*k, L], L is the total number of such blocks
|
87 |
+
|
88 |
+
|
89 |
+
def random_bbox(config, batch_size):
|
90 |
+
"""Generate a random tlhw with configuration.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
config: Config should have configuration including img
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
tuple: (top, left, height, width)
|
97 |
+
|
98 |
+
"""
|
99 |
+
img_height, img_width, _ = config['image_shape']
|
100 |
+
h, w = config['mask_shape']
|
101 |
+
margin_height, margin_width = config['margin']
|
102 |
+
maxt = img_height - margin_height - h
|
103 |
+
maxl = img_width - margin_width - w
|
104 |
+
bbox_list = []
|
105 |
+
if config['mask_batch_same']:
|
106 |
+
t = np.random.randint(margin_height, maxt)
|
107 |
+
l = np.random.randint(margin_width, maxl)
|
108 |
+
bbox_list.append((t, l, h, w))
|
109 |
+
bbox_list = bbox_list * batch_size
|
110 |
+
else:
|
111 |
+
for i in range(batch_size):
|
112 |
+
t = np.random.randint(margin_height, maxt)
|
113 |
+
l = np.random.randint(margin_width, maxl)
|
114 |
+
bbox_list.append((t, l, h, w))
|
115 |
+
|
116 |
+
return torch.tensor(bbox_list, dtype=torch.int64)
|
117 |
+
|
118 |
+
|
119 |
+
def test_random_bbox():
|
120 |
+
image_shape = [256, 256, 3]
|
121 |
+
mask_shape = [128, 128]
|
122 |
+
margin = [0, 0]
|
123 |
+
bbox = random_bbox(image_shape)
|
124 |
+
return bbox
|
125 |
+
|
126 |
+
|
127 |
+
def bbox2mask(bboxes, height, width, max_delta_h, max_delta_w):
|
128 |
+
batch_size = bboxes.size(0)
|
129 |
+
mask = torch.zeros((batch_size, 1, height, width), dtype=torch.float32)
|
130 |
+
for i in range(batch_size):
|
131 |
+
bbox = bboxes[i]
|
132 |
+
delta_h = np.random.randint(max_delta_h // 2 + 1)
|
133 |
+
delta_w = np.random.randint(max_delta_w // 2 + 1)
|
134 |
+
mask[i, :, bbox[0] + delta_h:bbox[0] + bbox[2] - delta_h, bbox[1] + delta_w:bbox[1] + bbox[3] - delta_w] = 1.
|
135 |
+
return mask
|
136 |
+
|
137 |
+
|
138 |
+
def test_bbox2mask():
|
139 |
+
image_shape = [256, 256, 3]
|
140 |
+
mask_shape = [128, 128]
|
141 |
+
margin = [0, 0]
|
142 |
+
max_delta_shape = [32, 32]
|
143 |
+
bbox = random_bbox(image_shape)
|
144 |
+
mask = bbox2mask(bbox, image_shape[0], image_shape[1], max_delta_shape[0], max_delta_shape[1])
|
145 |
+
return mask
|
146 |
+
|
147 |
+
|
148 |
+
def local_patch(x, bbox_list):
|
149 |
+
assert len(x.size()) == 4
|
150 |
+
patches = []
|
151 |
+
for i, bbox in enumerate(bbox_list):
|
152 |
+
t, l, h, w = bbox
|
153 |
+
patches.append(x[i, :, t:t + h, l:l + w])
|
154 |
+
return torch.stack(patches, dim=0)
|
155 |
+
|
156 |
+
|
157 |
+
def mask_image(x, bboxes, config):
|
158 |
+
height, width, _ = config['image_shape']
|
159 |
+
max_delta_h, max_delta_w = config['max_delta_shape']
|
160 |
+
mask = bbox2mask(bboxes, height, width, max_delta_h, max_delta_w)
|
161 |
+
if x.is_cuda:
|
162 |
+
mask = mask.cuda()
|
163 |
+
|
164 |
+
if config['mask_type'] == 'hole':
|
165 |
+
result = x * (1. - mask)
|
166 |
+
elif config['mask_type'] == 'mosaic':
|
167 |
+
# TODO: Matching the mosaic patch size and the mask size
|
168 |
+
mosaic_unit_size = config['mosaic_unit_size']
|
169 |
+
downsampled_image = F.interpolate(x, scale_factor=1. / mosaic_unit_size, mode='nearest')
|
170 |
+
upsampled_image = F.interpolate(downsampled_image, size=(height, width), mode='nearest')
|
171 |
+
result = upsampled_image * mask + x * (1. - mask)
|
172 |
+
else:
|
173 |
+
raise NotImplementedError('Not implemented mask type.')
|
174 |
+
|
175 |
+
return result, mask
|
176 |
+
|
177 |
+
|
178 |
+
def spatial_discounting_mask(config):
|
179 |
+
"""Generate spatial discounting mask constant.
|
180 |
+
|
181 |
+
Spatial discounting mask is first introduced in publication:
|
182 |
+
Generative Image Inpainting with Contextual Attention, Yu et al.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
config: Config should have configuration including HEIGHT, WIDTH,
|
186 |
+
DISCOUNTED_MASK.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
tf.Tensor: spatial discounting mask
|
190 |
+
|
191 |
+
"""
|
192 |
+
gamma = config['spatial_discounting_gamma']
|
193 |
+
height, width = config['mask_shape']
|
194 |
+
shape = [1, 1, height, width]
|
195 |
+
if config['discounted_mask']:
|
196 |
+
mask_values = np.ones((height, width))
|
197 |
+
for i in range(height):
|
198 |
+
for j in range(width):
|
199 |
+
mask_values[i, j] = max(
|
200 |
+
gamma ** min(i, height - i),
|
201 |
+
gamma ** min(j, width - j))
|
202 |
+
mask_values = np.expand_dims(mask_values, 0)
|
203 |
+
mask_values = np.expand_dims(mask_values, 0)
|
204 |
+
else:
|
205 |
+
mask_values = np.ones(shape)
|
206 |
+
spatial_discounting_mask_tensor = torch.tensor(mask_values, dtype=torch.float32)
|
207 |
+
if config['cuda']:
|
208 |
+
spatial_discounting_mask_tensor = spatial_discounting_mask_tensor.cuda()
|
209 |
+
return spatial_discounting_mask_tensor
|
210 |
+
|
211 |
+
|
212 |
+
def reduce_mean(x, axis=None, keepdim=False):
|
213 |
+
if not axis:
|
214 |
+
axis = range(len(x.shape))
|
215 |
+
for i in sorted(axis, reverse=True):
|
216 |
+
x = torch.mean(x, dim=i, keepdim=keepdim)
|
217 |
+
return x
|
218 |
+
|
219 |
+
|
220 |
+
def reduce_std(x, axis=None, keepdim=False):
|
221 |
+
if not axis:
|
222 |
+
axis = range(len(x.shape))
|
223 |
+
for i in sorted(axis, reverse=True):
|
224 |
+
x = torch.std(x, dim=i, keepdim=keepdim)
|
225 |
+
return x
|
226 |
+
|
227 |
+
|
228 |
+
def reduce_sum(x, axis=None, keepdim=False):
|
229 |
+
if not axis:
|
230 |
+
axis = range(len(x.shape))
|
231 |
+
for i in sorted(axis, reverse=True):
|
232 |
+
x = torch.sum(x, dim=i, keepdim=keepdim)
|
233 |
+
return x
|
234 |
+
|
235 |
+
|
236 |
+
def flow_to_image(flow):
|
237 |
+
"""Transfer flow map to image.
|
238 |
+
Part of code forked from flownet.
|
239 |
+
"""
|
240 |
+
out = []
|
241 |
+
maxu = -999.
|
242 |
+
maxv = -999.
|
243 |
+
minu = 999.
|
244 |
+
minv = 999.
|
245 |
+
maxrad = -1
|
246 |
+
for i in range(flow.shape[0]):
|
247 |
+
u = flow[i, :, :, 0]
|
248 |
+
v = flow[i, :, :, 1]
|
249 |
+
idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
|
250 |
+
u[idxunknow] = 0
|
251 |
+
v[idxunknow] = 0
|
252 |
+
maxu = max(maxu, np.max(u))
|
253 |
+
minu = min(minu, np.min(u))
|
254 |
+
maxv = max(maxv, np.max(v))
|
255 |
+
minv = min(minv, np.min(v))
|
256 |
+
rad = np.sqrt(u ** 2 + v ** 2)
|
257 |
+
maxrad = max(maxrad, np.max(rad))
|
258 |
+
u = u / (maxrad + np.finfo(float).eps)
|
259 |
+
v = v / (maxrad + np.finfo(float).eps)
|
260 |
+
img = compute_color(u, v)
|
261 |
+
out.append(img)
|
262 |
+
return np.float32(np.uint8(out))
|
263 |
+
|
264 |
+
|
265 |
+
def pt_flow_to_image(flow):
|
266 |
+
"""Transfer flow map to image.
|
267 |
+
Part of code forked from flownet.
|
268 |
+
"""
|
269 |
+
out = []
|
270 |
+
maxu = torch.tensor(-999)
|
271 |
+
maxv = torch.tensor(-999)
|
272 |
+
minu = torch.tensor(999)
|
273 |
+
minv = torch.tensor(999)
|
274 |
+
maxrad = torch.tensor(-1)
|
275 |
+
if torch.cuda.is_available():
|
276 |
+
maxu = maxu.cuda()
|
277 |
+
maxv = maxv.cuda()
|
278 |
+
minu = minu.cuda()
|
279 |
+
minv = minv.cuda()
|
280 |
+
maxrad = maxrad.cuda()
|
281 |
+
for i in range(flow.shape[0]):
|
282 |
+
u = flow[i, 0, :, :]
|
283 |
+
v = flow[i, 1, :, :]
|
284 |
+
idxunknow = (torch.abs(u) > 1e7) + (torch.abs(v) > 1e7)
|
285 |
+
u[idxunknow] = 0
|
286 |
+
v[idxunknow] = 0
|
287 |
+
maxu = torch.max(maxu, torch.max(u))
|
288 |
+
minu = torch.min(minu, torch.min(u))
|
289 |
+
maxv = torch.max(maxv, torch.max(v))
|
290 |
+
minv = torch.min(minv, torch.min(v))
|
291 |
+
rad = torch.sqrt((u ** 2 + v ** 2).float()).to(torch.int64)
|
292 |
+
maxrad = torch.max(maxrad, torch.max(rad))
|
293 |
+
u = u / (maxrad + torch.finfo(torch.float32).eps)
|
294 |
+
v = v / (maxrad + torch.finfo(torch.float32).eps)
|
295 |
+
# TODO: change the following to pytorch
|
296 |
+
img = pt_compute_color(u, v)
|
297 |
+
out.append(img)
|
298 |
+
|
299 |
+
return torch.stack(out, dim=0)
|
300 |
+
|
301 |
+
|
302 |
+
def highlight_flow(flow):
|
303 |
+
"""Convert flow into middlebury color code image.
|
304 |
+
"""
|
305 |
+
out = []
|
306 |
+
s = flow.shape
|
307 |
+
for i in range(flow.shape[0]):
|
308 |
+
img = np.ones((s[1], s[2], 3)) * 144.
|
309 |
+
u = flow[i, :, :, 0]
|
310 |
+
v = flow[i, :, :, 1]
|
311 |
+
for h in range(s[1]):
|
312 |
+
for w in range(s[1]):
|
313 |
+
ui = u[h, w]
|
314 |
+
vi = v[h, w]
|
315 |
+
img[ui, vi, :] = 255.
|
316 |
+
out.append(img)
|
317 |
+
return np.float32(np.uint8(out))
|
318 |
+
|
319 |
+
|
320 |
+
def pt_highlight_flow(flow):
|
321 |
+
"""Convert flow into middlebury color code image.
|
322 |
+
"""
|
323 |
+
out = []
|
324 |
+
s = flow.shape
|
325 |
+
for i in range(flow.shape[0]):
|
326 |
+
img = np.ones((s[1], s[2], 3)) * 144.
|
327 |
+
u = flow[i, :, :, 0]
|
328 |
+
v = flow[i, :, :, 1]
|
329 |
+
for h in range(s[1]):
|
330 |
+
for w in range(s[1]):
|
331 |
+
ui = u[h, w]
|
332 |
+
vi = v[h, w]
|
333 |
+
img[ui, vi, :] = 255.
|
334 |
+
out.append(img)
|
335 |
+
return np.float32(np.uint8(out))
|
336 |
+
|
337 |
+
|
338 |
+
def compute_color(u, v):
|
339 |
+
h, w = u.shape
|
340 |
+
img = np.zeros([h, w, 3])
|
341 |
+
nanIdx = np.isnan(u) | np.isnan(v)
|
342 |
+
u[nanIdx] = 0
|
343 |
+
v[nanIdx] = 0
|
344 |
+
# colorwheel = COLORWHEEL
|
345 |
+
colorwheel = make_color_wheel()
|
346 |
+
ncols = np.size(colorwheel, 0)
|
347 |
+
rad = np.sqrt(u ** 2 + v ** 2)
|
348 |
+
a = np.arctan2(-v, -u) / np.pi
|
349 |
+
fk = (a + 1) / 2 * (ncols - 1) + 1
|
350 |
+
k0 = np.floor(fk).astype(int)
|
351 |
+
k1 = k0 + 1
|
352 |
+
k1[k1 == ncols + 1] = 1
|
353 |
+
f = fk - k0
|
354 |
+
for i in range(np.size(colorwheel, 1)):
|
355 |
+
tmp = colorwheel[:, i]
|
356 |
+
col0 = tmp[k0 - 1] / 255
|
357 |
+
col1 = tmp[k1 - 1] / 255
|
358 |
+
col = (1 - f) * col0 + f * col1
|
359 |
+
idx = rad <= 1
|
360 |
+
col[idx] = 1 - rad[idx] * (1 - col[idx])
|
361 |
+
notidx = np.logical_not(idx)
|
362 |
+
col[notidx] *= 0.75
|
363 |
+
img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
|
364 |
+
return img
|
365 |
+
|
366 |
+
|
367 |
+
def pt_compute_color(u, v):
|
368 |
+
h, w = u.shape
|
369 |
+
img = torch.zeros([3, h, w])
|
370 |
+
if torch.cuda.is_available():
|
371 |
+
img = img.cuda()
|
372 |
+
nanIdx = (torch.isnan(u) + torch.isnan(v)) != 0
|
373 |
+
u[nanIdx] = 0.
|
374 |
+
v[nanIdx] = 0.
|
375 |
+
# colorwheel = COLORWHEEL
|
376 |
+
colorwheel = pt_make_color_wheel()
|
377 |
+
if torch.cuda.is_available():
|
378 |
+
colorwheel = colorwheel.cuda()
|
379 |
+
ncols = colorwheel.size()[0]
|
380 |
+
rad = torch.sqrt((u ** 2 + v ** 2).to(torch.float32))
|
381 |
+
a = torch.atan2(-v.to(torch.float32), -u.to(torch.float32)) / np.pi
|
382 |
+
fk = (a + 1) / 2 * (ncols - 1) + 1
|
383 |
+
k0 = torch.floor(fk).to(torch.int64)
|
384 |
+
k1 = k0 + 1
|
385 |
+
k1[k1 == ncols + 1] = 1
|
386 |
+
f = fk - k0.to(torch.float32)
|
387 |
+
for i in range(colorwheel.size()[1]):
|
388 |
+
tmp = colorwheel[:, i]
|
389 |
+
col0 = tmp[k0 - 1]
|
390 |
+
col1 = tmp[k1 - 1]
|
391 |
+
col = (1 - f) * col0 + f * col1
|
392 |
+
idx = rad <= 1. / 255.
|
393 |
+
col[idx] = 1 - rad[idx] * (1 - col[idx])
|
394 |
+
notidx = (idx != 0)
|
395 |
+
col[notidx] *= 0.75
|
396 |
+
img[i, :, :] = col * (1 - nanIdx).to(torch.float32)
|
397 |
+
return img
|
398 |
+
|
399 |
+
|
400 |
+
def make_color_wheel():
|
401 |
+
RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
|
402 |
+
ncols = RY + YG + GC + CB + BM + MR
|
403 |
+
colorwheel = np.zeros([ncols, 3])
|
404 |
+
col = 0
|
405 |
+
# RY
|
406 |
+
colorwheel[0:RY, 0] = 255
|
407 |
+
colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
|
408 |
+
col += RY
|
409 |
+
# YG
|
410 |
+
colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
|
411 |
+
colorwheel[col:col + YG, 1] = 255
|
412 |
+
col += YG
|
413 |
+
# GC
|
414 |
+
colorwheel[col:col + GC, 1] = 255
|
415 |
+
colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
|
416 |
+
col += GC
|
417 |
+
# CB
|
418 |
+
colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
|
419 |
+
colorwheel[col:col + CB, 2] = 255
|
420 |
+
col += CB
|
421 |
+
# BM
|
422 |
+
colorwheel[col:col + BM, 2] = 255
|
423 |
+
colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
|
424 |
+
col += + BM
|
425 |
+
# MR
|
426 |
+
colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
|
427 |
+
colorwheel[col:col + MR, 0] = 255
|
428 |
+
return colorwheel
|
429 |
+
|
430 |
+
|
431 |
+
def pt_make_color_wheel():
|
432 |
+
RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
|
433 |
+
ncols = RY + YG + GC + CB + BM + MR
|
434 |
+
colorwheel = torch.zeros([ncols, 3])
|
435 |
+
col = 0
|
436 |
+
# RY
|
437 |
+
colorwheel[0:RY, 0] = 1.
|
438 |
+
colorwheel[0:RY, 1] = torch.arange(0, RY, dtype=torch.float32) / RY
|
439 |
+
col += RY
|
440 |
+
# YG
|
441 |
+
colorwheel[col:col + YG, 0] = 1. - (torch.arange(0, YG, dtype=torch.float32) / YG)
|
442 |
+
colorwheel[col:col + YG, 1] = 1.
|
443 |
+
col += YG
|
444 |
+
# GC
|
445 |
+
colorwheel[col:col + GC, 1] = 1.
|
446 |
+
colorwheel[col:col + GC, 2] = torch.arange(0, GC, dtype=torch.float32) / GC
|
447 |
+
col += GC
|
448 |
+
# CB
|
449 |
+
colorwheel[col:col + CB, 1] = 1. - (torch.arange(0, CB, dtype=torch.float32) / CB)
|
450 |
+
colorwheel[col:col + CB, 2] = 1.
|
451 |
+
col += CB
|
452 |
+
# BM
|
453 |
+
colorwheel[col:col + BM, 2] = 1.
|
454 |
+
colorwheel[col:col + BM, 0] = torch.arange(0, BM, dtype=torch.float32) / BM
|
455 |
+
col += BM
|
456 |
+
# MR
|
457 |
+
colorwheel[col:col + MR, 2] = 1. - (torch.arange(0, MR, dtype=torch.float32) / MR)
|
458 |
+
colorwheel[col:col + MR, 0] = 1.
|
459 |
+
return colorwheel
|
460 |
+
|
461 |
+
|
462 |
+
def is_image_file(filename):
|
463 |
+
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
|
464 |
+
filename_lower = filename.lower()
|
465 |
+
return any(filename_lower.endswith(extension) for extension in IMG_EXTENSIONS)
|
466 |
+
|
467 |
+
|
468 |
+
def deprocess(img):
|
469 |
+
img = img.add_(1).div_(2)
|
470 |
+
return img
|
471 |
+
|
472 |
+
|
473 |
+
# get configs
|
474 |
+
def get_config(config):
|
475 |
+
with open(config, 'r') as stream:
|
476 |
+
return yaml.load(stream,Loader=yaml.Loader)
|
477 |
+
|
478 |
+
|
479 |
+
# Get model list for resume
|
480 |
+
def get_model_list(dirname, key, iteration=0):
|
481 |
+
if os.path.exists(dirname) is False:
|
482 |
+
return None
|
483 |
+
gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
|
484 |
+
os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f]
|
485 |
+
if gen_models is None:
|
486 |
+
return None
|
487 |
+
gen_models.sort()
|
488 |
+
if iteration == 0:
|
489 |
+
last_model_name = gen_models[-1]
|
490 |
+
else:
|
491 |
+
for model_name in gen_models:
|
492 |
+
if '{:0>8d}'.format(iteration) in model_name:
|
493 |
+
return model_name
|
494 |
+
raise ValueError('Not found models with this iteration')
|
495 |
+
return last_model_name
|
496 |
+
|
497 |
+
|
498 |
+
if __name__ == '__main__':
|
499 |
+
test_random_bbox()
|
500 |
+
mask = test_bbox2mask()
|
501 |
+
print(mask.shape)
|
502 |
+
import matplotlib.pyplot as plt
|
503 |
+
|
504 |
+
plt.imshow(mask, cmap='gray')
|
505 |
+
plt.show()
|