import streamlit as st from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForCausalLM import torch # Load BioBERT for medical question answering medical_model_name = "dmis-lab/biobert-base-cased-v1.1" medical_tokenizer = AutoTokenizer.from_pretrained(medical_model_name) medical_model = AutoModelForQuestionAnswering.from_pretrained(medical_model_name) # Load DialoGPT for conversation conversation_model_name = "microsoft/DialoGPT-small" conversation_tokenizer = AutoTokenizer.from_pretrained(conversation_model_name) conversation_model = AutoModelForCausalLM.from_pretrained(conversation_model_name) # Streamlit app layout st.title("Medical Chatbot") st.write("Ask your medical-related question below:") # Conversation history tracker if 'conversation_history' not in st.session_state: st.session_state.conversation_history = [] # Text input user_input = st.text_input("You:") if st.button("Send"): if user_input: # Append user input to the conversation history st.session_state.conversation_history.append(f"User: {user_input}") # BioBERT for Medical Q&A context = """ A headache can have many causes, ranging from stress, dehydration, or fatigue to more severe conditions like migraines, infections, or neurological problems. Common remedies include over-the-counter pain relievers, hydration, and rest. If headaches are persistent or severe, it may indicate an underlying condition such as tension headaches, cluster headaches, or even infections like sinusitis. If the headache is accompanied by other symptoms such as nausea, vision changes, or confusion, it is recommended to seek medical attention. """ inputs = medical_tokenizer.encode_plus(user_input, context, add_special_tokens=True, return_tensors="pt") input_ids = inputs["input_ids"].tolist()[0] # Perform Question Answering using BioBERT with torch.no_grad(): outputs = medical_model(**inputs) answer_start_scores = outputs.start_logits answer_end_scores = outputs.end_logits answer_start = torch.argmax(answer_start_scores) answer_end = torch.argmax(answer_end_scores) + 1 medical_answer = medical_tokenizer.convert_tokens_to_string(medical_tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end])) # Check if a valid medical answer was found if not medical_answer.strip(): # If the answer is empty or only whitespace medical_answer = "I'm not sure about that. You may want to consult a medical professional." # Append medical response to the conversation history st.session_state.conversation_history.append(f"Bot (Medical): {medical_answer}") # Generate conversational response with DialoGPT conversation_input_ids = conversation_tokenizer.encode(user_input + conversation_tokenizer.eos_token, return_tensors='pt') conversation_bot_input_ids = torch.cat([conversation_tokenizer.encode(convo + conversation_tokenizer.eos_token, return_tensors='pt') for convo in st.session_state.conversation_history], dim=-1) # Generate conversational response chat_history_ids = conversation_model.generate(conversation_bot_input_ids, max_length=1000, pad_token_id=conversation_tokenizer.eos_token_id) conversation_response = conversation_tokenizer.decode(chat_history_ids[:, conversation_bot_input_ids.shape[-1]:][0], skip_special_tokens=True) # Append conversational response to conversation history st.session_state.conversation_history.append(f"Bot (Conversational): {conversation_response}") # Display the entire conversation history for message in st.session_state.conversation_history: st.write(message) else: st.write("Please enter a medical question.")