mcding
minor fixes
2bbda54
raw
history blame
9.71 kB
import os
import gradio as gr
import numpy as np
import json
import redis
import plotly.graph_objects as go
from datetime import datetime
from PIL import Image
from kit import compute_performance, compute_quality
import dotenv
dotenv.load_dotenv()
CSS = """
.tabs button{
font-size: 20px;
}
#download_btn {
height: 91.6px;
}
#submit_btn {
height: 91.6px;
}
#original_image {
display: block;
margin-left: auto;
margin-right: auto;
}
#uploaded_image {
display: block;
margin-left: auto;
margin-right: auto;
}
#leaderboard_plot {
display: block;
margin-left: auto;
margin-right: auto;
width: 512px; /* Adjust width as needed */
height: 512px; /* Adjust height as needed */
}
"""
# Connect to Redis
redis_client = redis.Redis(
host=os.getenv("REDIS_HOST"),
port=os.getenv("REDIS_PORT"),
username=os.getenv("REDIS_USERNAME"),
password=os.getenv("REDIS_PASSWORD"),
decode_responses=True,
)
def save_to_redis(name, performance, quality):
submission = {
"name": name,
"performance": performance,
"quality": quality,
"timestamp": datetime.now().isoformat(),
}
redis_client.lpush("submissions", json.dumps(submission))
def get_submissions_from_redis():
submissions = redis_client.lrange("submissions", 0, -1)
return [json.loads(submission) for submission in submissions]
def update_plot(
submissions,
current_name=None,
):
names = [sub["name"] for sub in submissions]
performances = [float(sub["performance"]) for sub in submissions]
qualities = [float(sub["quality"]) for sub in submissions]
# Create scatter plot
fig = go.Figure()
for name, quality, performance in zip(names, qualities, performances):
if name == current_name:
marker = dict(symbol="star", size=15, color="blue")
elif name.startswith("Baseline: "):
marker = dict(symbol="square", size=10, color="grey")
else:
marker = dict(symbol="circle", size=10, color="green")
fig.add_trace(
go.Scatter(
x=[quality],
y=[performance],
mode="markers+text",
text=[name],
textposition="top center",
name=name,
marker=marker,
hovertemplate=f"{'Name: ' + name if not name.startswith('Baseline: ') else name}<br>(Performance, Quality) = ({performance:.3f}, {quality:.3f})",
)
)
# Add circles
circle_radii = np.linspace(0, 1, 5)
for radius in circle_radii:
theta = np.linspace(0, 2 * np.pi, 100)
x = radius * np.cos(theta)
y = radius * np.sin(theta)
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="lines",
line=dict(color="gray", dash="dash"),
showlegend=False,
)
)
# Update layout
fig.update_layout(
xaxis_title="Image Quality Degredation",
yaxis_title="Watermark Detection Performance",
xaxis=dict(
range=[0, 1.1], titlefont=dict(size=16) # Adjust this value as needed
),
yaxis=dict(
range=[0, 1.1], titlefont=dict(size=16) # Adjust this value as needed
),
width=512,
height=512,
showlegend=False, # Remove legend
)
return fig
def process_submission(name, image):
original_image = Image.open("./image.png")
progress = gr.Progress()
progress(0, desc="Detecting Watermark")
performance = compute_performance(image)
progress(0.4, desc="Evaluating Image Quality")
quality = compute_quality(image, original_image)
progress(1.0, desc="Uploading Results")
save_to_redis(name, performance, quality)
submissions = get_submissions_from_redis()
leaderboard_plot = update_plot(submissions, current_name=name)
# Calculate rank
distances = [
np.sqrt(float(s["quality"]) ** 2 + float(s["performance"]) ** 2)
for s in submissions
]
rank = (
sorted(distances, reverse=True).index(np.sqrt(quality**2 + performance**2)) + 1
)
gr.Info(f"You ranked {rank} out of {len(submissions)}!")
return (
leaderboard_plot,
f"{rank} out of {len(submissions)}",
name,
f"{performance:.3f}",
f"{quality:.3f}",
f"{np.sqrt(quality**2 + performance**2):.3f}",
)
def upload_and_evaluate(name, image):
if name == "":
raise gr.Error("Please enter your name before submitting.")
if image is None:
raise gr.Error("Please upload an image before submitting.")
return process_submission(name, image)
def create_interface():
with gr.Blocks(css=CSS) as demo:
gr.Markdown(
"""
# Erasing the Invisible Demo
TODO: Improve title and add description, add icon.jpg, also improve configs in README.md
"""
)
with gr.Tabs(elem_classes=["tabs"]) as tabs:
with gr.Tab("Original Watermarked Image", id="download"):
gr.Markdown(
"""
TODO: Add descriptions
"""
)
with gr.Column():
original_image = gr.Image(
value="./image.png",
format="png",
label="Original Watermarked Image",
show_label=True,
height=512,
width=512,
type="filepath",
show_download_button=False,
show_share_button=False,
show_fullscreen_button=False,
container=True,
elem_id="original_image",
)
with gr.Row():
download_btn = gr.DownloadButton(
"Download Watermarked Image",
value="./image.png",
elem_id="download_btn",
)
submit_btn = gr.Button(
"Submit Your Removal", elem_id="submit_btn"
)
with gr.Tab(
"Submit Watermark Removed Image",
id="submit",
elem_classes="gr-tab-header",
):
gr.Markdown(
"""
TODO: Add descriptions
"""
)
with gr.Column():
uploaded_image = gr.Image(
label="Your Watermark Removed Image",
format="png",
show_label=True,
height=512,
width=512,
sources=["upload"],
type="pil",
show_download_button=False,
show_share_button=False,
show_fullscreen_button=False,
container=True,
placeholder="Upload your watermark removed image",
elem_id="uploaded_image",
)
with gr.Row():
name_input = gr.Textbox(
label="Your Name", placeholder="Anonymous"
)
upload_btn = gr.Button("Upload and Evaluate")
with gr.Tab(
"Evaluation Results and Your Ranking",
id="leaderboard",
elem_classes="gr-tab-header",
):
gr.Markdown(
"""
TODO: Add descriptions
"""
)
with gr.Column():
leaderboard_plot = gr.Plot(
value=update_plot(get_submissions_from_redis()),
show_label=False,
elem_id="leaderboard_plot",
)
with gr.Row():
rank_output = gr.Textbox(label="Your Ranking")
name_output = gr.Textbox(label="Your Name")
performance_output = gr.Textbox(
label="Watermark Performance (lower is better)"
)
quality_output = gr.Textbox(
label="Quality Degredation (lower is better)"
)
overall_output = gr.Textbox(
label="Overall Score (lower is better)"
)
submit_btn.click(lambda: gr.Tabs(selected="submit"), None, tabs)
upload_btn.click(lambda: gr.Tabs(selected="leaderboard"), None, tabs).then(
upload_and_evaluate,
inputs=[name_input, uploaded_image],
outputs=[
leaderboard_plot,
rank_output,
name_output,
performance_output,
quality_output,
overall_output,
],
)
demo.load(
lambda: [
gr.Image(value="./image.png", height=512, width=512),
gr.Plot(update_plot(get_submissions_from_redis())),
],
outputs=[original_image, leaderboard_plot],
)
return demo
# Create the demo object
demo = create_interface()
# Launch the app
if __name__ == "__main__":
demo.launch(share=False)