OCN_CSChatbot / app.py
Heraali's picture
Upload app.py
22f193b verified
raw
history blame
6.28 kB
import torch
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer, util
import gradio as gr
import json
import logging
# Setup logging
logging.basicConfig(filename='chatbot_logs.log', level=logging.INFO)
# Load pre-trained BERT QA model and tokenizer from Hugging Face model hub
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Dynamically handle device (CPU only)
device = -1 # Force CPU usage by setting device to -1
# Initialize the QA pipeline with the correct device
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=device)
# Load the knowledge base from JSON file
with open('knowledge_base.json', 'r') as f:
knowledge_base = json.load(f)
# Load the expanded QA dataset
with open('expanded_qa_dataset.json', 'r') as f:
expanded_qa_dataset = json.load(f)
# Load Sentence-BERT model for semantic search
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Function to create embeddings for the knowledge base content
def create_knowledge_base_embeddings(knowledge_base):
embeddings = []
for entry in knowledge_base:
if 'title' in entry:
content = entry['title'] + ' ' + ' '.join(
[c.get('text', '') for c in entry.get('content', [])] +
[
' '.join(step['details']) if isinstance(step['details'], list) else step['details']
for c in entry.get('content', []) if 'steps' in c
for step in c['steps']
] +
[
faq['question'] + ' ' + faq['answer']
for c in entry.get('content', []) if 'faq' in c
for faq in c['faq']
]
)
embeddings.append(embedding_model.encode(content, convert_to_tensor=True))
return embeddings
# Create knowledge base embeddings
knowledge_base_embeddings = create_knowledge_base_embeddings(knowledge_base)
# Function to retrieve the best context using semantic similarity with dynamic thresholds
def get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embeddings, threshold=0.55):
# Create embedding for the question
question_embedding = embedding_model.encode(question, convert_to_tensor=True)
# Calculate cosine similarity between the question and knowledge base entries
cosine_scores = util.pytorch_cos_sim(question_embedding, torch.stack(knowledge_base_embeddings))
# Get the index of the highest score (most similar context)
best_match_idx = torch.argmax(cosine_scores).item()
best_match_score = cosine_scores[0, best_match_idx].item()
logging.info(f"Question: {question} - Best match score: {best_match_score}")
# Log if the similarity score is too low
if best_match_score < threshold:
logging.warning(f"Low similarity score ({best_match_score}) for question: {question}")
return "Lo siento, no encontré una respuesta adecuada para tu pregunta."
best_match_entry = knowledge_base[best_match_idx]
# Check if FAQ section exists and prioritize FAQ answers
for content_item in best_match_entry['content']:
if 'faq' in content_item:
for faq in content_item['faq']:
if faq['question'].lower() in question.lower():
return faq['answer']
# If no FAQ is found, check for steps
for content_item in best_match_entry['content']:
if 'steps' in content_item:
step_details = [step['details'] for step in content_item['steps']]
return "\n".join(step_details)
# Fallback to regular text
for content_item in best_match_entry['content']:
if 'text' in content_item:
return content_item['text']
return "Lo siento, no encontré una respuesta adecuada a tu pregunta."
# Check expanded QA dataset first for a direct answer
def get_answer_from_expanded_qa(question, expanded_qa_dataset):
for item in expanded_qa_dataset:
if item['question'].lower() in question.lower():
logging.info(f"Direct match found in expanded QA dataset for question: {question}")
return item['answer']
return None
# Collect user feedback for improving the model (Placeholder for future enhancement)
def collect_user_feedback(question, user_answer, correct_answer, feedback):
# Placeholder: Save feedback to a file or database
with open('user_feedback.log', 'a') as feedback_log:
feedback_log.write(f"Question: {question}\n")
feedback_log.write(f"User Answer: {user_answer}\n")
feedback_log.write(f"Correct Answer: {correct_answer}\n")
feedback_log.write(f"Feedback: {feedback}\n\n")
logging.info(f"Feedback collected for question: {question}")
# Answer function for the Gradio app
def answer_question(question):
# Check if the question matches any entry in the expanded QA dataset
direct_answer = get_answer_from_expanded_qa(question, expanded_qa_dataset)
if direct_answer:
return direct_answer
# If no direct answer found, use the knowledge base with semantic search
context = get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embeddings, threshold=0.55)
return context
# Gradio interface setup
def feedback_interface(question, user_answer, correct_answer, feedback):
collect_user_feedback(question, user_answer, correct_answer, feedback)
return "Thank you for your feedback!"
# Gradio interface setup for feedback collection
feedback_gr = gr.Interface(
fn=feedback_interface,
inputs=["text", "text", "text", "text"],
outputs="text",
title="Feedback Collection",
description="Submit feedback on the chatbot responses."
)
# Main interface
interface = gr.Interface(
fn=answer_question,
inputs="text",
outputs="text",
title="OCN Customer Support Chatbot",
description="Ask questions and get answers from the OCN knowledge base."
)
# Launch the Gradio interface
interface.launch()
# Launch the feedback interface separately
feedback_gr.launch(share=True)