medicalchatbot / app.py
Manjulabalathandayutham's picture
Update app.py
ed140cf verified
raw
history blame contribute delete
No virus
3.91 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForCausalLM
import torch
# Load BioBERT for medical question answering
medical_model_name = "dmis-lab/biobert-base-cased-v1.1"
medical_tokenizer = AutoTokenizer.from_pretrained(medical_model_name)
medical_model = AutoModelForQuestionAnswering.from_pretrained(medical_model_name)
# Load DialoGPT for conversation
conversation_model_name = "microsoft/DialoGPT-small"
conversation_tokenizer = AutoTokenizer.from_pretrained(conversation_model_name)
conversation_model = AutoModelForCausalLM.from_pretrained(conversation_model_name)
# Streamlit app layout
st.title("Medical Chatbot")
st.write("Ask your medical-related question below:")
# Conversation history tracker
if 'conversation_history' not in st.session_state:
st.session_state.conversation_history = []
# Text input
user_input = st.text_input("You:")
if st.button("Send"):
if user_input:
# Append user input to the conversation history
st.session_state.conversation_history.append(f"User: {user_input}")
# BioBERT for Medical Q&A
context = """
A headache can have many causes, ranging from stress, dehydration, or fatigue to more severe conditions like migraines, infections, or neurological problems. Common remedies include over-the-counter pain relievers, hydration, and rest.
If headaches are persistent or severe, it may indicate an underlying condition such as tension headaches, cluster headaches, or even infections like sinusitis. If the headache is accompanied by other symptoms such as nausea, vision changes, or confusion, it is recommended to seek medical attention.
"""
inputs = medical_tokenizer.encode_plus(user_input, context, add_special_tokens=True, return_tensors="pt")
input_ids = inputs["input_ids"].tolist()[0]
# Perform Question Answering using BioBERT
with torch.no_grad():
outputs = medical_model(**inputs)
answer_start_scores = outputs.start_logits
answer_end_scores = outputs.end_logits
answer_start = torch.argmax(answer_start_scores)
answer_end = torch.argmax(answer_end_scores) + 1
medical_answer = medical_tokenizer.convert_tokens_to_string(medical_tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
# Check if a valid medical answer was found
if not medical_answer.strip(): # If the answer is empty or only whitespace
medical_answer = "I'm not sure about that. You may want to consult a medical professional."
# Append medical response to the conversation history
st.session_state.conversation_history.append(f"Bot (Medical): {medical_answer}")
# Generate conversational response with DialoGPT
conversation_input_ids = conversation_tokenizer.encode(user_input + conversation_tokenizer.eos_token, return_tensors='pt')
conversation_bot_input_ids = torch.cat([conversation_tokenizer.encode(convo + conversation_tokenizer.eos_token, return_tensors='pt') for convo in st.session_state.conversation_history], dim=-1)
# Generate conversational response
chat_history_ids = conversation_model.generate(conversation_bot_input_ids, max_length=1000, pad_token_id=conversation_tokenizer.eos_token_id)
conversation_response = conversation_tokenizer.decode(chat_history_ids[:, conversation_bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
# Append conversational response to conversation history
st.session_state.conversation_history.append(f"Bot (Conversational): {conversation_response}")
# Display the entire conversation history
for message in st.session_state.conversation_history:
st.write(message)
else:
st.write("Please enter a medical question.")