Spaces:
Runtime error
Runtime error
import streamlit as st | |
from PIL import Image | |
import io | |
import base64 | |
import uuid | |
from gtts import gTTS | |
import google.generativeai as genai | |
from io import BytesIO | |
import PyPDF2 | |
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration | |
import asyncio | |
# Set your API key | |
api_key = "AIzaSyAHD0FwX-Ds6Y3eI-i5Oz7IdbJqR6rN7pg" # Replace with your actual API key | |
genai.configure(api_key=api_key) | |
# Configure the generative AI model | |
generation_config = genai.GenerationConfig( | |
temperature=0.9, | |
max_output_tokens=3000 | |
) | |
# Safety settings configuration | |
safety_settings = [ | |
{ | |
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_HATE_SPEECH", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_HARASSMENT", | |
"threshold": "BLOCK_NONE", | |
}, | |
] | |
# Initialize session state | |
if 'chat_history' not in st.session_state: | |
st.session_state['chat_history'] = [] | |
if 'file_uploader_key' not in st.session_state: | |
st.session_state['file_uploader_key'] = str(uuid.uuid4()) | |
if 'uploaded_files' not in st.session_state: | |
st.session_state['uploaded_files'] = [] | |
if 'user_input' not in st.session_state: | |
st.session_state['user_input'] = '' | |
if 'audio_data' not in st.session_state: | |
st.session_state['audio_data'] = None | |
# --- Streamlit UI --- | |
st.title("Gemini Chatbot") | |
st.write("Interact with the powerful Gemini 1.5 models.") | |
# Model Selection Dropdown | |
selected_model = st.selectbox("Choose a Gemini 1.5 Model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"]) | |
# TTS Option Checkbox | |
enable_tts = st.checkbox("Enable Text-to-Speech") | |
# --- Helper Functions --- | |
def get_file_base64(file_content, mime_type): | |
base64_data = base64.b64encode(file_content).decode() | |
return {"mime_type": mime_type, "data": base64_data} | |
def clear_conversation(): | |
st.session_state['chat_history'] = [] | |
st.session_state['file_uploader_key'] = str(uuid.uuid4()) | |
st.session_state['user_input'] = '' | |
st.session_state['uploaded_files'] = [] | |
st.session_state['audio_data'] = None | |
def display_chat_history(): | |
chat_container = st.empty() | |
with chat_container.container(): | |
for entry in st.session_state['chat_history']: | |
role = entry["role"] | |
parts = entry["parts"][0] | |
if 'text' in parts: | |
st.markdown(f"**{role.title()}:** {parts['text']}") | |
elif 'data' in parts: | |
mime_type = parts.get('mime_type', '') | |
if mime_type.startswith('image'): | |
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), | |
caption='Uploaded Image', use_column_width=True) | |
elif mime_type == 'application/pdf': | |
st.write("**PDF Content:**") | |
pdf_reader = PyPDF2.PdfReader(io.BytesIO(base64.b64decode(parts['data']))) | |
for page_num in range(len(pdf_reader.pages)): | |
page = pdf_reader.pages[page_num] | |
st.write(page.extract_text()) | |
elif mime_type.startswith('audio'): | |
st.audio(io.BytesIO(base64.b64decode(parts['data'])), format=mime_type) | |
elif mime_type.startswith('video'): | |
st.video(io.BytesIO(base64.b64decode(parts['data']))) | |
# --- Send Message Function --- | |
def send_message(audio_data=None): | |
user_input = st.session_state.user_input | |
uploaded_files = st.session_state.uploaded_files | |
prompt_parts = [] | |
# Add user input to the prompt | |
if user_input: | |
prompt_parts.append({"text": user_input}) | |
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]}) | |
# Handle uploaded files | |
if uploaded_files: | |
for uploaded_file in uploaded_files: | |
file_content = uploaded_file.read() | |
prompt_parts.append(get_file_base64(file_content, uploaded_file.type)) | |
st.session_state['chat_history'].append( | |
{"role": "user", "parts": [get_file_base64(file_content, uploaded_file.type)]} | |
) | |
# Handle audio data from WebRTC | |
if audio_data: | |
prompt_parts.append(get_file_base64(audio_data, 'audio/wav')) | |
st.session_state['chat_history'].append( | |
{"role": "user", "parts": [get_file_base64(audio_data, 'audio/wav')]} | |
) | |
# Generate response using the selected model | |
try: | |
model = genai.GenerativeModel( | |
model_name=selected_model, | |
generation_config=generation_config, | |
safety_settings=safety_settings | |
) | |
response = model.generate_content([{"role": "user", "parts": prompt_parts}]) | |
response_text = response.text if hasattr(response, "text") else "No response text found." | |
if response_text: | |
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]}) | |
if enable_tts: | |
tts = gTTS(text=response_text, lang='en') | |
tts_file = BytesIO() | |
tts.write_to_fp(tts_file) | |
tts_file.seek(0) | |
st.audio(tts_file, format='audio/mp3') | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
st.session_state.user_input = '' | |
st.session_state.uploaded_files = [] | |
st.session_state.file_uploader_key = str(uuid.uuid4()) | |
# Update the chat history display | |
display_chat_history() | |
# --- User Input Area --- | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
user_input = st.text_area( | |
"Enter your message:", | |
value="", | |
key="user_input" | |
) | |
with col2: | |
send_button = st.button( | |
"Send", | |
on_click=send_message, | |
type="primary" | |
) | |
# --- File Uploader --- | |
uploaded_files = st.file_uploader( | |
"Upload Files (Images, Videos, PDFs, MP3):", | |
type=["png", "jpg", "jpeg", "mp4", "pdf", "mp3"], | |
accept_multiple_files=True, | |
key=st.session_state.file_uploader_key | |
) | |
# --- WebRTC Audio Recording --- | |
RTC_CONFIGURATION = RTCConfiguration({"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}) | |
async def run_webrtc(): | |
webrtc_ctx = webrtc_streamer( | |
key="audio-recorder", | |
mode=WebRtcMode.SENDONLY, | |
rtc_configuration=RTC_CONFIGURATION, | |
audio_receiver_size=256, | |
media_stream_constraints={"video": False, "audio": True}, | |
) | |
if webrtc_ctx.audio_receiver: | |
st.write("Recording audio...") | |
audio_frames = webrtc_ctx.audio_receiver.get_frames(timeout=None) | |
st.session_state.audio_data = b"".join([frame.to_ndarray() for frame in audio_frames]) | |
if st.button("Send Recording"): | |
send_message(audio_data=st.session_state.audio_data) | |
# --- Other Buttons --- | |
st.button("Clear Conversation", on_click=clear_conversation) | |
# --- Ensure file_uploader state --- | |
st.session_state.uploaded_files = uploaded_files | |
# --- JavaScript for Ctrl+Enter --- | |
st.markdown( | |
""" | |
<script> | |
document.addEventListener('DOMContentLoaded', (event) => { | |
document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) { | |
if (e.key === 'Enter' && e.ctrlKey) { | |
document.querySelector('.stButton > button').click(); | |
e.preventDefault(); | |
} | |
}); | |
}); | |
</script> | |
""", | |
unsafe_allow_html=True | |
) | |
# --- Run WebRTC and display chat history --- | |
asyncio.new_event_loop().run_until_complete(run_webrtc()) | |
display_chat_history() |