Spaces:
Running
Running
Upload app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,10 @@ from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
|
|
3 |
from sentence_transformers import SentenceTransformer, util
|
4 |
import gradio as gr
|
5 |
import json
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Load pre-trained BERT QA model and tokenizer from Hugging Face model hub
|
8 |
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
|
@@ -50,31 +54,8 @@ def create_knowledge_base_embeddings(knowledge_base):
|
|
50 |
# Create knowledge base embeddings
|
51 |
knowledge_base_embeddings = create_knowledge_base_embeddings(knowledge_base)
|
52 |
|
53 |
-
#
|
54 |
-
def
|
55 |
-
qa_embeddings = []
|
56 |
-
for item in expanded_qa_dataset:
|
57 |
-
qa_embeddings.append({
|
58 |
-
"question": item["question"],
|
59 |
-
"answer": item["answer"],
|
60 |
-
"embedding": embedding_model.encode(item["question"], convert_to_tensor=True)
|
61 |
-
})
|
62 |
-
return qa_embeddings
|
63 |
-
|
64 |
-
# Create expanded QA dataset embeddings
|
65 |
-
expanded_qa_embeddings = create_expanded_qa_embeddings(expanded_qa_dataset)
|
66 |
-
|
67 |
-
# Dynamic threshold adjustment based on query length
|
68 |
-
def adjust_threshold_based_on_query_length(question_length):
|
69 |
-
if question_length <= 5: # Short question, use higher threshold
|
70 |
-
return 0.7
|
71 |
-
elif 5 < question_length <= 10: # Medium-length question
|
72 |
-
return 0.6
|
73 |
-
else: # Longer question, use lower threshold
|
74 |
-
return 0.5
|
75 |
-
|
76 |
-
# Function to retrieve the best context using semantic similarity (Knowledge Base)
|
77 |
-
def get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embeddings):
|
78 |
# Create embedding for the question
|
79 |
question_embedding = embedding_model.encode(question, convert_to_tensor=True)
|
80 |
|
@@ -85,64 +66,79 @@ def get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embedd
|
|
85 |
best_match_idx = torch.argmax(cosine_scores).item()
|
86 |
best_match_score = cosine_scores[0, best_match_idx].item()
|
87 |
|
88 |
-
|
89 |
-
dynamic_threshold = adjust_threshold_based_on_query_length(len(question.split()))
|
90 |
-
|
91 |
-
if best_match_score > dynamic_threshold: # Use dynamic threshold
|
92 |
-
best_match_entry = knowledge_base[best_match_idx]
|
93 |
-
|
94 |
-
# Check if FAQ section exists and prioritize FAQ answers
|
95 |
-
for content_item in best_match_entry['content']:
|
96 |
-
if 'faq' in content_item:
|
97 |
-
for faq in content_item['faq']:
|
98 |
-
if faq['question'].lower() in question.lower():
|
99 |
-
return faq['answer']
|
100 |
-
|
101 |
-
# If no FAQ is found, check for steps
|
102 |
-
for content_item in best_match_entry['content']:
|
103 |
-
if 'steps' in content_item:
|
104 |
-
step_details = [step['details'] for step in content_item['steps']]
|
105 |
-
return "\n".join(step_details)
|
106 |
-
|
107 |
-
# Fallback to regular text
|
108 |
-
for content_item in best_match_entry['content']:
|
109 |
-
if 'text' in content_item:
|
110 |
-
return content_item['text']
|
111 |
-
|
112 |
-
return "Lo siento, no encontré una respuesta adecuada para tu pregunta."
|
113 |
|
114 |
-
#
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
|
119 |
-
|
120 |
-
qa_cosine_scores = [util.pytorch_cos_sim(question_embedding, item["embedding"]) for item in expanded_qa_embeddings]
|
121 |
|
122 |
-
#
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
125 |
|
126 |
-
#
|
127 |
-
|
|
|
|
|
|
|
128 |
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
return None
|
133 |
|
134 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
def answer_question(question):
|
136 |
-
# Check
|
137 |
-
direct_answer =
|
138 |
if direct_answer:
|
139 |
return direct_answer
|
140 |
|
141 |
# If no direct answer found, use the knowledge base with semantic search
|
142 |
-
context = get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embeddings)
|
143 |
return context
|
144 |
|
145 |
# Gradio interface setup
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
interface = gr.Interface(
|
147 |
fn=answer_question,
|
148 |
inputs="text",
|
@@ -153,3 +149,6 @@ interface = gr.Interface(
|
|
153 |
|
154 |
# Launch the Gradio interface
|
155 |
interface.launch()
|
|
|
|
|
|
|
|
3 |
from sentence_transformers import SentenceTransformer, util
|
4 |
import gradio as gr
|
5 |
import json
|
6 |
+
import logging
|
7 |
+
|
8 |
+
# Setup logging
|
9 |
+
logging.basicConfig(filename='chatbot_logs.log', level=logging.INFO)
|
10 |
|
11 |
# Load pre-trained BERT QA model and tokenizer from Hugging Face model hub
|
12 |
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
|
|
|
54 |
# Create knowledge base embeddings
|
55 |
knowledge_base_embeddings = create_knowledge_base_embeddings(knowledge_base)
|
56 |
|
57 |
+
# Function to retrieve the best context using semantic similarity with dynamic thresholds
|
58 |
+
def get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embeddings, threshold=0.55):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
# Create embedding for the question
|
60 |
question_embedding = embedding_model.encode(question, convert_to_tensor=True)
|
61 |
|
|
|
66 |
best_match_idx = torch.argmax(cosine_scores).item()
|
67 |
best_match_score = cosine_scores[0, best_match_idx].item()
|
68 |
|
69 |
+
logging.info(f"Question: {question} - Best match score: {best_match_score}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
# Log if the similarity score is too low
|
72 |
+
if best_match_score < threshold:
|
73 |
+
logging.warning(f"Low similarity score ({best_match_score}) for question: {question}")
|
74 |
+
return "Lo siento, no encontré una respuesta adecuada para tu pregunta."
|
75 |
|
76 |
+
best_match_entry = knowledge_base[best_match_idx]
|
|
|
77 |
|
78 |
+
# Check if FAQ section exists and prioritize FAQ answers
|
79 |
+
for content_item in best_match_entry['content']:
|
80 |
+
if 'faq' in content_item:
|
81 |
+
for faq in content_item['faq']:
|
82 |
+
if faq['question'].lower() in question.lower():
|
83 |
+
return faq['answer']
|
84 |
|
85 |
+
# If no FAQ is found, check for steps
|
86 |
+
for content_item in best_match_entry['content']:
|
87 |
+
if 'steps' in content_item:
|
88 |
+
step_details = [step['details'] for step in content_item['steps']]
|
89 |
+
return "\n".join(step_details)
|
90 |
|
91 |
+
# Fallback to regular text
|
92 |
+
for content_item in best_match_entry['content']:
|
93 |
+
if 'text' in content_item:
|
94 |
+
return content_item['text']
|
95 |
+
|
96 |
+
return "Lo siento, no encontré una respuesta adecuada a tu pregunta."
|
97 |
|
98 |
+
# Check expanded QA dataset first for a direct answer
|
99 |
+
def get_answer_from_expanded_qa(question, expanded_qa_dataset):
|
100 |
+
for item in expanded_qa_dataset:
|
101 |
+
if item['question'].lower() in question.lower():
|
102 |
+
logging.info(f"Direct match found in expanded QA dataset for question: {question}")
|
103 |
+
return item['answer']
|
104 |
return None
|
105 |
|
106 |
+
# Collect user feedback for improving the model (Placeholder for future enhancement)
|
107 |
+
def collect_user_feedback(question, user_answer, correct_answer, feedback):
|
108 |
+
# Placeholder: Save feedback to a file or database
|
109 |
+
with open('user_feedback.log', 'a') as feedback_log:
|
110 |
+
feedback_log.write(f"Question: {question}\n")
|
111 |
+
feedback_log.write(f"User Answer: {user_answer}\n")
|
112 |
+
feedback_log.write(f"Correct Answer: {correct_answer}\n")
|
113 |
+
feedback_log.write(f"Feedback: {feedback}\n\n")
|
114 |
+
logging.info(f"Feedback collected for question: {question}")
|
115 |
+
|
116 |
+
# Answer function for the Gradio app
|
117 |
def answer_question(question):
|
118 |
+
# Check if the question matches any entry in the expanded QA dataset
|
119 |
+
direct_answer = get_answer_from_expanded_qa(question, expanded_qa_dataset)
|
120 |
if direct_answer:
|
121 |
return direct_answer
|
122 |
|
123 |
# If no direct answer found, use the knowledge base with semantic search
|
124 |
+
context = get_dynamic_context_semantic(question, knowledge_base, knowledge_base_embeddings, threshold=0.55)
|
125 |
return context
|
126 |
|
127 |
# Gradio interface setup
|
128 |
+
def feedback_interface(question, user_answer, correct_answer, feedback):
|
129 |
+
collect_user_feedback(question, user_answer, correct_answer, feedback)
|
130 |
+
return "Thank you for your feedback!"
|
131 |
+
|
132 |
+
# Gradio interface setup for feedback collection
|
133 |
+
feedback_gr = gr.Interface(
|
134 |
+
fn=feedback_interface,
|
135 |
+
inputs=["text", "text", "text", "text"],
|
136 |
+
outputs="text",
|
137 |
+
title="Feedback Collection",
|
138 |
+
description="Submit feedback on the chatbot responses."
|
139 |
+
)
|
140 |
+
|
141 |
+
# Main interface
|
142 |
interface = gr.Interface(
|
143 |
fn=answer_question,
|
144 |
inputs="text",
|
|
|
149 |
|
150 |
# Launch the Gradio interface
|
151 |
interface.launch()
|
152 |
+
|
153 |
+
# Launch the feedback interface separately
|
154 |
+
feedback_gr.launch(share=True)
|