Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline | |
from sentence_transformers import SentenceTransformer, util | |
import gradio as gr | |
import json | |
from fuzzywuzzy import fuzz | |
# Load pre-trained BERT QA model and tokenizer from Hugging Face model hub | |
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad" | |
model = AutoModelForQuestionAnswering.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Dynamically handle device (CPU only) | |
device = -1 # Force CPU usage by setting device to -1 | |
# Initialize the QA pipeline with the correct device | |
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=device) | |
# Load the knowledge base from JSON file | |
with open('knowledge_base.json', 'r') as f: | |
knowledge_base = json.load(f) | |
# Load the expanded QA dataset | |
with open('expanded_qa_dataset.json', 'r') as f: | |
expanded_qa_dataset = json.load(f) | |
# Load Sentence-BERT model for semantic search | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Function to create embeddings for the knowledge base content | |
def create_knowledge_base_embeddings(knowledge_base): | |
embeddings = [] | |
for entry in knowledge_base: | |
if 'title' in entry: | |
content = entry['title'] + ' ' + ' '.join( | |
[c.get('text', '') for c in entry.get('content', [])] + | |
[ | |
' '.join(step['details']) if isinstance(step['details'], list) else step['details'] | |
for c in entry.get('content', []) if 'steps' in c | |
for step in c['steps'] | |
] + | |
[ | |
faq['question'] + ' ' + faq['answer'] | |
for c in entry.get('content', []) if 'faq' in c | |
for faq in c['faq'] | |
] | |
) | |
embeddings.append(embedding_model.encode(content, convert_to_tensor=True)) | |
return embeddings | |
# Create knowledge base embeddings | |
knowledge_base_embeddings = create_knowledge_base_embeddings(knowledge_base) | |
# Function to retrieve the best context using semantic similarity | |
def get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embeddings, threshold=0.5): | |
# Create embedding for the question | |
question_embedding = embedding_model.encode(question, convert_to_tensor=True) | |
# Calculate cosine similarity between the question and knowledge base entries | |
cosine_scores = util.pytorch_cos_sim(question_embedding, torch.stack(knowledge_base_embeddings)) | |
# Get the index of the highest score (most similar context) | |
best_match_idx = torch.argmax(cosine_scores).item() | |
best_match_score = cosine_scores[0, best_match_idx].item() | |
if best_match_score > threshold: # Set a threshold for semantic similarity | |
best_match_entry = knowledge_base[best_match_idx] | |
# Check if FAQ section exists and prioritize FAQ answers | |
for content_item in best_match_entry['content']: | |
if 'faq' in content_item: | |
for faq in content_item['faq']: | |
if fuzz.token_sort_ratio(faq['question'].lower(), question.lower()) > 80: | |
return faq['answer'] | |
# If no FAQ is found, check for steps | |
for content_item in best_match_entry['content']: | |
if 'steps' in content_item: | |
step_details = [step['details'] for step in content_item['steps']] | |
return "\n".join(step_details) | |
# Fallback to regular text | |
for content_item in best_match_entry['content']: | |
if 'text' in content_item: | |
return content_item['text'] | |
return "Lo siento, no encontré una respuesta adecuada a tu pregunta." | |
# Use fuzzy matching to find the closest match in the expanded QA dataset | |
def get_answer_from_expanded_qa(question, expanded_qa_dataset, threshold=80): | |
for item in expanded_qa_dataset: | |
# Use fuzzy matching to find close matches | |
if fuzz.token_sort_ratio(item['question'].lower(), question.lower()) > threshold: | |
return item['answer'] | |
return None | |
# Answer function for the Gradio app | |
def answer_question(question): | |
# Check if the question matches any entry in the expanded QA dataset | |
direct_answer = get_answer_from_expanded_qa(question, expanded_qa_dataset) | |
if direct_answer: | |
return direct_answer | |
# If no direct answer found, use the knowledge base with semantic search | |
context = get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embeddings, threshold=0.45) | |
return context | |
# Gradio interface setup | |
interface = gr.Interface( | |
fn=answer_question, | |
inputs="text", | |
outputs="text", | |
title="OCN Customer Support Chatbot", | |
description="Ask questions and get answers from the OCN knowledge base." | |
) | |
# Launch the Gradio interface | |
interface.launch(share=True) | |