|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AttentionVisualizer: |
|
def __init__(self, device): |
|
self.device = device |
|
|
|
def visualize_token2token_scores(self, all_tokens, |
|
scores_mat, |
|
useful_indeces, |
|
x_label_name='Head', |
|
apply_normalization=True): |
|
fig = plt.figure(figsize=(20, 20)) |
|
|
|
all_tokens = np.array(all_tokens)[useful_indeces] |
|
for idx, scores in enumerate(scores_mat): |
|
if apply_normalization: |
|
scores = torch.from_numpy(scores) |
|
shape = scores.shape |
|
scores = scores.reshape((shape[0],shape[1], 1)) |
|
scores = torch.linalg.norm(scores, dim=2) |
|
scores_np = np.array(scores) |
|
scores_np = scores_np[useful_indeces, :] |
|
scores_np = scores_np[:, useful_indeces] |
|
ax = fig.add_subplot(4, 4, idx + 1) |
|
|
|
im = ax.imshow(scores_np, cmap='viridis') |
|
|
|
fontdict = {'fontsize': 10} |
|
|
|
ax.set_xticks(range(len(all_tokens))) |
|
ax.set_yticks(range(len(all_tokens))) |
|
|
|
ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90) |
|
ax.set_yticklabels(all_tokens, fontdict=fontdict) |
|
ax.set_xlabel('{} {}'.format(x_label_name, idx + 1)) |
|
|
|
fig.colorbar(im, fraction=0.046, pad=0.04) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
def visualize_matrix(self, |
|
scores_mat, |
|
label_name='heads_layers'): |
|
_fig = plt.figure(figsize=(20, 20)) |
|
scores_np = np.array(scores_mat) |
|
fig, ax = plt.subplots() |
|
im = ax.imshow(scores_np, cmap='viridis') |
|
|
|
fontdict = {'fontsize': 10} |
|
|
|
ax.set_xticks(range(len(scores_mat[0]))) |
|
ax.set_yticks(range(len(scores_mat))) |
|
|
|
x_labels = [f'head-{i}' for i in range(1, len(scores_mat[0])+1)] |
|
y_labels = [f'layer-{i}' for i in range(1, len(scores_mat) + 1)] |
|
|
|
ax.set_xticklabels(x_labels, fontdict=fontdict, rotation=90) |
|
ax.set_yticklabels(y_labels, fontdict=fontdict) |
|
ax.set_xlabel('{}'.format(label_name)) |
|
|
|
fig.colorbar(im, fraction=0.046, pad=0.04) |
|
plt.tight_layout() |
|
|
|
plt.savefig(f'figs/{label_name}.png', dpi=fig.dpi) |
|
|
|
def visualize_token2head_scores(self, all_tokens, scores_mat): |
|
fig = plt.figure(figsize=(30, 50)) |
|
for idx, scores in enumerate(scores_mat): |
|
scores_np = np.array(scores) |
|
ax = fig.add_subplot(6, 3, idx + 1) |
|
|
|
im = ax.matshow(scores_np, cmap='viridis') |
|
|
|
fontdict = {'fontsize': 20} |
|
|
|
ax.set_xticks(range(len(all_tokens))) |
|
ax.set_yticks(range(len(scores))) |
|
|
|
ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90) |
|
ax.set_yticklabels(range(len(scores[0])), fontdict=fontdict) |
|
ax.set_xlabel('Layer {}'.format(idx + 1)) |
|
|
|
fig.colorbar(im, fraction=0.046, pad=0.04) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
def plot_attn_lines(self, data, heads): |
|
"""Plots attention maps for the given example and attention heads.""" |
|
width = 3 |
|
example_sep = 3 |
|
word_height = 1 |
|
pad = 0.1 |
|
|
|
for ei, (layer, head) in enumerate(heads): |
|
yoffset = 1 |
|
xoffset = ei * width * example_sep |
|
|
|
attn = data["attns"][layer][head] |
|
attn = np.array(attn) |
|
attn /= attn.sum(axis=-1, keepdims=True) |
|
words = data["tokens"] |
|
words[0] = "..." |
|
n_words = len(words) |
|
|
|
for position, word in enumerate(words): |
|
plt.text(xoffset + 0, yoffset - position * word_height, word, |
|
ha="right", va="center") |
|
plt.text(xoffset + width, yoffset - position * word_height, word, |
|
ha="left", va="center") |
|
for i in range(1, n_words): |
|
for j in range(1, n_words): |
|
plt.plot([xoffset + pad, xoffset + width - pad], |
|
[yoffset - word_height * i, yoffset - word_height * j], |
|
color="blue", linewidth=1, alpha=attn[i, j]) |
|
|
|
def plot_attn_lines_concepts(self, title, examples, layer, head, color_words, |
|
color_from=True, width=3, example_sep=3, |
|
word_height=1, pad=0.1, hide_sep=False): |
|
|
|
plt.figure(figsize=(4, 4)) |
|
for i, example in enumerate(examples): |
|
yoffset = 0 |
|
if i == 0: |
|
yoffset += (len(examples[0]["words"]) - |
|
len(examples[1]["words"])) * word_height / 2 |
|
xoffset = i * width * example_sep |
|
attn = example["attentions"][layer][head] |
|
if hide_sep: |
|
attn = np.array(attn) |
|
attn[:, 0] = 0 |
|
attn[:, -1] = 0 |
|
attn /= attn.sum(axis=-1, keepdims=True) |
|
|
|
words = example["words"] |
|
n_words = len(words) |
|
for position, word in enumerate(words): |
|
for x, from_word in [(xoffset, True), (xoffset + width, False)]: |
|
color = "k" |
|
if from_word == color_from and word in color_words: |
|
color = "#cc0000" |
|
plt.text(x, yoffset - (position * word_height), word, |
|
ha="right" if from_word else "left", va="center", |
|
color=color) |
|
|
|
for i in range(n_words): |
|
for j in range(n_words): |
|
color = "b" |
|
if words[i if color_from else j] in color_words: |
|
color = "r" |
|
print(attn[i, j]) |
|
plt.plot([xoffset + pad, xoffset + width - pad], |
|
[yoffset - word_height * i, yoffset - word_height * j], |
|
color=color, linewidth=1, alpha=attn[i, j]) |
|
plt.axis("off") |
|
plt.title(title) |
|
plt.show() |
|
|
|
def plot_attn_lines_concepts_ids(self, title, examples, layer, head, |
|
relations_total, width=3, example_sep=3, |
|
word_height=1, pad=0.1, hide_sep=False): |
|
|
|
plt.clf() |
|
fig = plt.figure(figsize=(10, 5)) |
|
|
|
|
|
for idx, example in enumerate(examples): |
|
yoffset = 0 |
|
if idx == 0: |
|
yoffset += (len(examples[0]["words"]) - |
|
len(examples[0]["words"])) * word_height / 2 |
|
xoffset = idx * width * example_sep |
|
attn = example["attentions"][layer][head] |
|
if hide_sep: |
|
attn = np.array(attn) |
|
attn[:, 0] = 0 |
|
attn[:, -1] = 0 |
|
attn /= attn.sum(axis=-1, keepdims=True) |
|
|
|
words = example["words"] |
|
n_words = len(words) |
|
example_rel = relations_total[idx] |
|
for position, word in enumerate(words): |
|
for x, from_word in [(xoffset, True), (xoffset + width, False)]: |
|
color = "k" |
|
for y_idx, y in enumerate(words): |
|
if from_word and example_rel[position, y_idx] > 0: |
|
|
|
color = "r" |
|
if not from_word and example_rel[y_idx, position] > 0: |
|
|
|
color = "g" |
|
|
|
|
|
plt.text(x, yoffset - (position * word_height), word, |
|
ha="right" if from_word else "left", va="center", |
|
color=color) |
|
|
|
for i in range(n_words): |
|
for j in range(n_words): |
|
color = "k" |
|
|
|
if example_rel[i, j].item() > 0 and i <= j: |
|
color = "r" |
|
if example_rel[i, j].item() > 0 and i >= j: |
|
color = "g" |
|
plt.plot([xoffset + pad, xoffset + width - pad], |
|
[yoffset - word_height * i, yoffset - word_height * j], |
|
color=color, linewidth=1, alpha=attn[i, j]) |
|
|
|
plt.axis("off") |
|
plt.title(title) |
|
|
|
return fig |
|
|