ziyadsuper2017's picture
Update app.py
e240cad verified
raw
history blame
7.71 kB
import streamlit as st
from PIL import Image
import io
import base64
import uuid
from gtts import gTTS
import google.generativeai as genai
from io import BytesIO
import PyPDF2
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
import asyncio
# Set your API key
api_key = "AIzaSyAHD0FwX-Ds6Y3eI-i5Oz7IdbJqR6rN7pg" # Replace with your actual API key
genai.configure(api_key=api_key)
# Configure the generative AI model
generation_config = genai.GenerationConfig(
temperature=0.9,
max_output_tokens=3000
)
# Safety settings configuration
safety_settings = [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
]
# Initialize session state
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
st.session_state['file_uploader_key'] = str(uuid.uuid4())
if 'uploaded_files' not in st.session_state:
st.session_state['uploaded_files'] = []
if 'user_input' not in st.session_state:
st.session_state['user_input'] = ''
if 'audio_data' not in st.session_state:
st.session_state['audio_data'] = None
# --- Streamlit UI ---
st.title("Gemini Chatbot")
st.write("Interact with the powerful Gemini 1.5 models.")
# Model Selection Dropdown
selected_model = st.selectbox("Choose a Gemini 1.5 Model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"])
# TTS Option Checkbox
enable_tts = st.checkbox("Enable Text-to-Speech")
# --- Helper Functions ---
def get_file_base64(file_content, mime_type):
base64_data = base64.b64encode(file_content).decode()
return {"mime_type": mime_type, "data": base64_data}
def clear_conversation():
st.session_state['chat_history'] = []
st.session_state['file_uploader_key'] = str(uuid.uuid4())
st.session_state['user_input'] = ''
st.session_state['uploaded_files'] = []
st.session_state['audio_data'] = None
def display_chat_history():
chat_container = st.empty()
with chat_container.container():
for entry in st.session_state['chat_history']:
role = entry["role"]
parts = entry["parts"][0]
if 'text' in parts:
st.markdown(f"**{role.title()}:** {parts['text']}")
elif 'data' in parts:
mime_type = parts.get('mime_type', '')
if mime_type.startswith('image'):
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))),
caption='Uploaded Image', use_column_width=True)
elif mime_type == 'application/pdf':
st.write("**PDF Content:**")
pdf_reader = PyPDF2.PdfReader(io.BytesIO(base64.b64decode(parts['data'])))
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
st.write(page.extract_text())
elif mime_type.startswith('audio'):
st.audio(io.BytesIO(base64.b64decode(parts['data'])), format=mime_type)
elif mime_type.startswith('video'):
st.video(io.BytesIO(base64.b64decode(parts['data'])))
# --- Send Message Function ---
def send_message(audio_data=None):
user_input = st.session_state.user_input
uploaded_files = st.session_state.uploaded_files
prompt_parts = []
# Add user input to the prompt
if user_input:
prompt_parts.append({"text": user_input})
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
# Handle uploaded files
if uploaded_files:
for uploaded_file in uploaded_files:
file_content = uploaded_file.read()
prompt_parts.append(get_file_base64(file_content, uploaded_file.type))
st.session_state['chat_history'].append(
{"role": "user", "parts": [get_file_base64(file_content, uploaded_file.type)]}
)
# Handle audio data from WebRTC
if audio_data:
prompt_parts.append(get_file_base64(audio_data, 'audio/wav'))
st.session_state['chat_history'].append(
{"role": "user", "parts": [get_file_base64(audio_data, 'audio/wav')]}
)
# Generate response using the selected model
try:
model = genai.GenerativeModel(
model_name=selected_model,
generation_config=generation_config,
safety_settings=safety_settings
)
response = model.generate_content([{"role": "user", "parts": prompt_parts}])
response_text = response.text if hasattr(response, "text") else "No response text found."
if response_text:
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
if enable_tts:
tts = gTTS(text=response_text, lang='en')
tts_file = BytesIO()
tts.write_to_fp(tts_file)
tts_file.seek(0)
st.audio(tts_file, format='audio/mp3')
except Exception as e:
st.error(f"An error occurred: {e}")
st.session_state.user_input = ''
st.session_state.uploaded_files = []
st.session_state.file_uploader_key = str(uuid.uuid4())
# Update the chat history display
display_chat_history()
# --- User Input Area ---
col1, col2 = st.columns([3, 1])
with col1:
user_input = st.text_area(
"Enter your message:",
value="",
key="user_input"
)
with col2:
send_button = st.button(
"Send",
on_click=send_message,
type="primary"
)
# --- File Uploader ---
uploaded_files = st.file_uploader(
"Upload Files (Images, Videos, PDFs, MP3):",
type=["png", "jpg", "jpeg", "mp4", "pdf", "mp3"],
accept_multiple_files=True,
key=st.session_state.file_uploader_key
)
# --- WebRTC Audio Recording ---
RTC_CONFIGURATION = RTCConfiguration({"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]})
async def run_webrtc():
webrtc_ctx = webrtc_streamer(
key="audio-recorder",
mode=WebRtcMode.SENDONLY,
rtc_configuration=RTC_CONFIGURATION,
audio_receiver_size=256,
media_stream_constraints={"video": False, "audio": True},
)
if webrtc_ctx.audio_receiver:
st.write("Recording audio...")
audio_frames = webrtc_ctx.audio_receiver.get_frames(timeout=None)
st.session_state.audio_data = b"".join([frame.to_ndarray() for frame in audio_frames])
if st.button("Send Recording"):
send_message(audio_data=st.session_state.audio_data)
# --- Other Buttons ---
st.button("Clear Conversation", on_click=clear_conversation)
# --- Ensure file_uploader state ---
st.session_state.uploaded_files = uploaded_files
# --- JavaScript for Ctrl+Enter ---
st.markdown(
"""
<script>
document.addEventListener('DOMContentLoaded', (event) => {
document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) {
if (e.key === 'Enter' && e.ctrlKey) {
document.querySelector('.stButton > button').click();
e.preventDefault();
}
});
});
</script>
""",
unsafe_allow_html=True
)
# --- Run WebRTC and display chat history ---
asyncio.new_event_loop().run_until_complete(run_webrtc())
display_chat_history()