SafeChat / app.py
jeremierostan's picture
Update app.py
d9dee4f verified
raw
history blame
No virus
3.26 kB
import os
import asyncio
import gradio as gr
from groq import AsyncGroq
username = os.getenv('USERNAME')
password = os.getenv('PASSWORD')
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
client = AsyncGroq(api_key=GROQ_API_KEY)
chat_model = "llama-3.1-70b-versatile"
moderation_model = "llama-guard-3-8b"
system_prompt = """You are a helpful, respectful assistant engaging in educational conversations with students."""
async def moderate_message(message):
"""Moderate the user message using Llama Guard 3."""
response = await client.chat.completions.create(
model=moderation_model,
messages=[{
"role": "user",
"content": message
}],
temperature=0.1,
max_tokens=10,
)
print(response.choices[0].message.content)
return response.choices[0].message.content.strip().lower() == "safe"
async def chat_response(message, history):
"""Generate a chat response, including moderation."""
is_safe = await moderate_message(message)
if not is_safe:
yield "I apologize, but I can't respond to that type of message. Let's keep our conversation appropriate."
return
messages = [
{
"role": "system",
"content": system_prompt
},
]
for i, (user_msg, assistant_msg) in enumerate(history):
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
stream = await client.chat.completions.create(
model=chat_model,
messages=messages,
temperature=0.7,
max_tokens=321,
top_p=1,
stream=True,
)
response_content = ""
async for chunk in stream:
content = chunk.choices[0].delta.content
if content:
response_content += content
yield response_content
async def user(message, history):
"""Process user input and update history."""
return "", history + [[message, None]]
async def bot(history):
"""Process bot response and update history."""
user_message = history[-1][0]
async for response in chat_response(user_message, history[:-1]):
history[-1][1] = response
yield history
with gr.Blocks() as demo:
gr.Image("llama_pic.png", label="Logo", height=200)
gr.Markdown("# SafeChat 🛡️")
gr.Markdown(
"This chat interface uses Llama Guard 3 to moderate messages and ensure safe interactions. It is an adaptation of a code by Martin Bowling"
)
chatbot = gr.Chatbot(height=400, show_label=False)
message = gr.Textbox(placeholder="Type your message here...", label="User Input", show_label=False, container=False)
submit = gr.Button("Send", variant="primary")
clear = gr.Button("Clear")
submit.click(user, [message, chatbot], [message, chatbot], queue=False).then(
bot, chatbot, chatbot
)
message.submit(user, [message, chatbot], [message, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.queue()
demo.launch(auth=(username, password))