OCN_CSChatbot / app.py
Heraali's picture
Upload 2 files
822876e verified
raw
history blame
4.92 kB
import torch
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer, util
import gradio as gr
import json
from fuzzywuzzy import fuzz
# Load pre-trained BERT QA model and tokenizer from Hugging Face model hub
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Dynamically handle device (CPU only)
device = -1 # Force CPU usage by setting device to -1
# Initialize the QA pipeline with the correct device
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=device)
# Load the knowledge base from JSON file
with open('knowledge_base.json', 'r') as f:
knowledge_base = json.load(f)
# Load the expanded QA dataset
with open('expanded_qa_dataset.json', 'r') as f:
expanded_qa_dataset = json.load(f)
# Load Sentence-BERT model for semantic search
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Function to create embeddings for the knowledge base content
def create_knowledge_base_embeddings(knowledge_base):
embeddings = []
for entry in knowledge_base:
if 'title' in entry:
content = entry['title'] + ' ' + ' '.join(
[c.get('text', '') for c in entry.get('content', [])] +
[
' '.join(step['details']) if isinstance(step['details'], list) else step['details']
for c in entry.get('content', []) if 'steps' in c
for step in c['steps']
] +
[
faq['question'] + ' ' + faq['answer']
for c in entry.get('content', []) if 'faq' in c
for faq in c['faq']
]
)
embeddings.append(embedding_model.encode(content, convert_to_tensor=True))
return embeddings
# Create knowledge base embeddings
knowledge_base_embeddings = create_knowledge_base_embeddings(knowledge_base)
# Function to retrieve the best context using semantic similarity
def get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embeddings, threshold=0.5):
# Create embedding for the question
question_embedding = embedding_model.encode(question, convert_to_tensor=True)
# Calculate cosine similarity between the question and knowledge base entries
cosine_scores = util.pytorch_cos_sim(question_embedding, torch.stack(knowledge_base_embeddings))
# Get the index of the highest score (most similar context)
best_match_idx = torch.argmax(cosine_scores).item()
best_match_score = cosine_scores[0, best_match_idx].item()
if best_match_score > threshold: # Set a threshold for semantic similarity
best_match_entry = knowledge_base[best_match_idx]
# Check if FAQ section exists and prioritize FAQ answers
for content_item in best_match_entry['content']:
if 'faq' in content_item:
for faq in content_item['faq']:
if fuzz.token_sort_ratio(faq['question'].lower(), question.lower()) > 80:
return faq['answer']
# If no FAQ is found, check for steps
for content_item in best_match_entry['content']:
if 'steps' in content_item:
step_details = [step['details'] for step in content_item['steps']]
return "\n".join(step_details)
# Fallback to regular text
for content_item in best_match_entry['content']:
if 'text' in content_item:
return content_item['text']
return "Lo siento, no encontré una respuesta adecuada a tu pregunta."
# Use fuzzy matching to find the closest match in the expanded QA dataset
def get_answer_from_expanded_qa(question, expanded_qa_dataset, threshold=80):
for item in expanded_qa_dataset:
# Use fuzzy matching to find close matches
if fuzz.token_sort_ratio(item['question'].lower(), question.lower()) > threshold:
return item['answer']
return None
# Answer function for the Gradio app
def answer_question(question):
# Check if the question matches any entry in the expanded QA dataset
direct_answer = get_answer_from_expanded_qa(question, expanded_qa_dataset)
if direct_answer:
return direct_answer
# If no direct answer found, use the knowledge base with semantic search
context = get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embeddings, threshold=0.45)
return context
# Gradio interface setup
interface = gr.Interface(
fn=answer_question,
inputs="text",
outputs="text",
title="OCN Customer Support Chatbot",
description="Ask questions and get answers from the OCN knowledge base."
)
# Launch the Gradio interface
interface.launch(share=True)