import torch from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline from sentence_transformers import SentenceTransformer, util import gradio as gr import json # Load pre-trained BERT QA model and tokenizer from Hugging Face model hub model_name = "bert-large-uncased-whole-word-masking-finetuned-squad" model = AutoModelForQuestionAnswering.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # Dynamically handle device (CPU only) device = -1 # Force CPU usage by setting device to -1 # Initialize the QA pipeline with the correct device qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=device) # Load the knowledge base from JSON file with open('knowledge_base.json', 'r') as f: knowledge_base = json.load(f) # Load the expanded QA dataset with open('expanded_qa_dataset.json', 'r') as f: expanded_qa_dataset = json.load(f) # Load Sentence-BERT model for semantic search embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Function to create embeddings for the knowledge base content def create_knowledge_base_embeddings(knowledge_base): embeddings = [] for entry in knowledge_base: if 'title' in entry: content = entry['title'] + ' ' + ' '.join( [c.get('text', '') for c in entry.get('content', [])] + [ ' '.join(step['details']) if isinstance(step['details'], list) else step['details'] for c in entry.get('content', []) if 'steps' in c for step in c['steps'] ] + [ faq['question'] + ' ' + faq['answer'] for c in entry.get('content', []) if 'faq' in c for faq in c['faq'] ] ) embeddings.append(embedding_model.encode(content, convert_to_tensor=True)) return embeddings # Create knowledge base embeddings knowledge_base_embeddings = create_knowledge_base_embeddings(knowledge_base) # Function to retrieve the best context using semantic similarity def get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embeddings, threshold=0.5): # Create embedding for the question question_embedding = embedding_model.encode(question, convert_to_tensor=True) # Calculate cosine similarity between the question and knowledge base entries cosine_scores = util.pytorch_cos_sim(question_embedding, torch.stack(knowledge_base_embeddings)) # Get the index of the highest score (most similar context) best_match_idx = torch.argmax(cosine_scores).item() best_match_score = cosine_scores[0, best_match_idx].item() if best_match_score > threshold: # Set a threshold for semantic similarity best_match_entry = knowledge_base[best_match_idx] # Check if FAQ section exists and prioritize FAQ answers for content_item in best_match_entry['content']: if 'faq' in content_item: for faq in content_item['faq']: if faq['question'].lower() in question.lower(): return faq['answer'] # If no FAQ is found, check for steps for content_item in best_match_entry['content']: if 'steps' in content_item: step_details = [step['details'] for step in content_item['steps']] return "\n".join(step_details) # Fallback to regular text for content_item in best_match_entry['content']: if 'text' in content_item: return content_item['text'] return "Lo siento, no encontré una respuesta adecuada para tu pregunta." # Check expanded QA dataset first for a direct answer def get_answer_from_expanded_qa(question, expanded_qa_dataset): for item in expanded_qa_dataset: if item['question'].lower() == question.lower(): return item['answer'] return None # Answer function for the Gradio app def answer_question(question): # Check if the question matches any entry in the expanded QA dataset direct_answer = get_answer_from_expanded_qa(question, expanded_qa_dataset) if direct_answer: return direct_answer # If no direct answer found, use the knowledge base with semantic search context = get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embeddings, threshold=0.55) return context # Gradio interface setup interface = gr.Interface( fn=answer_question, inputs="text", outputs="text", title="OCN Customer Support Chatbot", description="Ask questions and get answers from the OCN knowledge base." ) # Launch the Gradio interface interface.launch()