############################# # Imports ############################# # Python modules from typing import List # Remote modules import numpy as np import torch # Local modules from kgs_binding.relation_mapper_builder import RelationsMapperBuilder from kgs_binding.kg_qa_binding_utils import load_kg_handler from data.relation_utils import clean_relations from model_utils import create_layers_head_mask from transformers import ( BartForConditionalGeneration, BartTokenizer, BartConfig, DisjunctiveConstraint, ) from utils import get_jump_chunks ############################# # Constants ############################# ############################# # Stuff ############################# from custom_tokenizer import BartCustomTokenizerFast from custom_bart import BartCustomConfig, BartCustomForConditionalGeneration from utils import get_device, KGType, Model_Type from kgs_binding.kg_base_wrapper import KGBaseHandler from kgs_binding.swow_handler import SwowHandler from kgs_binding.conceptnet_handler import ConceptNetHandler class Inference: def __init__(self, model_path:str, max_length=32): self.device = get_device() self.tokenizer = self.prepare_tokenizer() self.model = self.prepare_model(model_path) self.max_length = max_length def prepare_tokenizer(self): tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') return tokenizer def prepare_model(self, model_path): config = BartConfig.from_pretrained(model_path) model = BartForConditionalGeneration.from_pretrained(model_path, config=config).to(self.device) model.eval() return model def pre_process_context(self, context): context = context.lower() context_tokenized = self.tokenizer(context, padding='max_length', truncation='longest_first', max_length=self.max_length, return_tensors="pt", ) return context_tokenized def generate_based_on_context(self, context): model_input = self.pre_process_context(context) generated_answers_encoded = self.model.generate(input_ids=model_input["input_ids"].to(self.device), attention_mask=model_input["attention_mask"].to(self.device), min_length=1, max_length=self.max_length, do_sample=True, early_stopping=True, num_beams=4, temperature=1.0, top_k=None, top_p=None, # eos_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=2, num_return_sequences=1, return_dict_in_generate=True, output_attentions=True, output_scores=True) # print(f'Scores: {generated_answers_encoded}') response = self.tokenizer.batch_decode(generated_answers_encoded['sequences'], skip_special_tokens=True, clean_up_tokenization_spaces=True) encoder_attentions = generated_answers_encoded['encoder_attentions'] return response, encoder_attentions, model_input def prepare_context_for_visualization(self, context): examples = [] response, encoder_outputs, model_input = self.generate_based_on_context(context) encoder_outputs = torch.stack(encoder_outputs) n_layers, batch_size, n_heads, src, tgt = encoder_outputs.size() print(encoder_outputs.size()) encoder_attentions = encoder_outputs.view(batch_size, n_layers, n_heads, src, tgt) for i, ex in enumerate(encoder_attentions): d = {} indices = model_input['input_ids'][i].detach().cpu() all_tokens = self.tokenizer.convert_ids_to_tokens(indices) useful_indeces = indices != self.tokenizer.pad_token_id all_tokens = np.array(all_tokens)[useful_indeces] all_tokens = [tok.replace('Ġ', '') for tok in all_tokens] d['words'] = all_tokens d['attentions'] = ex.detach().cpu().numpy() examples.append(d) print(d['words']) return response, examples class RelationsInference: def __init__(self, model_path:str, kg_type: KGType, model_type:Model_Type, max_length=32): self.device = get_device() kg_handler: KGBaseHandler = load_kg_handler(kg_type) self.kg_handler = kg_handler relation_names = kg_handler.get_relation_types() self.tokenizer = self.prepare_tokenizer(relation_names, model_type) self.simple_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') self.model, self.config = self.prepare_model(relation_names, model_path, model_type) self.relation_mapper_builder = RelationsMapperBuilder(knowledge=kg_handler) self.max_length = max_length def prepare_tokenizer(self, relation_names: List[str], model_type:Model_Type): tokenizer = BartCustomTokenizerFast.from_pretrained('facebook/bart-large') tokenizer.set_known_relation_names(relation_names) tokenizer.set_operation_mode(there_is_difference_between_relations=model_type.there_is_difference_between_relations()) return tokenizer def prepare_model(self, relation_names: List[str], model_path, model_type:Model_Type): config = BartCustomConfig.from_pretrained(model_path, revision='master') print('config.heads_mask:', config.heads_mask) if config.num_relation_kinds is None: config.num_relation_kinds = len(relation_names) if config.is_simple_mask_commonsense is None: config.is_simple_mask_commonsense = model_type.is_simple_mask_commonsense() if config.heads_mask is None: config.heads_mask = create_layers_head_mask(config)#, heads_mask_type, specific_heads) model = BartCustomForConditionalGeneration.from_pretrained(model_path, config=config, revision='master').to(self.device) model.eval() return model, config def pre_process_context(self, context): context = context.lower() # process context in search for relations commonsense_relations = self.relation_mapper_builder.get_relations_mapping_complex(context=[context], clear_common_wds=True) # clean relation commonsense_relation = clean_relations(commonsense_relations)[0] # convert this relations to matrices print(commonsense_relation) context_tokenized = self.tokenizer(context, padding='max_length', truncation='longest_first', max_length=self.max_length, return_tensors="pt", return_offsets_mapping=True, input_commonsense_relations=commonsense_relation, ) return context_tokenized def get_relations_information(self, phrase_generated): all_concepts = self.relation_mapper_builder.get_kg_concepts_from_context([phrase_generated], clear_common_wds=True)[0] words = phrase_generated.strip().split(' ') # all words concepts_with_relations = self.relation_mapper_builder.get_concepts_from_context(phrase_generated, clear_common_wds=True) concepts_with_no_relations = list(set(all_concepts).difference(concepts_with_relations)) #print('without_relations:', concepts_with_no_relations) print("====== RELATIONS SUMMARY ======") print('phrase_generated:', phrase_generated) print('words:', words) print('all_concepts:', all_concepts) print('concepts_with_relations:', concepts_with_relations) print('without_relations:', concepts_with_no_relations) print("\n== STATS:") print('n_words:', len(words)) print('n_concepts:', len(all_concepts)) print('n_concepts_with_relations:', len(concepts_with_relations)) print('n_c_without_relations:', len(concepts_with_no_relations)) print("====== ================= ======") return words, all_concepts, concepts_with_relations, concepts_with_no_relations def remove_subsets(self, l): l2 = l[:] for m in l: for n in l: if set(m).issubset(set(n)) and m != n: l2.remove(m) break return l2 def generate_based_on_context(self, context, use_kg=False): model_input = self.pre_process_context(context) #print(model_input) gen_kwargs = {} if "input_commonsense_relations" in model_input: #print(model_input['input_commonsense_relations'].sum()) gen_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations").to(self.device) constraints = None if use_kg: constraints = [] concepts_from_context = self.relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True) useful_concepts = [self.relation_mapper_builder.knowledge.get_related_concepts(concept) for concept in concepts_from_context] if not useful_concepts: useful_concepts = [self.kg_handler.get_related_concepts(concept) for concept in concepts_from_context] useful_concepts = [[f'{phrase}' for phrase in concepts] for concepts in useful_concepts] # add spaces #useful_concepts = [[phrase for phrase in concepts if len(phrase.split(' ')) == 1] for concepts in useful_concepts] #useful_concepts = list(itertools.chain.from_iterable(useful_concepts)) #print('useful_concepts:', useful_concepts[:5]) if concepts_from_context: for context_concept, neighbour_concepts in zip(concepts_from_context, useful_concepts): print('neighbour:', neighbour_concepts[:20]) #flexible_words = self.most_similar_words(context_concept, neighbour_concepts) # limit the upperbound #flexible_words = [word for word in flexible_words if word not in context_concept] # remove input concepts flexible_words = [word for word in neighbour_concepts if word not in context_concept] # remove input concepts flexible_words_ids: List[List[int]] = self.simple_tokenizer(flexible_words, add_prefix_space=True,add_special_tokens=False).input_ids flexible_words_ids = self.remove_subsets(flexible_words_ids) #add_prefix_space=True #flexible_words_ids = [x for x in flexible_words_ids if len(x) == 1] # problem with subsets flexible_words_ids = flexible_words_ids[:10] print('flexible_words_ids:', flexible_words_ids[:3]) constraint = DisjunctiveConstraint(flexible_words_ids) constraints.append(constraint) else: constraints = None generated_answers_encoded = self.model.generate(input_ids=model_input["input_ids"].to(self.device), attention_mask=model_input["attention_mask"].to(self.device), constraints=constraints, min_length=1, max_length=self.max_length, do_sample=False, early_stopping=True, num_beams=8, temperature=1.0, top_k=None, top_p=None, # eos_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=2, num_return_sequences=1, return_dict_in_generate=True, output_attentions=True, output_scores=True, **gen_kwargs, ) # print(f'Scores: {generated_answers_encoded}') response = self.tokenizer.batch_decode(generated_answers_encoded['sequences'], skip_special_tokens=True, clean_up_tokenization_spaces=True) encoder_attentions = generated_answers_encoded['encoder_attentions'] return response, encoder_attentions, model_input def get_related_concepts_list(self, knowledge, list_concepts): other_concepts = [] for concept in list_concepts: other_near_concepts = knowledge.get_related_concepts(concept) other_concepts.extend(other_near_concepts) return other_concepts def generate_contrained_based_on_context(self, contexts, use_kg=True, max_concepts=1): model_inputs = [self.pre_process_context(context) for context in contexts] constraints = None if use_kg: constraints = [] concepts_from_contexts = [self.relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True) for context in contexts] neighbours_contexts = []#[self.get_related_concepts_list(self.relation_mapper_builder.knowledge, context) for context in concepts_from_contexts] if not neighbours_contexts: neighbours_contexts = [self.get_related_concepts_list(self.kg_handler, context) for context in concepts_from_contexts] all_constraints = [] for context_neighbours in neighbours_contexts: # context_neighbours is a collection of concepts # lets create sub collections of concepts context_neighbours = [f' {concept}' for concept in context_neighbours if len(concept) > 3] n_size_chuncks = len(context_neighbours) // max_concepts n_size_chuncks = n_size_chuncks if n_size_chuncks > 0 else 1 sub_concepts_collection = list(get_jump_chunks(context_neighbours, jump=n_size_chuncks)) constraints = [] for sub_concepts in sub_concepts_collection[:max_concepts]: flexible_words_ids: List[List[int]] = self.tokenizer(sub_concepts, add_special_tokens=False).input_ids # add_prefix_space=True, # flexible_words_ids = self.remove_subsets(flexible_words_ids) flexible_words_ids = [[word_ids[0]] for word_ids in flexible_words_ids] disjunctive_set = list(map(list, set(map(frozenset, flexible_words_ids)))) if not any(disjunctive_set): continue constraint = DisjunctiveConstraint(disjunctive_set) constraints.append(constraint) if not any(constraints): constraints = None all_constraints.append(constraints) else: all_constraints = None if not all_constraints: all_constraints = None generated_answers_encoded = [] encoder_attentions_list = [] for i, contraints in enumerate(all_constraints): #print('contraints.token_ids:', [x.token_ids for x in contraints]) gen_kwargs = {} inputs = model_inputs[i] if "input_commonsense_relations" in inputs: # print(model_input['input_commonsense_relations'].sum()) gen_kwargs["relation_inputs"] = inputs.get("input_commonsense_relations").to(self.device) #print('model_kwargs.get("attention_mask"):', model_kwargs.get("attention_mask")) gen = self.model.generate(input_ids=inputs["input_ids"].to(self.device), attention_mask=inputs["attention_mask"].to(self.device), constraints=constraints, min_length=1, max_length=self.max_length, do_sample=False, early_stopping=True, num_beams=8, temperature=1.0, top_k=None, top_p=None, # eos_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=2, num_return_sequences=1, return_dict_in_generate=True, output_attentions=True, output_scores=True, **gen_kwargs, ) # print('[gen]:', gen) # print(tokenizer.batch_decode(gen)) generated_answers_encoded.append(gen['sequences'][0].detach().cpu()) encoder_attentions_list.append(gen['encoder_attentions'][0].detach().cpu()) # print(f'Scores: {generated_answers_encoded}') text_results = self.tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True, clean_up_tokenization_spaces=True) return text_results, encoder_attentions_list, model_inputs def prepare_context_for_visualization(self, context): examples, relations = [], [] response, encoder_outputs, model_input = self.generate_based_on_context(context) input_commonsense_relations = model_input.get("input_commonsense_relations") encoder_outputs = torch.stack(encoder_outputs) n_layers, batch_size, n_heads, src, tgt = encoder_outputs.size() print(encoder_outputs.size()) encoder_attentions = encoder_outputs.view(batch_size, n_layers, n_heads, src, tgt) for i, ex in enumerate(encoder_attentions): d = {} indices = model_input['input_ids'][i].detach().cpu() all_tokens = self.tokenizer.convert_ids_to_tokens(indices) useful_indeces = indices != self.tokenizer.pad_token_id all_tokens = np.array(all_tokens)[useful_indeces] all_tokens = [tok.replace('Ġ', '') for tok in all_tokens] d['words'] = all_tokens d['attentions'] = ex.detach().cpu().numpy() examples.append(d) relations.append(input_commonsense_relations[i]) print(d['words']) return response, examples, relations