Spaces:
Sleeping
Sleeping
import os | |
import asyncio | |
import gradio as gr | |
from groq import AsyncGroq | |
username = os.getenv('USERNAME') | |
password = os.getenv('PASSWORD') | |
GROQ_API_KEY = os.getenv('GROQ_API_KEY') | |
client = AsyncGroq(api_key=GROQ_API_KEY) | |
chat_model = "llama-3.1-70b-versatile" | |
moderation_model = "llama-guard-3-8b" | |
system_prompt = """You are a helpful, respectful assistant engaging in educational conversations with students.""" | |
async def moderate_message(message): | |
"""Moderate the user message using Llama Guard 3.""" | |
response = await client.chat.completions.create( | |
model=moderation_model, | |
messages=[{ | |
"role": "user", | |
"content": message | |
}], | |
temperature=0.1, | |
max_tokens=10, | |
) | |
print(response.choices[0].message.content) | |
return response.choices[0].message.content.strip().lower() == "safe" | |
async def chat_response(message, history): | |
"""Generate a chat response, including moderation.""" | |
is_safe = await moderate_message(message) | |
if not is_safe: | |
yield "I apologize, but I can't respond to that type of message. Let's keep our conversation appropriate." | |
return | |
messages = [ | |
{ | |
"role": "system", | |
"content": system_prompt | |
}, | |
] | |
for i, (user_msg, assistant_msg) in enumerate(history): | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": message}) | |
stream = await client.chat.completions.create( | |
model=chat_model, | |
messages=messages, | |
temperature=0.7, | |
max_tokens=321, | |
top_p=1, | |
stream=True, | |
) | |
response_content = "" | |
async for chunk in stream: | |
content = chunk.choices[0].delta.content | |
if content: | |
response_content += content | |
yield response_content | |
async def user(message, history): | |
"""Process user input and update history.""" | |
return "", history + [[message, None]] | |
async def bot(history): | |
"""Process bot response and update history.""" | |
user_message = history[-1][0] | |
async for response in chat_response(user_message, history[:-1]): | |
history[-1][1] = response | |
yield history | |
with gr.Blocks() as demo: | |
gr.Image("llama_pic.png", label="Logo", height=200) | |
gr.Markdown("# SafeChat 🛡️") | |
gr.Markdown( | |
"This chat interface uses Llama Guard 3 to moderate messages and ensure safe interactions. It is an adaptation of a code by Martin Bowling" | |
) | |
chatbot = gr.Chatbot(height=400, show_label=False) | |
message = gr.Textbox(placeholder="Type your message here...", label="User Input", show_label=False, container=False) | |
submit = gr.Button("Send", variant="primary") | |
clear = gr.Button("Clear") | |
submit.click(user, [message, chatbot], [message, chatbot], queue=False).then( | |
bot, chatbot, chatbot | |
) | |
message.submit(user, [message, chatbot], [message, chatbot], queue=False).then( | |
bot, chatbot, chatbot | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch(auth=(username, password)) |