import networkx as nx import matplotlib.pyplot as plt import jraph import jax.numpy as jnp from datasets import load_dataset import spacy dataset = load_dataset("gigant/tib_transcripts") nlp = spacy.load("en_core_web_sm") def dependency_parser(sentences): return [nlp(sentence) for sentence in sentences] def construct_dependency_graph(docs): """ docs is a list of outputs of the SpaCy dependency parser """ graphs = [] for doc in docs: nodes = [token.text for token in doc] senders = [] receivers = [] for token in doc: for child in token.children: senders.append(token.i) receivers.append(child.i) graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers}) return graphs def to_jraph(graph): nodes = graph["nodes"] s = graph["senders"] r = graph["receivers"] # Define a three node graph, each node has an integer as its feature. node_features = jnp.array([0]*len(nodes)) # We will construct a graph for which there is a directed edge between each node # and its successor. We define this with `senders` (source nodes) and `receivers` # (destination nodes). senders = jnp.array(s) receivers = jnp.array(r) # We then save the number of nodes and the number of edges. # This information is used to make running GNNs over multiple graphs # in a GraphsTuple possible. n_node = jnp.array([len(nodes)]) n_edge = jnp.array([len(s)]) return jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers, edges=None, n_node=n_node, n_edge=n_edge, globals=None) def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph: nodes, edges, receivers, senders, _, _, _ = jraph_graph nx_graph = nx.DiGraph() if nodes is None: for n in range(jraph_graph.n_node[0]): nx_graph.add_node(n) else: for n in range(jraph_graph.n_node[0]): nx_graph.add_node(n, node_feature=nodes[n]) if edges is None: for e in range(jraph_graph.n_edge[0]): nx_graph.add_edge(int(senders[e]), int(receivers[e])) else: for e in range(jraph_graph.n_edge[0]): nx_graph.add_edge( int(senders[e]), int(receivers[e]), edge_feature=edges[e]) return nx_graph def plot_graph_sentence(sentence): docs = dependency_parser([sentence]) graphs = construct_dependency_graph(docs) g = to_jraph(graphs[0]) nx_graph = convert_jraph_to_networkx_graph(g) pos = nx.spring_layout(nx_graph) plot = plt.figure(figsize=(6, 6)) nx.draw(nx_graph, pos=pos, labels={i: e for i,e in enumerate(graphs[0]["nodes"])}, with_labels = True, node_size=500, font_color='black', node_color="yellow") return plot def get_list_sentences(id): return gr.update(choices = dataset["train"][id]["transcript"].split(".")) with gr.Blocks() as demo: id = gr.Slider(maximum=len(dataset["train"]) - 1) sentence = gr.Dropdown(choices = dataset["train"][0]["transcript"].split("."), interactive = True) plot = gr.Plot() id.change(get_list_sentences, id, sentence) sentence.change(plot_graph_sentence, sentence, plot) demo.launch()