File size: 4,051 Bytes
70303d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba0e651
70303d6
 
 
 
 
 
 
 
 
ba0e651
 
70303d6
456234e
 
70303d6
 
 
 
456234e
 
 
 
 
70303d6
 
 
 
ba0e651
70303d6
 
ba0e651
70303d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba0e651
 
70303d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import List
from transformers import pipeline
from pyvis.network import Network
from functools import lru_cache
import spacy
from spacy import displacy


DEFAULT_LABEL_COLORS = {
    "ORG": "#7aecec",
    "PRODUCT": "#bfeeb7",
    "GPE": "#feca74",
    "LOC": "#ff9561",
    "PERSON": "#aa9cfc",
    "NORP": "#c887fb",
    "FACILITY": "#9cc9cc",
    "EVENT": "#ffeb80",
    "LAW": "#ff8197",
    "LANGUAGE": "#ff8197",
    "WORK_OF_ART": "#f0d0ff",
    "DATE": "#bfe1d9",
    "TIME": "#bfe1d9",
    "MONEY": "#e4e7d2",
    "QUANTITY": "#e4e7d2",
    "ORDINAL": "#e4e7d2",
    "CARDINAL": "#e4e7d2",
    "PERCENT": "#e4e7d2",
}

def generate_knowledge_graph(texts: List[str], filename: str):
    nlp = spacy.load("en_core_web_sm")
    doc = nlp("\n".join(texts).lower())
    NERs = [ent.text for ent in doc.ents]
    NER_types =  [ent.label_ for ent in doc.ents]
    for nr, nrt in zip(NERs, NER_types):
        print(nr, nrt)

    triplets = []
    for triplet in texts:
        triplets.extend(generate_partial_graph(triplet))
    print(generate_partial_graph.cache_info())
    heads = [ t["head"].lower() for t in triplets]
    tails = [ t["tail"].lower() for t in triplets]

    nodes = list(set(heads + tails))
    net = Network(directed=True, width="700px", height="700px")

    for n in nodes:
        if n in NERs:
            NER_type = NER_types[NERs.index(n)]
            if NER_type in NER_types:
                color = DEFAULT_LABEL_COLORS[NER_type]
                net.add_node(n, title=NER_type, shape="circle", color=color)
            else:
                net.add_node(n, shape="circle")
        else:
            net.add_node(n, shape="circle")

    unique_triplets = set()
    stringify_trip = lambda x : x["tail"] + x["head"] + x["type"].lower()
    for triplet in triplets:
        if stringify_trip(triplet) not in unique_triplets:
            net.add_edge(triplet["head"].lower(), triplet["tail"].lower(), title=triplet["type"], label=triplet["type"])
            unique_triplets.add(stringify_trip(triplet))

    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )
    net.set_edge_smooth('dynamic')
    net.show(filename)
    return nodes


@lru_cache
def generate_partial_graph(text: str):
    print(text[0:20], hash(text))
    triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
    a = triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
    extracted_text = triplet_extractor.tokenizer.batch_decode(a)
    extracted_triplets = extract_triplets(extracted_text[0])
    return extracted_triplets


def extract_triplets(text):
    """
    Function to parse the generated text and extract the triplets
    """
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})

    return triplets