eggarsway's picture
init
85456ff
raw
history blame
No virus
6.38 kB
"""
"""
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as nnf
import torchvision
import einops
import matplotlib.pyplot as plt
import scipy.stats as st
from PIL import Image, ImageFont, ImageDraw
plt.rcParams["figure.figsize"] = [
float(v) * 1.5 for v in plt.rcParams["figure.figsize"]
]
class CrossAttnPainter:
def __init__(self, bundle, pipe, root="/tmp"):
self.dim = 64
self.folder =
def plot_frames(self):
folder = "/tmp"
from PIL import Image
for i, f in enumerate(video_frames):
img = Image.fromarray(f)
filepath = os.path.join(folder, "recons.{:04d}.jpg".format(i))
img.save(filepath)
def plot_spatial_attn(self):
arr = (
pipe.unet.up_blocks[1]
.attentions[0]
.transformer_blocks[0]
.attn2.processor.cross_attention_map
)
heads = pipe.unet.up_blocks[1].attentions[0].transformer_blocks[0].attn2.heads
arr = torch.transpose(arr, 1, 3)
arr = nnf.interpolate(arr, size=(64, 64), mode='bicubic', align_corners=False)
arr = torch.transpose(arr, 1, 3)
arr = arr.cpu().numpy()
arr = arr.reshape(24, heads, 64, 64, 77)
arr = arr.mean(axis=1)
n = arr.shape[0]
for i in range(n):
filename = "/tmp/spatialca.{:04d}.jpg".format(i)
plt.clf()
plt.imshow(arr[i, :, :, 2], cmap="jet")
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig(filename, bbox_inches = 'tight',pad_inches = 0)
print(filename)
def plot_temporal_attn(self):
# arr = pipe.unet.mid_block.temp_attentions[0].transformer_blocks[0].attn2.processor.cross_attention_map
import matplotlib.pyplot as plt
import torch.nn.functional as nnf
arr = (
pipe.unet.up_blocks[2]
.temp_attentions[1]
.transformer_blocks[0]
.attn2.processor.cross_attention_map
)
#arr = pipe.unet.transformer_in.transformer_blocks[0].attn2.processor.cross_attention_map
arr = torch.transpose(arr, 0, 2).transpose(1, 3)
arr = nnf.interpolate(arr, size=(64, 64), mode="bicubic", align_corners=False)
arr = torch.transpose(arr, 0, 2).transpose(1, 3)
arr = arr.cpu().numpy()
n = arr.shape[-1]
for i in range(n-2):
filename = "/tmp/tempcaiip2.{:04d}.jpg".format(i)
plt.clf()
plt.imshow(arr[..., i+2, i], cmap="jet")
plt.gca().set_axis_off()
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig(filename, bbox_inches="tight", pad_inches=0)
print(filename)
def plot_latent_noise(latents, mode):
for i in range(latents.shape[0]):
tensor = latents[i].cpu()
min_val = torch.min(tensor)
max_val = torch.max(tensor)
scale = 255 * (max_val - min_val)
tensor = scale * (tensor - min_val)
tensor = tensor.type(torch.int8)
tensor = einops.rearrange(tensor, "c w h -> w h c")
if mode == "RGB":
tensor = tensor[...,:3]
mode_ = "RGB"
elif mode == "RGBA":
mode_ = "RGBA"
pass
elif mode == "GRAY":
tensor = tensor[...,0]
mode_ = "L"
x = tensor.numpy()
img = Image.fromarray(x, mode_)
img = img.resize((256, 256), resample=Image.NEAREST )
filepath = f"/tmp/out.{i:04d}.jpg"
img.save(filepath)
tensor = latents[i].cpu()
x = tensor.flatten().numpy()
x /= x.max()
plt.hist(x, density=True, bins=20, range=[-1, 1])
mn, mx = plt.xlim()
plt.xlim(mn, mx)
kde_xs = np.linspace(mn, mx, 300)
kde = st.gaussian_kde(x)
plt.plot(kde_xs, kde.pdf(kde_xs), label="PDF")
filepath = f"/tmp/hist.{i:04d}.jpg"
plt.savefig(filepath)
plt.clf()
print(i)
def plot_activation(cross_attn, prompt, filepath="", plot_with_trailings=False, n_trailing=2):
splitted_prompt = prompt.split(" ")
n = len(splitted_prompt)
start = 0
arrs = []
if plot_with_trailings:
for j in range(n_trailing):
arr = []
for i in range(start, start + n):
cross_attn_sliced = cross_attn[..., i + 1]
arr.append(cross_attn_sliced.T)
start += n
arr = np.hstack(arr)
arrs.append(arr)
arrs = np.vstack(arrs).T
else:
arr = []
for i in range(start, start + n):
cross_attn_sliced = cross_attn[..., i + 1]
arr.append(cross_attn_sliced)
arrs = np.vstack(arr)
plt.imshow(arrs, cmap="jet", vmin=0.0, vmax=.5)
plt.title(prompt)
if filepath:
plt.savefig(filepath)
else:
plt.show()
def draw_dd_metadata(img, bbox, text="", target_res=1024):
img = img.resize((target_res, target_res))
image_editable = ImageDraw.Draw(img)
for region in [bbox]:
x0 = region[0] * target_res
y0 = region[2] * target_res
x1 = region[1] * target_res
y1 = region[3] * target_res
image_editable.rectangle(xy=[x0, y0, x1, y1], outline=(255, 0, 0, 255), width=5)
if text:
font = ImageFont.truetype("./assets/JetBrainsMono-Bold.ttf", size=13)
image_editable.multiline_text(
(15, 15),
text,
(255, 255, 255, 0),
font=font,
stroke_width=2,
stroke_fill=(0, 0, 0, 255),
spacing=0,
)
return img
if __name__ == "__main__":
latents = torch.load("assets/experiments/a-cat-sitting-on-a-car_230615-144611/latents.pt")
plot_latent_noise(latents, "GRAY")