|
import os |
|
import gradio as gr |
|
import numpy as np |
|
import json |
|
import redis |
|
import plotly.graph_objects as go |
|
from datetime import datetime |
|
from PIL import Image |
|
from kit import compute_performance, compute_quality |
|
import dotenv |
|
|
|
dotenv.load_dotenv() |
|
|
|
CSS = """ |
|
.tabs button{ |
|
font-size: 20px; |
|
} |
|
#download_btn { |
|
height: 91.6px; |
|
} |
|
#submit_btn { |
|
height: 91.6px; |
|
} |
|
#original_image { |
|
display: block; |
|
margin-left: auto; |
|
margin-right: auto; |
|
} |
|
#uploaded_image { |
|
display: block; |
|
margin-left: auto; |
|
margin-right: auto; |
|
} |
|
#leaderboard_plot { |
|
display: block; |
|
margin-left: auto; |
|
margin-right: auto; |
|
width: 512px; /* Adjust width as needed */ |
|
height: 512px; /* Adjust height as needed */ |
|
} |
|
""" |
|
|
|
|
|
|
|
redis_client = redis.Redis( |
|
host=os.getenv("REDIS_HOST"), |
|
port=os.getenv("REDIS_PORT"), |
|
username=os.getenv("REDIS_USERNAME"), |
|
password=os.getenv("REDIS_PASSWORD"), |
|
decode_responses=True, |
|
) |
|
|
|
|
|
def save_to_redis(name, performance, quality): |
|
submission = { |
|
"name": name, |
|
"performance": performance, |
|
"quality": quality, |
|
"timestamp": datetime.now().isoformat(), |
|
} |
|
redis_client.lpush("submissions", json.dumps(submission)) |
|
|
|
|
|
def get_submissions_from_redis(): |
|
submissions = redis_client.lrange("submissions", 0, -1) |
|
return [json.loads(submission) for submission in submissions] |
|
|
|
|
|
def update_plot( |
|
submissions, |
|
current_name=None, |
|
): |
|
names = [sub["name"] for sub in submissions] |
|
performances = [float(sub["performance"]) for sub in submissions] |
|
qualities = [float(sub["quality"]) for sub in submissions] |
|
|
|
|
|
fig = go.Figure() |
|
|
|
for name, quality, performance in zip(names, qualities, performances): |
|
if name == current_name: |
|
marker = dict(symbol="star", size=15, color="blue") |
|
elif name.startswith("Baseline: "): |
|
marker = dict(symbol="square", size=10, color="grey") |
|
else: |
|
marker = dict(symbol="circle", size=10, color="green") |
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=[quality], |
|
y=[performance], |
|
mode="markers+text", |
|
text=[name], |
|
textposition="top center", |
|
name=name, |
|
marker=marker, |
|
hovertemplate=f"{'Name: ' + name if not name.startswith('Baseline: ') else name}<br>(Performance, Quality) = ({performance:.3f}, {quality:.3f})", |
|
) |
|
) |
|
|
|
|
|
circle_radii = np.linspace(0, 1, 5) |
|
for radius in circle_radii: |
|
theta = np.linspace(0, 2 * np.pi, 100) |
|
x = radius * np.cos(theta) |
|
y = radius * np.sin(theta) |
|
fig.add_trace( |
|
go.Scatter( |
|
x=x, |
|
y=y, |
|
mode="lines", |
|
line=dict(color="gray", dash="dash"), |
|
showlegend=False, |
|
) |
|
) |
|
|
|
|
|
fig.update_layout( |
|
xaxis_title="Image Quality Degredation", |
|
yaxis_title="Watermark Detection Performance", |
|
xaxis=dict( |
|
range=[0, 1.1], titlefont=dict(size=16) |
|
), |
|
yaxis=dict( |
|
range=[0, 1.1], titlefont=dict(size=16) |
|
), |
|
width=512, |
|
height=512, |
|
showlegend=False, |
|
) |
|
|
|
return fig |
|
|
|
|
|
def process_submission(name, image): |
|
original_image = Image.open("./image.png") |
|
progress = gr.Progress() |
|
progress(0, desc="Detecting Watermark") |
|
performance = compute_performance(image) |
|
progress(0.4, desc="Evaluating Image Quality") |
|
quality = compute_quality(image, original_image) |
|
progress(1.0, desc="Uploading Results") |
|
save_to_redis(name, performance, quality) |
|
|
|
submissions = get_submissions_from_redis() |
|
leaderboard_plot = update_plot(submissions, current_name=name) |
|
|
|
|
|
distances = [ |
|
np.sqrt(float(s["quality"]) ** 2 + float(s["performance"]) ** 2) |
|
for s in submissions |
|
] |
|
rank = ( |
|
sorted(distances, reverse=True).index(np.sqrt(quality**2 + performance**2)) + 1 |
|
) |
|
gr.Info(f"You ranked {rank} out of {len(submissions)}!") |
|
return ( |
|
leaderboard_plot, |
|
f"{rank} out of {len(submissions)}", |
|
name, |
|
f"{performance:.3f}", |
|
f"{quality:.3f}", |
|
f"{np.sqrt(quality**2 + performance**2):.3f}", |
|
) |
|
|
|
|
|
def upload_and_evaluate(name, image): |
|
if name == "": |
|
raise gr.Error("Please enter your name before submitting.") |
|
if image is None: |
|
raise gr.Error("Please upload an image before submitting.") |
|
return process_submission(name, image) |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(css=CSS) as demo: |
|
gr.Markdown( |
|
""" |
|
# Erasing the Invisible Demo |
|
TODO: Improve title and add description, add icon.jpg, also improve configs in README.md |
|
""" |
|
) |
|
|
|
with gr.Tabs(elem_classes=["tabs"]) as tabs: |
|
with gr.Tab("Original Watermarked Image", id="download"): |
|
gr.Markdown( |
|
""" |
|
TODO: Add descriptions |
|
""" |
|
) |
|
with gr.Column(): |
|
original_image = gr.Image( |
|
value="./image.png", |
|
format="png", |
|
label="Original Watermarked Image", |
|
show_label=True, |
|
height=512, |
|
width=512, |
|
type="filepath", |
|
show_download_button=False, |
|
show_share_button=False, |
|
show_fullscreen_button=False, |
|
container=True, |
|
elem_id="original_image", |
|
) |
|
with gr.Row(): |
|
download_btn = gr.DownloadButton( |
|
"Download Watermarked Image", |
|
value="./image.png", |
|
elem_id="download_btn", |
|
) |
|
submit_btn = gr.Button( |
|
"Submit Your Removal", elem_id="submit_btn" |
|
) |
|
|
|
with gr.Tab( |
|
"Submit Watermark Removed Image", |
|
id="submit", |
|
elem_classes="gr-tab-header", |
|
): |
|
gr.Markdown( |
|
""" |
|
TODO: Add descriptions |
|
""" |
|
) |
|
with gr.Column(): |
|
uploaded_image = gr.Image( |
|
label="Your Watermark Removed Image", |
|
format="png", |
|
show_label=True, |
|
height=512, |
|
width=512, |
|
sources=["upload"], |
|
type="pil", |
|
show_download_button=False, |
|
show_share_button=False, |
|
show_fullscreen_button=False, |
|
container=True, |
|
placeholder="Upload your watermark removed image", |
|
elem_id="uploaded_image", |
|
) |
|
with gr.Row(): |
|
name_input = gr.Textbox( |
|
label="Your Name", placeholder="Anonymous" |
|
) |
|
upload_btn = gr.Button("Upload and Evaluate") |
|
|
|
with gr.Tab( |
|
"Evaluation Results and Your Ranking", |
|
id="leaderboard", |
|
elem_classes="gr-tab-header", |
|
): |
|
gr.Markdown( |
|
""" |
|
TODO: Add descriptions |
|
""" |
|
) |
|
with gr.Column(): |
|
leaderboard_plot = gr.Plot( |
|
value=update_plot(get_submissions_from_redis()), |
|
show_label=False, |
|
elem_id="leaderboard_plot", |
|
) |
|
with gr.Row(): |
|
rank_output = gr.Textbox(label="Your Ranking") |
|
name_output = gr.Textbox(label="Your Name") |
|
performance_output = gr.Textbox( |
|
label="Watermark Performance (lower is better)" |
|
) |
|
quality_output = gr.Textbox( |
|
label="Quality Degredation (lower is better)" |
|
) |
|
overall_output = gr.Textbox( |
|
label="Overall Score (lower is better)" |
|
) |
|
|
|
submit_btn.click(lambda: gr.Tabs(selected="submit"), None, tabs) |
|
|
|
upload_btn.click(lambda: gr.Tabs(selected="leaderboard"), None, tabs).then( |
|
upload_and_evaluate, |
|
inputs=[name_input, uploaded_image], |
|
outputs=[ |
|
leaderboard_plot, |
|
rank_output, |
|
name_output, |
|
performance_output, |
|
quality_output, |
|
overall_output, |
|
], |
|
) |
|
|
|
demo.load( |
|
lambda: [ |
|
gr.Image(value="./image.png", height=512, width=512), |
|
gr.Plot(update_plot(get_submissions_from_redis())), |
|
], |
|
outputs=[original_image, leaderboard_plot], |
|
) |
|
|
|
return demo |
|
|
|
|
|
|
|
demo = create_interface() |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=False) |
|
|