import os import gradio as gr import numpy as np import json import redis import plotly.graph_objects as go from datetime import datetime from PIL import Image from kit import compute_performance, compute_quality import dotenv dotenv.load_dotenv() CSS = """ .tabs button{ font-size: 20px; } #download_btn { height: 91.6px; } #submit_btn { height: 91.6px; } #original_image { display: block; margin-left: auto; margin-right: auto; } #uploaded_image { display: block; margin-left: auto; margin-right: auto; } #leaderboard_plot { display: block; margin-left: auto; margin-right: auto; width: 512px; /* Adjust width as needed */ height: 512px; /* Adjust height as needed */ } """ # Connect to Redis redis_client = redis.Redis( host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), username=os.getenv("REDIS_USERNAME"), password=os.getenv("REDIS_PASSWORD"), decode_responses=True, ) def save_to_redis(name, performance, quality): submission = { "name": name, "performance": performance, "quality": quality, "timestamp": datetime.now().isoformat(), } redis_client.lpush("submissions", json.dumps(submission)) def get_submissions_from_redis(): submissions = redis_client.lrange("submissions", 0, -1) return [json.loads(submission) for submission in submissions] def update_plot( submissions, current_name=None, ): names = [sub["name"] for sub in submissions] performances = [float(sub["performance"]) for sub in submissions] qualities = [float(sub["quality"]) for sub in submissions] # Create scatter plot fig = go.Figure() for name, quality, performance in zip(names, qualities, performances): if name == current_name: marker = dict(symbol="star", size=15, color="blue") elif name.startswith("Baseline: "): marker = dict(symbol="square", size=10, color="grey") else: marker = dict(symbol="circle", size=10, color="green") fig.add_trace( go.Scatter( x=[quality], y=[performance], mode="markers+text", text=[name], textposition="top center", name=name, marker=marker, hovertemplate=f"{'Name: ' + name if not name.startswith('Baseline: ') else name}
(Performance, Quality) = ({performance:.3f}, {quality:.3f})", ) ) # Add circles circle_radii = np.linspace(0, 1, 5) for radius in circle_radii: theta = np.linspace(0, 2 * np.pi, 100) x = radius * np.cos(theta) y = radius * np.sin(theta) fig.add_trace( go.Scatter( x=x, y=y, mode="lines", line=dict(color="gray", dash="dash"), showlegend=False, ) ) # Update layout fig.update_layout( xaxis_title="Image Quality Degredation", yaxis_title="Watermark Detection Performance", xaxis=dict( range=[0, 1.1], titlefont=dict(size=16) # Adjust this value as needed ), yaxis=dict( range=[0, 1.1], titlefont=dict(size=16) # Adjust this value as needed ), width=512, height=512, showlegend=False, # Remove legend ) return fig def process_submission(name, image): original_image = Image.open("./image.png") progress = gr.Progress() progress(0, desc="Detecting Watermark") performance = compute_performance(image) progress(0.4, desc="Evaluating Image Quality") quality = compute_quality(image, original_image) progress(1.0, desc="Uploading Results") save_to_redis(name, performance, quality) submissions = get_submissions_from_redis() leaderboard_plot = update_plot(submissions, current_name=name) # Calculate rank distances = [ np.sqrt(float(s["quality"]) ** 2 + float(s["performance"]) ** 2) for s in submissions ] rank = ( sorted(distances, reverse=True).index(np.sqrt(quality**2 + performance**2)) + 1 ) gr.Info(f"You ranked {rank} out of {len(submissions)}!") return ( leaderboard_plot, f"{rank} out of {len(submissions)}", name, f"{performance:.3f}", f"{quality:.3f}", f"{np.sqrt(quality**2 + performance**2):.3f}", ) def upload_and_evaluate(name, image): if name == "": raise gr.Error("Please enter your name before submitting.") if image is None: raise gr.Error("Please upload an image before submitting.") return process_submission(name, image) def create_interface(): with gr.Blocks(css=CSS) as demo: gr.Markdown( """ # Erasing the Invisible Demo TODO: Improve title and add description, add icon.jpg, also improve configs in README.md """ ) with gr.Tabs(elem_classes=["tabs"]) as tabs: with gr.Tab("Original Watermarked Image", id="download"): gr.Markdown( """ TODO: Add descriptions """ ) with gr.Column(): original_image = gr.Image( value="./image.png", format="png", label="Original Watermarked Image", show_label=True, height=512, width=512, type="filepath", show_download_button=False, show_share_button=False, show_fullscreen_button=False, container=True, elem_id="original_image", ) with gr.Row(): download_btn = gr.DownloadButton( "Download Watermarked Image", value="./image.png", elem_id="download_btn", ) submit_btn = gr.Button( "Submit Your Removal", elem_id="submit_btn" ) with gr.Tab( "Submit Watermark Removed Image", id="submit", elem_classes="gr-tab-header", ): gr.Markdown( """ TODO: Add descriptions """ ) with gr.Column(): uploaded_image = gr.Image( label="Your Watermark Removed Image", format="png", show_label=True, height=512, width=512, sources=["upload"], type="pil", show_download_button=False, show_share_button=False, show_fullscreen_button=False, container=True, placeholder="Upload your watermark removed image", elem_id="uploaded_image", ) with gr.Row(): name_input = gr.Textbox( label="Your Name", placeholder="Anonymous" ) upload_btn = gr.Button("Upload and Evaluate") with gr.Tab( "Evaluation Results and Your Ranking", id="leaderboard", elem_classes="gr-tab-header", ): gr.Markdown( """ TODO: Add descriptions """ ) with gr.Column(): leaderboard_plot = gr.Plot( value=update_plot(get_submissions_from_redis()), show_label=False, elem_id="leaderboard_plot", ) with gr.Row(): rank_output = gr.Textbox(label="Your Ranking") name_output = gr.Textbox(label="Your Name") performance_output = gr.Textbox( label="Watermark Performance (lower is better)" ) quality_output = gr.Textbox( label="Quality Degredation (lower is better)" ) overall_output = gr.Textbox( label="Overall Score (lower is better)" ) submit_btn.click(lambda: gr.Tabs(selected="submit"), None, tabs) upload_btn.click(lambda: gr.Tabs(selected="leaderboard"), None, tabs).then( upload_and_evaluate, inputs=[name_input, uploaded_image], outputs=[ leaderboard_plot, rank_output, name_output, performance_output, quality_output, overall_output, ], ) demo.load( lambda: [ gr.Image(value="./image.png", height=512, width=512), gr.Plot(update_plot(get_submissions_from_redis())), ], outputs=[original_image, leaderboard_plot], ) return demo # Create the demo object demo = create_interface() # Launch the app if __name__ == "__main__": demo.launch(share=False)