|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
medical_model_name = "dmis-lab/biobert-base-cased-v1.1" |
|
medical_tokenizer = AutoTokenizer.from_pretrained(medical_model_name) |
|
medical_model = AutoModelForQuestionAnswering.from_pretrained(medical_model_name) |
|
|
|
|
|
conversation_model_name = "microsoft/DialoGPT-small" |
|
conversation_tokenizer = AutoTokenizer.from_pretrained(conversation_model_name) |
|
conversation_model = AutoModelForCausalLM.from_pretrained(conversation_model_name) |
|
|
|
|
|
st.title("Medical Chatbot") |
|
st.write("Ask your medical-related question below:") |
|
|
|
|
|
if 'conversation_history' not in st.session_state: |
|
st.session_state.conversation_history = [] |
|
|
|
|
|
user_input = st.text_input("You:") |
|
|
|
if st.button("Send"): |
|
if user_input: |
|
|
|
st.session_state.conversation_history.append(f"User: {user_input}") |
|
|
|
|
|
context = """ |
|
A headache can have many causes, ranging from stress, dehydration, or fatigue to more severe conditions like migraines, infections, or neurological problems. Common remedies include over-the-counter pain relievers, hydration, and rest. |
|
If headaches are persistent or severe, it may indicate an underlying condition such as tension headaches, cluster headaches, or even infections like sinusitis. If the headache is accompanied by other symptoms such as nausea, vision changes, or confusion, it is recommended to seek medical attention. |
|
""" |
|
|
|
inputs = medical_tokenizer.encode_plus(user_input, context, add_special_tokens=True, return_tensors="pt") |
|
input_ids = inputs["input_ids"].tolist()[0] |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = medical_model(**inputs) |
|
answer_start_scores = outputs.start_logits |
|
answer_end_scores = outputs.end_logits |
|
|
|
answer_start = torch.argmax(answer_start_scores) |
|
answer_end = torch.argmax(answer_end_scores) + 1 |
|
|
|
medical_answer = medical_tokenizer.convert_tokens_to_string(medical_tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end])) |
|
|
|
|
|
if not medical_answer.strip(): |
|
medical_answer = "I'm not sure about that. You may want to consult a medical professional." |
|
|
|
|
|
st.session_state.conversation_history.append(f"Bot (Medical): {medical_answer}") |
|
|
|
|
|
conversation_input_ids = conversation_tokenizer.encode(user_input + conversation_tokenizer.eos_token, return_tensors='pt') |
|
conversation_bot_input_ids = torch.cat([conversation_tokenizer.encode(convo + conversation_tokenizer.eos_token, return_tensors='pt') for convo in st.session_state.conversation_history], dim=-1) |
|
|
|
|
|
chat_history_ids = conversation_model.generate(conversation_bot_input_ids, max_length=1000, pad_token_id=conversation_tokenizer.eos_token_id) |
|
conversation_response = conversation_tokenizer.decode(chat_history_ids[:, conversation_bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
|
|
|
|
st.session_state.conversation_history.append(f"Bot (Conversational): {conversation_response}") |
|
|
|
|
|
for message in st.session_state.conversation_history: |
|
st.write(message) |
|
else: |
|
st.write("Please enter a medical question.") |
|
|
|
|
|
|
|
|