File size: 19,348 Bytes
6cf191b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
#############################
#   Imports
#############################

# Python modules
from typing import List

# Remote modules
import numpy as np
import torch

# Local modules
from kgs_binding.relation_mapper_builder import RelationsMapperBuilder
from kgs_binding.kg_qa_binding_utils import load_kg_handler
from data.relation_utils import clean_relations
from model_utils import create_layers_head_mask

from transformers import (
    BartForConditionalGeneration,
    BartTokenizer,
    BartConfig,
    DisjunctiveConstraint,
)

from utils import get_jump_chunks

#############################
#   Constants
#############################

#############################
#   Stuff
#############################
from custom_tokenizer import BartCustomTokenizerFast
from custom_bart import BartCustomConfig, BartCustomForConditionalGeneration
from utils import get_device, KGType, Model_Type

from kgs_binding.kg_base_wrapper import KGBaseHandler
from kgs_binding.swow_handler import SwowHandler
from kgs_binding.conceptnet_handler import ConceptNetHandler

class Inference:
    def __init__(self, model_path:str, max_length=32):
        self.device = get_device()
        self.tokenizer = self.prepare_tokenizer()
        self.model = self.prepare_model(model_path)
        self.max_length = max_length

    def prepare_tokenizer(self):
        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
        return tokenizer

    def prepare_model(self, model_path):
        config = BartConfig.from_pretrained(model_path)
        model = BartForConditionalGeneration.from_pretrained(model_path, config=config).to(self.device)
        model.eval()
        return model

    def pre_process_context(self, context):
        context = context.lower()
        context_tokenized = self.tokenizer(context, padding='max_length',
                                truncation='longest_first',  max_length=self.max_length,
                                return_tensors="pt",
                                )
        return context_tokenized

    def generate_based_on_context(self, context):
        model_input = self.pre_process_context(context)
        generated_answers_encoded = self.model.generate(input_ids=model_input["input_ids"].to(self.device),
                                                   attention_mask=model_input["attention_mask"].to(self.device),
                                                   min_length=1,
                                                   max_length=self.max_length,
                                                   do_sample=True,
                                                   early_stopping=True,
                                                   num_beams=4,
                                                   temperature=1.0,
                                                   top_k=None,
                                                   top_p=None,
                                                   # eos_token_id=tokenizer.eos_token_id,
                                                   no_repeat_ngram_size=2,
                                                   num_return_sequences=1,
                                                   return_dict_in_generate=True,
                                                   output_attentions=True,
                                                   output_scores=True)
        # print(f'Scores: {generated_answers_encoded}')
        response = self.tokenizer.batch_decode(generated_answers_encoded['sequences'], skip_special_tokens=True,
                                          clean_up_tokenization_spaces=True)
        encoder_attentions = generated_answers_encoded['encoder_attentions']
        return response, encoder_attentions, model_input

    def prepare_context_for_visualization(self, context):
        examples = []
        response, encoder_outputs, model_input = self.generate_based_on_context(context)
        encoder_outputs = torch.stack(encoder_outputs)
        n_layers, batch_size, n_heads, src, tgt = encoder_outputs.size()
        print(encoder_outputs.size())
        encoder_attentions = encoder_outputs.view(batch_size, n_layers, n_heads, src, tgt)
        for i, ex in enumerate(encoder_attentions):
            d = {}
            indices = model_input['input_ids'][i].detach().cpu()
            all_tokens = self.tokenizer.convert_ids_to_tokens(indices)
            useful_indeces = indices != self.tokenizer.pad_token_id
            all_tokens = np.array(all_tokens)[useful_indeces]
            all_tokens = [tok.replace('Ġ', '') for tok in all_tokens]
            d['words'] = all_tokens
            d['attentions'] = ex.detach().cpu().numpy()
            examples.append(d)
        print(d['words'])
        return response, examples

class RelationsInference:
    def __init__(self, model_path:str, kg_type: KGType, model_type:Model_Type, max_length=32):
        self.device = get_device()
        kg_handler: KGBaseHandler = load_kg_handler(kg_type)
        self.kg_handler = kg_handler
        relation_names = kg_handler.get_relation_types()
        self.tokenizer = self.prepare_tokenizer(relation_names, model_type)
        self.simple_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
        self.model, self.config = self.prepare_model(relation_names, model_path, model_type)
        self.relation_mapper_builder = RelationsMapperBuilder(knowledge=kg_handler)
        self.max_length = max_length

    def prepare_tokenizer(self, relation_names: List[str], model_type:Model_Type):
        tokenizer = BartCustomTokenizerFast.from_pretrained('facebook/bart-large')
        tokenizer.set_known_relation_names(relation_names)
        tokenizer.set_operation_mode(there_is_difference_between_relations=model_type.there_is_difference_between_relations())
        return tokenizer

    def prepare_model(self, relation_names: List[str], model_path, model_type:Model_Type):
        config = BartCustomConfig.from_pretrained(model_path, revision='master')
        print('config.heads_mask:', config.heads_mask)
        if config.num_relation_kinds is None:
            config.num_relation_kinds = len(relation_names)
        if config.is_simple_mask_commonsense is None:
            config.is_simple_mask_commonsense = model_type.is_simple_mask_commonsense()
        if config.heads_mask is None:
            config.heads_mask = create_layers_head_mask(config)#, heads_mask_type, specific_heads)
        model = BartCustomForConditionalGeneration.from_pretrained(model_path, config=config, revision='master').to(self.device)
        model.eval()
        return model, config

    def pre_process_context(self, context):
        context = context.lower()
        # process context in search for relations
        commonsense_relations = self.relation_mapper_builder.get_relations_mapping_complex(context=[context], clear_common_wds=True)
        # clean relation
        commonsense_relation = clean_relations(commonsense_relations)[0]
        # convert this relations to matrices
        print(commonsense_relation)
        context_tokenized = self.tokenizer(context, padding='max_length',
                                truncation='longest_first',  max_length=self.max_length,
                                return_tensors="pt", return_offsets_mapping=True,
                                input_commonsense_relations=commonsense_relation,
        )
        return context_tokenized

    def get_relations_information(self, phrase_generated):
        all_concepts = self.relation_mapper_builder.get_kg_concepts_from_context([phrase_generated], clear_common_wds=True)[0]
        words = phrase_generated.strip().split(' ') # all words
        concepts_with_relations = self.relation_mapper_builder.get_concepts_from_context(phrase_generated, clear_common_wds=True)
        concepts_with_no_relations = list(set(all_concepts).difference(concepts_with_relations))
        #print('without_relations:', concepts_with_no_relations)
        print("====== RELATIONS SUMMARY ======")
        print('phrase_generated:', phrase_generated)
        print('words:', words)
        print('all_concepts:', all_concepts)
        print('concepts_with_relations:', concepts_with_relations)
        print('without_relations:', concepts_with_no_relations)
        print("\n== STATS:")
        print('n_words:', len(words))
        print('n_concepts:', len(all_concepts))
        print('n_concepts_with_relations:', len(concepts_with_relations))
        print('n_c_without_relations:', len(concepts_with_no_relations))
        print("====== ================= ======")
        return words, all_concepts, concepts_with_relations, concepts_with_no_relations

    def remove_subsets(self, l):
        l2 = l[:]
        for m in l:
            for n in l:
                if set(m).issubset(set(n)) and m != n:
                    l2.remove(m)
                    break
        return l2

    def generate_based_on_context(self, context, use_kg=False):
        model_input = self.pre_process_context(context)
        #print(model_input)
        gen_kwargs = {}
        if "input_commonsense_relations" in model_input:
            #print(model_input['input_commonsense_relations'].sum())
            gen_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations").to(self.device)

        constraints = None
        if use_kg:
            constraints = []
            concepts_from_context = self.relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True)
            useful_concepts = [self.relation_mapper_builder.knowledge.get_related_concepts(concept) for concept in concepts_from_context]
            if not useful_concepts:
                useful_concepts = [self.kg_handler.get_related_concepts(concept) for concept in concepts_from_context]
            useful_concepts = [[f'{phrase}' for phrase in concepts] for concepts in useful_concepts] # add spaces
            #useful_concepts = [[phrase for phrase in concepts if len(phrase.split(' ')) == 1] for concepts in useful_concepts]
            #useful_concepts = list(itertools.chain.from_iterable(useful_concepts))
            #print('useful_concepts:', useful_concepts[:5])
            if concepts_from_context:
                for context_concept, neighbour_concepts in zip(concepts_from_context, useful_concepts):
                    print('neighbour:', neighbour_concepts[:20])
                    #flexible_words = self.most_similar_words(context_concept, neighbour_concepts) # limit the upperbound
                    #flexible_words = [word for word in flexible_words if word not in context_concept] # remove input concepts
                    flexible_words = [word for word in neighbour_concepts if word not in context_concept]  # remove input concepts
                    flexible_words_ids: List[List[int]] = self.simple_tokenizer(flexible_words, add_prefix_space=True,add_special_tokens=False).input_ids
                    flexible_words_ids = self.remove_subsets(flexible_words_ids)
                    #add_prefix_space=True
                    #flexible_words_ids = [x for x in flexible_words_ids if len(x) == 1] # problem with subsets
                    flexible_words_ids = flexible_words_ids[:10]
                    print('flexible_words_ids:', flexible_words_ids[:3])
                    constraint = DisjunctiveConstraint(flexible_words_ids)
                    constraints.append(constraint)
            else:
                constraints = None

        generated_answers_encoded = self.model.generate(input_ids=model_input["input_ids"].to(self.device),
                                                   attention_mask=model_input["attention_mask"].to(self.device),
                                                   constraints=constraints,
                                                   min_length=1,
                                                   max_length=self.max_length,
                                                   do_sample=False,
                                                   early_stopping=True,
                                                   num_beams=8,
                                                   temperature=1.0,
                                                   top_k=None,
                                                   top_p=None,
                                                   # eos_token_id=tokenizer.eos_token_id,
                                                   no_repeat_ngram_size=2,
                                                   num_return_sequences=1,
                                                   return_dict_in_generate=True,
                                                   output_attentions=True,
                                                   output_scores=True,
                                                   **gen_kwargs,
                                                   )
        # print(f'Scores: {generated_answers_encoded}')
        response = self.tokenizer.batch_decode(generated_answers_encoded['sequences'], skip_special_tokens=True,
                                          clean_up_tokenization_spaces=True)
        encoder_attentions = generated_answers_encoded['encoder_attentions']
        return response, encoder_attentions, model_input

    def get_related_concepts_list(self, knowledge, list_concepts):
        other_concepts = []
        for concept in list_concepts:
            other_near_concepts = knowledge.get_related_concepts(concept)
            other_concepts.extend(other_near_concepts)
        return other_concepts


    def generate_contrained_based_on_context(self, contexts, use_kg=True, max_concepts=1):
        model_inputs = [self.pre_process_context(context) for context in contexts]
        constraints = None
        if use_kg:
            constraints = []
            concepts_from_contexts = [self.relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True) for context in contexts]
            neighbours_contexts = []#[self.get_related_concepts_list(self.relation_mapper_builder.knowledge, context) for context in concepts_from_contexts]
            if not neighbours_contexts:
                neighbours_contexts = [self.get_related_concepts_list(self.kg_handler, context) for context in concepts_from_contexts]
            all_constraints = []
            for context_neighbours in neighbours_contexts:
                # context_neighbours is a collection of concepts
                # lets create sub collections of concepts
                context_neighbours = [f' {concept}' for concept in context_neighbours if len(concept) > 3]
                n_size_chuncks = len(context_neighbours) // max_concepts
                n_size_chuncks = n_size_chuncks if n_size_chuncks > 0 else 1
                sub_concepts_collection = list(get_jump_chunks(context_neighbours, jump=n_size_chuncks))
                constraints = []
                for sub_concepts in sub_concepts_collection[:max_concepts]:
                    flexible_words_ids: List[List[int]] = self.tokenizer(sub_concepts,
                                                                    add_special_tokens=False).input_ids  # add_prefix_space=True,
                    # flexible_words_ids = self.remove_subsets(flexible_words_ids)
                    flexible_words_ids = [[word_ids[0]] for word_ids in flexible_words_ids]
                    disjunctive_set = list(map(list, set(map(frozenset, flexible_words_ids))))
                    if not any(disjunctive_set):
                        continue
                    constraint = DisjunctiveConstraint(disjunctive_set)
                    constraints.append(constraint)
                if not any(constraints):
                    constraints = None
                all_constraints.append(constraints)
        else:
            all_constraints = None
        if not all_constraints:
            all_constraints = None

        generated_answers_encoded = []
        encoder_attentions_list = []
        for i, contraints in enumerate(all_constraints):
            #print('contraints.token_ids:', [x.token_ids for x in contraints])
            gen_kwargs = {}
            inputs = model_inputs[i]
            if "input_commonsense_relations" in inputs:
                # print(model_input['input_commonsense_relations'].sum())
                gen_kwargs["relation_inputs"] = inputs.get("input_commonsense_relations").to(self.device)
            #print('model_kwargs.get("attention_mask"):', model_kwargs.get("attention_mask"))
            gen = self.model.generate(input_ids=inputs["input_ids"].to(self.device),
                               attention_mask=inputs["attention_mask"].to(self.device),
                               constraints=constraints,
                               min_length=1,
                               max_length=self.max_length,
                               do_sample=False,
                               early_stopping=True,
                               num_beams=8,
                               temperature=1.0,
                               top_k=None,
                               top_p=None,
                               # eos_token_id=tokenizer.eos_token_id,
                               no_repeat_ngram_size=2,
                               num_return_sequences=1,
                               return_dict_in_generate=True,
                               output_attentions=True,
                               output_scores=True,
                               **gen_kwargs,
            )
            # print('[gen]:', gen)
            # print(tokenizer.batch_decode(gen))
            generated_answers_encoded.append(gen['sequences'][0].detach().cpu())
            encoder_attentions_list.append(gen['encoder_attentions'][0].detach().cpu())
        # print(f'Scores: {generated_answers_encoded}')
        text_results = self.tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
                                          clean_up_tokenization_spaces=True)
        return text_results, encoder_attentions_list, model_inputs

    def prepare_context_for_visualization(self, context):
        examples, relations = [], []
        response, encoder_outputs, model_input = self.generate_based_on_context(context)
        input_commonsense_relations = model_input.get("input_commonsense_relations")
        encoder_outputs = torch.stack(encoder_outputs)
        n_layers, batch_size, n_heads, src, tgt = encoder_outputs.size()
        print(encoder_outputs.size())
        encoder_attentions = encoder_outputs.view(batch_size, n_layers, n_heads, src, tgt)
        for i, ex in enumerate(encoder_attentions):
            d = {}
            indices = model_input['input_ids'][i].detach().cpu()
            all_tokens = self.tokenizer.convert_ids_to_tokens(indices)
            useful_indeces = indices != self.tokenizer.pad_token_id
            all_tokens = np.array(all_tokens)[useful_indeces]
            all_tokens = [tok.replace('Ġ', '') for tok in all_tokens]
            d['words'] = all_tokens
            d['attentions'] = ex.detach().cpu().numpy()
            examples.append(d)
            relations.append(input_commonsense_relations[i])
        print(d['words'])
        return response, examples, relations