import numpy as np import plotly.graph_objects as go import json import gradio as gr from nltk.corpus import words import nltk # load files w embeddings, attention scores, and tokens vocab_embeddings = np.load('vocab_embeddings.npy') with open('vocab_attention_scores.json', 'r') as f: vocab_attention_scores = json.load(f) with open('vocab_tokens.json', 'r') as f: vocab_tokens = json.load(f) # attention scores to numpy arrs b_gen_attention = np.array([score['B-GEN'] for score in vocab_attention_scores]) i_gen_attention = np.array([score['I-GEN'] for score in vocab_attention_scores]) b_unfair_attention = np.array([score['B-UNFAIR'] for score in vocab_attention_scores]) i_unfair_attention = np.array([score['I-UNFAIR'] for score in vocab_attention_scores]) b_stereo_attention = np.array([score['B-STEREO'] for score in vocab_attention_scores]) i_stereo_attention = np.array([score['I-STEREO'] for score in vocab_attention_scores]) o_attention = np.array([score['O'] for score in vocab_attention_scores]) # Use actual O scores # remove non-dict english words, but keep subwords ## nltk.download('words') english_words = set(words.words()) filtered_indices = [i for i, token in enumerate(vocab_tokens) if token in english_words or token.startswith("##")] filtered_tokens = [vocab_tokens[i] for i in filtered_indices] b_gen_attention_filtered = b_gen_attention[filtered_indices] i_gen_attention_filtered = i_gen_attention[filtered_indices] b_unfair_attention_filtered = b_unfair_attention[filtered_indices] i_unfair_attention_filtered = i_unfair_attention[filtered_indices] b_stereo_attention_filtered = b_stereo_attention[filtered_indices] i_stereo_attention_filtered = i_stereo_attention[filtered_indices] o_attention_filtered = o_attention[filtered_indices] # plot top 500 O tokens for comparison top_500_o_indices = np.argsort(o_attention_filtered)[-500:] top_500_o_tokens = [filtered_tokens[i] for i in top_500_o_indices] o_attention_filtered_top_500 = o_attention_filtered[top_500_o_indices] # tool tip for tokens def create_hover_text(tokens, b_gen, i_gen, b_unfair, i_unfair, b_stereo, i_stereo, o_val): hover_text = [] for i in range(len(tokens)): hover_text.append( f"Token: {tokens[i]}
" f"B-GEN: {b_gen[i]:.3f}, I-GEN: {i_gen[i]:.3f}
" f"B-UNFAIR: {b_unfair[i]:.3f}, I-UNFAIR: {i_unfair[i]:.3f}
" f"B-STEREO: {b_stereo[i]:.3f}, I-STEREO: {i_stereo[i]:.3f}
" f"O: {o_val[i]:.3f}" ) return hover_text # ploting top 100 tokens for each entity def select_top_100(*data_arrays): indices_list = [] for data in data_arrays: if data is not None: top_indices = np.argsort(data)[-100:] indices_list.append(top_indices) combined_indices = np.unique(np.concatenate(indices_list)) # filter based on combined indices filtered_data = [data[combined_indices] if data is not None else None for data in data_arrays] tokens_filtered = [filtered_tokens[i] for i in combined_indices] return (*filtered_data, tokens_filtered) # plots for 1 2 and 3 D def create_plot(selected_dimensions): # plot data attention_map = { 'Generalization': b_gen_attention_filtered + i_gen_attention_filtered, 'Unfairness': b_unfair_attention_filtered + i_unfair_attention_filtered, 'Stereotype': b_stereo_attention_filtered + i_stereo_attention_filtered, } # init x, y, z so they can be moved around x_data, y_data, z_data = None, None, None # use selected dimentsions to order dimensions if len(selected_dimensions) > 0: x_data = attention_map[selected_dimensions[0]] if len(selected_dimensions) > 1: y_data = attention_map[selected_dimensions[1]] if len(selected_dimensions) > 2: z_data = attention_map[selected_dimensions[2]] # select top 100 dps for each selected dimension x_data, y_data, z_data, tokens_filtered = select_top_100(x_data, y_data, z_data) # filter the O tokens using the same dimensions o_x = attention_map[selected_dimensions[0]][top_500_o_indices] if len(selected_dimensions) > 1: o_y = attention_map[selected_dimensions[1]][top_500_o_indices] else: o_y = np.zeros_like(o_x) if len(selected_dimensions) > 2: o_z = attention_map[selected_dimensions[2]][top_500_o_indices] else: o_z = np.zeros_like(o_x) # hover text for GUS tokens classified_hover_text = create_hover_text( tokens_filtered, b_gen_attention_filtered, i_gen_attention_filtered, b_unfair_attention_filtered, i_unfair_attention_filtered, b_stereo_attention_filtered, i_stereo_attention_filtered, o_attention_filtered ) # hover text for O tokens o_hover_text = create_hover_text( top_500_o_tokens, b_gen_attention_filtered[top_500_o_indices], i_gen_attention_filtered[top_500_o_indices], b_unfair_attention_filtered[top_500_o_indices], i_unfair_attention_filtered[top_500_o_indices], b_stereo_attention_filtered[top_500_o_indices], i_stereo_attention_filtered[top_500_o_indices], o_attention_filtered_top_500 ) # plot fig = go.Figure() if x_data is not None and y_data is not None and z_data is not None: # 3d scatter plot fig.add_trace(go.Scatter3d( x=x_data, y=y_data, z=z_data, mode='markers', marker=dict( size=6, color=x_data, # color based on the x-axis data colorscale='Viridis', opacity=0.85, ), text=classified_hover_text, hoverinfo='text', name='Classified Tokens' )) # add top 500 O tags to the plot too fig.add_trace(go.Scatter3d( x=o_x, y=o_y, z=o_z, mode='markers', marker=dict( size=6, color='grey', opacity=0.5, ), text=o_hover_text, hoverinfo='text', name='O Tokens' )) elif x_data is not None and y_data is not None: # 2d scatter plot fig.add_trace(go.Scatter( x=x_data, y=y_data, mode='markers', marker=dict( size=6, color=x_data, # color based on the x-axis data colorscale='Viridis', opacity=0.85, ), text=classified_hover_text, hoverinfo='text', name='Classified Tokens' )) # add top 500 O tags to the plot too fig.add_trace(go.Scatter( x=o_x, y=o_y, mode='markers', marker=dict( size=6, color='grey', opacity=0.5, ), text=o_hover_text, hoverinfo='text', name='O Tokens' )) elif x_data is not None: # 1D scatter plot fig.add_trace(go.Scatter( x=x_data, y=np.zeros_like(x_data), mode='markers', marker=dict( size=6, color=x_data, colorscale='Viridis', opacity=0.85, ), text=classified_hover_text, hoverinfo='text', name='GUS Tokens' )) fig.add_trace(go.Scatter( x=o_x, y=np.zeros_like(o_x), mode='markers', marker=dict( size=6, color='grey', opacity=0.5, ), text=o_hover_text, hoverinfo='text', name='O Tokens' )) # update layout dynamically if x_data is not None and y_data is not None and z_data is not None: # 3D fig.update_layout( title="GUS-Net Entity Attentions Visualization", scene=dict( xaxis=dict(title=f"{selected_dimensions[0]} Attention"), yaxis=dict(title=f"{selected_dimensions[1]} Attention"), zaxis=dict(title=f"{selected_dimensions[2]} Attention"), ), margin=dict(l=0, r=0, b=0, t=40), ) elif x_data is not None and y_data is not None: # 2D fig.update_layout( title="GUS-Net Entity Attentions Visualization", xaxis_title=f"{selected_dimensions[0]} Attention", yaxis_title=f"{selected_dimensions[1]} Attention", margin=dict(l=0, r=0, b=0, t=40), ) elif x_data is not None: # 1D fig.update_layout( title="GUS-Net Entity Attentions Visualization", xaxis_title=f"{selected_dimensions[0]} Attention", margin=dict(l=0, r=0, b=0, t=40), ) return fig def get_top_tokens_for_entities(selected_dimensions): entity_map = { 'Generalization': b_gen_attention_filtered + i_gen_attention_filtered, 'Unfairness': b_unfair_attention_filtered + i_unfair_attention_filtered, 'Stereotype': b_stereo_attention_filtered + i_stereo_attention_filtered, } top_tokens_info = {} for dimension in selected_dimensions: if dimension in entity_map: attention_scores = entity_map[dimension] top_indices = np.argsort(attention_scores)[-10:] # top 10 tokens top_tokens = [filtered_tokens[i] for i in top_indices] top_scores = attention_scores[top_indices] top_tokens_info[dimension] = list(zip(top_tokens, top_scores)) return top_tokens_info def update_gradio(selected_dimensions): fig = create_plot(selected_dimensions) top_tokens_info = get_top_tokens_for_entities(selected_dimensions) formatted_top_tokens = "" for entity, tokens_info in top_tokens_info.items(): formatted_top_tokens += f"\nTop tokens for {entity}:\n" for token, score in tokens_info: formatted_top_tokens += f"Token: {token}, Attention Score: {score:.3f}\n" return fig, formatted_top_tokens def render_gradio_interface(): with gr.Blocks() as interface: with gr.Column(): dimensions_input = gr.CheckboxGroup( choices=["Generalization", "Unfairness", "Stereotype"], label="Select Dimensions to Plot", value=["Generalization", "Unfairness", "Stereotype"] # defaults to 3D ) plot_output = gr.Plot(label="Token Attention Visualization") top_tokens_output = gr.Textbox(label="Top Tokens for Each Entity Class", lines=10) dimensions_input.change( fn=update_gradio, inputs=[dimensions_input], outputs=[plot_output, top_tokens_output] ) interface.load( fn=lambda: update_gradio(["Generalization", "Unfairness", "Stereotype"]), inputs=None, outputs=[plot_output, top_tokens_output] ) return interface interface = render_gradio_interface() interface.launch()