""" """ import torch import matplotlib.pyplot as plt import numpy as np import torch.nn.functional as nnf import torchvision import einops import matplotlib.pyplot as plt import scipy.stats as st from PIL import Image, ImageFont, ImageDraw plt.rcParams["figure.figsize"] = [ float(v) * 1.5 for v in plt.rcParams["figure.figsize"] ] class CrossAttnPainter: def __init__(self, bundle, pipe, root="/tmp"): self.dim = 64 self.folder = def plot_frames(self): folder = "/tmp" from PIL import Image for i, f in enumerate(video_frames): img = Image.fromarray(f) filepath = os.path.join(folder, "recons.{:04d}.jpg".format(i)) img.save(filepath) def plot_spatial_attn(self): arr = ( pipe.unet.up_blocks[1] .attentions[0] .transformer_blocks[0] .attn2.processor.cross_attention_map ) heads = pipe.unet.up_blocks[1].attentions[0].transformer_blocks[0].attn2.heads arr = torch.transpose(arr, 1, 3) arr = nnf.interpolate(arr, size=(64, 64), mode='bicubic', align_corners=False) arr = torch.transpose(arr, 1, 3) arr = arr.cpu().numpy() arr = arr.reshape(24, heads, 64, 64, 77) arr = arr.mean(axis=1) n = arr.shape[0] for i in range(n): filename = "/tmp/spatialca.{:04d}.jpg".format(i) plt.clf() plt.imshow(arr[i, :, :, 2], cmap="jet") plt.gca().set_axis_off() plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) plt.margins(0,0) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.savefig(filename, bbox_inches = 'tight',pad_inches = 0) print(filename) def plot_temporal_attn(self): # arr = pipe.unet.mid_block.temp_attentions[0].transformer_blocks[0].attn2.processor.cross_attention_map import matplotlib.pyplot as plt import torch.nn.functional as nnf arr = ( pipe.unet.up_blocks[2] .temp_attentions[1] .transformer_blocks[0] .attn2.processor.cross_attention_map ) #arr = pipe.unet.transformer_in.transformer_blocks[0].attn2.processor.cross_attention_map arr = torch.transpose(arr, 0, 2).transpose(1, 3) arr = nnf.interpolate(arr, size=(64, 64), mode="bicubic", align_corners=False) arr = torch.transpose(arr, 0, 2).transpose(1, 3) arr = arr.cpu().numpy() n = arr.shape[-1] for i in range(n-2): filename = "/tmp/tempcaiip2.{:04d}.jpg".format(i) plt.clf() plt.imshow(arr[..., i+2, i], cmap="jet") plt.gca().set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.savefig(filename, bbox_inches="tight", pad_inches=0) print(filename) def plot_latent_noise(latents, mode): for i in range(latents.shape[0]): tensor = latents[i].cpu() min_val = torch.min(tensor) max_val = torch.max(tensor) scale = 255 * (max_val - min_val) tensor = scale * (tensor - min_val) tensor = tensor.type(torch.int8) tensor = einops.rearrange(tensor, "c w h -> w h c") if mode == "RGB": tensor = tensor[...,:3] mode_ = "RGB" elif mode == "RGBA": mode_ = "RGBA" pass elif mode == "GRAY": tensor = tensor[...,0] mode_ = "L" x = tensor.numpy() img = Image.fromarray(x, mode_) img = img.resize((256, 256), resample=Image.NEAREST ) filepath = f"/tmp/out.{i:04d}.jpg" img.save(filepath) tensor = latents[i].cpu() x = tensor.flatten().numpy() x /= x.max() plt.hist(x, density=True, bins=20, range=[-1, 1]) mn, mx = plt.xlim() plt.xlim(mn, mx) kde_xs = np.linspace(mn, mx, 300) kde = st.gaussian_kde(x) plt.plot(kde_xs, kde.pdf(kde_xs), label="PDF") filepath = f"/tmp/hist.{i:04d}.jpg" plt.savefig(filepath) plt.clf() print(i) def plot_activation(cross_attn, prompt, filepath="", plot_with_trailings=False, n_trailing=2): splitted_prompt = prompt.split(" ") n = len(splitted_prompt) start = 0 arrs = [] if plot_with_trailings: for j in range(n_trailing): arr = [] for i in range(start, start + n): cross_attn_sliced = cross_attn[..., i + 1] arr.append(cross_attn_sliced.T) start += n arr = np.hstack(arr) arrs.append(arr) arrs = np.vstack(arrs).T else: arr = [] for i in range(start, start + n): cross_attn_sliced = cross_attn[..., i + 1] arr.append(cross_attn_sliced) arrs = np.vstack(arr) plt.imshow(arrs, cmap="jet", vmin=0.0, vmax=.5) plt.title(prompt) if filepath: plt.savefig(filepath) else: plt.show() def draw_dd_metadata(img, bbox, text="", target_res=1024): img = img.resize((target_res, target_res)) image_editable = ImageDraw.Draw(img) for region in [bbox]: x0 = region[0] * target_res y0 = region[2] * target_res x1 = region[1] * target_res y1 = region[3] * target_res image_editable.rectangle(xy=[x0, y0, x1, y1], outline=(255, 0, 0, 255), width=5) if text: font = ImageFont.truetype("./assets/JetBrainsMono-Bold.ttf", size=13) image_editable.multiline_text( (15, 15), text, (255, 255, 255, 0), font=font, stroke_width=2, stroke_fill=(0, 0, 0, 255), spacing=0, ) return img if __name__ == "__main__": latents = torch.load("assets/experiments/a-cat-sitting-on-a-car_230615-144611/latents.pt") plot_latent_noise(latents, "GRAY")