OCN_CSChatbot / app.py
Heraali's picture
Upload 2 files
48a63b6 verified
raw
history blame
4.49 kB
import torch
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer, util
import gradio as gr
import json
# Load the lightweight BERT-based QA model optimized for CPU
model_name = "distilbert-base-uncased-distilled-squad" # Efficient for CPU
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Initialize pipeline for CPU usage
device = -1 # Force CPU
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=device)
# Load Sentence-BERT for semantic search
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Load knowledge base and expanded QA dataset
with open('knowledge_base.json', 'r') as f:
knowledge_base = json.load(f)
with open('expanded_qa_dataset.json', 'r') as f:
expanded_qa_dataset = json.load(f)
# Function to create embeddings for the expanded QA dataset
def create_qa_dataset_embeddings(expanded_qa_dataset):
qa_embeddings = []
questions = []
for item in expanded_qa_dataset:
questions.append(item['question'])
qa_embeddings.append(embedding_model.encode(item['question'], convert_to_tensor=True))
return qa_embeddings, questions
# Create QA dataset embeddings
qa_embeddings, qa_questions = create_qa_dataset_embeddings(expanded_qa_dataset)
# Function to create embeddings for the knowledge base content
def create_knowledge_base_embeddings(knowledge_base):
embeddings = []
for entry in knowledge_base:
if 'title' in entry:
content = entry['title'] + ' '.join(
[c.get('text', '') for c in entry.get('content', [])] +
[' '.join(step['details']) for c in entry.get('content', []) if 'steps' in c for step in c['steps']] +
[faq['question'] + ' ' + faq['answer'] for c in entry.get('content', []) if 'faq' in c for faq in c['faq']]
)
embeddings.append(embedding_model.encode(content, convert_to_tensor=True))
return embeddings
# Create knowledge base embeddings
knowledge_base_embeddings = create_knowledge_base_embeddings(knowledge_base)
# Semantic search on expanded QA dataset
def search_expanded_qa(question):
question_embedding = embedding_model.encode(question, convert_to_tensor=True)
cosine_scores = util.pytorch_cos_sim(question_embedding, torch.stack(qa_embeddings))
best_match_idx = torch.argmax(cosine_scores).item()
best_match_score = cosine_scores[0, best_match_idx].item()
return expanded_qa_dataset[best_match_idx]['answer'], best_match_score
# Semantic search on knowledge base
def search_knowledge_base(question):
question_embedding = embedding_model.encode(question, convert_to_tensor=True)
cosine_scores = util.pytorch_cos_sim(question_embedding, torch.stack(knowledge_base_embeddings))
best_match_idx = torch.argmax(cosine_scores).item()
best_match_score = cosine_scores[0, best_match_idx].item()
# Retrieve content from best matched knowledge base entry
best_match_entry = knowledge_base[best_match_idx]
for content_item in best_match_entry['content']:
if 'faq' in content_item:
for faq in content_item['faq']:
if faq['question'].lower() in question.lower():
return faq['answer'], best_match_score
if 'steps' in content_item:
step_details = [step['details'] for step in content_item['steps']]
return "\n".join(step_details), best_match_score
if 'text' in content_item:
return content_item['text'], best_match_score
return "Lo siento, no encontré una respuesta adecuada para tu pregunta.", best_match_score
# Answer function: search both datasets and return the best match
def answer_question(question):
# Search expanded QA dataset
qa_answer, qa_score = search_expanded_qa(question)
# Search knowledge base
kb_answer, kb_score = search_knowledge_base(question)
# Compare scores and return the best answer
if qa_score >= kb_score:
return qa_answer
else:
return kb_answer
# Gradio interface
interface = gr.Interface(
fn=answer_question,
inputs="text",
outputs="text",
title="OCN Customer Support Chatbot",
description="Ask questions and get answers from the OCN knowledge base and expanded QA dataset."
)
# Launch the interface
interface.launch(share=True)