############################# # Imports ############################# # Python modules # Remote modules import matplotlib.pyplot as plt import numpy as np import torch # Local modules ############################# # Constants ############################# class AttentionVisualizer: def __init__(self, device): self.device = device def visualize_token2token_scores(self, all_tokens, scores_mat, useful_indeces, x_label_name='Head', apply_normalization=True): fig = plt.figure(figsize=(20, 20)) all_tokens = np.array(all_tokens)[useful_indeces] for idx, scores in enumerate(scores_mat): if apply_normalization: scores = torch.from_numpy(scores) shape = scores.shape scores = scores.reshape((shape[0],shape[1], 1)) scores = torch.linalg.norm(scores, dim=2) scores_np = np.array(scores) scores_np = scores_np[useful_indeces, :] scores_np = scores_np[:, useful_indeces] ax = fig.add_subplot(4, 4, idx + 1) # append the attention weights im = ax.imshow(scores_np, cmap='viridis') fontdict = {'fontsize': 10} ax.set_xticks(range(len(all_tokens))) ax.set_yticks(range(len(all_tokens))) ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90) ax.set_yticklabels(all_tokens, fontdict=fontdict) ax.set_xlabel('{} {}'.format(x_label_name, idx + 1)) fig.colorbar(im, fraction=0.046, pad=0.04) plt.tight_layout() plt.show() def visualize_matrix(self, scores_mat, label_name='heads_layers'): _fig = plt.figure(figsize=(20, 20)) scores_np = np.array(scores_mat) fig, ax = plt.subplots() im = ax.imshow(scores_np, cmap='viridis') fontdict = {'fontsize': 10} ax.set_xticks(range(len(scores_mat[0]))) ax.set_yticks(range(len(scores_mat))) x_labels = [f'head-{i}' for i in range(1, len(scores_mat[0])+1)] y_labels = [f'layer-{i}' for i in range(1, len(scores_mat) + 1)] ax.set_xticklabels(x_labels, fontdict=fontdict, rotation=90) ax.set_yticklabels(y_labels, fontdict=fontdict) ax.set_xlabel('{}'.format(label_name)) fig.colorbar(im, fraction=0.046, pad=0.04) plt.tight_layout() #plt.show() plt.savefig(f'figs/{label_name}.png', dpi=fig.dpi) def visualize_token2head_scores(self, all_tokens, scores_mat): fig = plt.figure(figsize=(30, 50)) for idx, scores in enumerate(scores_mat): scores_np = np.array(scores) ax = fig.add_subplot(6, 3, idx + 1) # append the attention weights im = ax.matshow(scores_np, cmap='viridis') fontdict = {'fontsize': 20} ax.set_xticks(range(len(all_tokens))) ax.set_yticks(range(len(scores))) ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90) ax.set_yticklabels(range(len(scores[0])), fontdict=fontdict) ax.set_xlabel('Layer {}'.format(idx + 1)) fig.colorbar(im, fraction=0.046, pad=0.04) plt.tight_layout() plt.show() def plot_attn_lines(self, data, heads): """Plots attention maps for the given example and attention heads.""" width = 3 example_sep = 3 word_height = 1 pad = 0.1 for ei, (layer, head) in enumerate(heads): yoffset = 1 xoffset = ei * width * example_sep attn = data["attns"][layer][head] attn = np.array(attn) attn /= attn.sum(axis=-1, keepdims=True) words = data["tokens"] words[0] = "..." n_words = len(words) for position, word in enumerate(words): plt.text(xoffset + 0, yoffset - position * word_height, word, ha="right", va="center") plt.text(xoffset + width, yoffset - position * word_height, word, ha="left", va="center") for i in range(1, n_words): for j in range(1, n_words): plt.plot([xoffset + pad, xoffset + width - pad], [yoffset - word_height * i, yoffset - word_height * j], color="blue", linewidth=1, alpha=attn[i, j]) def plot_attn_lines_concepts(self, title, examples, layer, head, color_words, color_from=True, width=3, example_sep=3, word_height=1, pad=0.1, hide_sep=False): # examples -> {'words': tokens, 'attentions': [layer][head]} plt.figure(figsize=(4, 4)) for i, example in enumerate(examples): yoffset = 0 if i == 0: yoffset += (len(examples[0]["words"]) - len(examples[1]["words"])) * word_height / 2 xoffset = i * width * example_sep attn = example["attentions"][layer][head] if hide_sep: attn = np.array(attn) attn[:, 0] = 0 attn[:, -1] = 0 attn /= attn.sum(axis=-1, keepdims=True) words = example["words"] n_words = len(words) for position, word in enumerate(words): for x, from_word in [(xoffset, True), (xoffset + width, False)]: color = "k" if from_word == color_from and word in color_words: color = "#cc0000" plt.text(x, yoffset - (position * word_height), word, ha="right" if from_word else "left", va="center", color=color) for i in range(n_words): for j in range(n_words): color = "b" if words[i if color_from else j] in color_words: color = "r" print(attn[i, j]) plt.plot([xoffset + pad, xoffset + width - pad], [yoffset - word_height * i, yoffset - word_height * j], color=color, linewidth=1, alpha=attn[i, j]) plt.axis("off") plt.title(title) plt.show() def plot_attn_lines_concepts_ids(self, title, examples, layer, head, relations_total, width=3, example_sep=3, word_height=1, pad=0.1, hide_sep=False): # examples -> {'words': tokens, 'attentions': [layer][head]} plt.clf() fig = plt.figure(figsize=(10, 5)) # print('relations_total:', relations_total) # print(examples[0]) for idx, example in enumerate(examples): yoffset = 0 if idx == 0: yoffset += (len(examples[0]["words"]) - len(examples[0]["words"])) * word_height / 2 xoffset = idx * width * example_sep attn = example["attentions"][layer][head] if hide_sep: attn = np.array(attn) attn[:, 0] = 0 attn[:, -1] = 0 attn /= attn.sum(axis=-1, keepdims=True) words = example["words"] n_words = len(words) example_rel = relations_total[idx] for position, word in enumerate(words): for x, from_word in [(xoffset, True), (xoffset + width, False)]: color = "k" for y_idx, y in enumerate(words): if from_word and example_rel[position, y_idx] > 0: # print('outgoing', position, y_idx) color = "r" if not from_word and example_rel[y_idx, position] > 0: # print('coming', position, y_idx) color = "g" # if from_word == color_from and word in color_words: # color = "#cc0000" plt.text(x, yoffset - (position * word_height), word, ha="right" if from_word else "left", va="center", color=color) for i in range(n_words): for j in range(n_words): color = "k" # print(i,j, example_rel[i,j]) if example_rel[i, j].item() > 0 and i <= j: color = "r" if example_rel[i, j].item() > 0 and i >= j: color = "g" plt.plot([xoffset + pad, xoffset + width - pad], [yoffset - word_height * i, yoffset - word_height * j], color=color, linewidth=1, alpha=attn[i, j]) # color=color, linewidth=1, alpha=min(attn[i, j]*10,1)) plt.axis("off") plt.title(title) #plt.show() return fig