SuSy / app.py
pabberpe's picture
Update Description
056767c
raw
history blame
3.81 kB
import gradio as gr
import numpy as np
import torch
from PIL import Image
from skimage.feature import graycomatrix, graycoprops
from torchvision import transforms
# Load the model
model = torch.jit.load("SuSy.pt")
def process_image(image):
# Set Parameters
top_k_patches = 5
patch_size = 224
# Get the image dimensions
width, height = image.size
# Calculate the number of patches
num_patches_x = width // patch_size
num_patches_y = height // patch_size
# Divide the image in patches
patches = np.zeros((num_patches_x * num_patches_y, patch_size, patch_size, 3), dtype=np.uint8)
for i in range(num_patches_x):
for j in range(num_patches_y):
x = i * patch_size
y = j * patch_size
patch = image.crop((x, y, x + patch_size, y + patch_size))
patches[i * num_patches_y + j] = np.array(patch)
# Compute the most relevant patches (optional)
dissimilarity_scores = []
for patch in patches:
transform_patch = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()])
grayscale_patch = transform_patch(Image.fromarray(patch)).squeeze(0)
glcm = graycomatrix(grayscale_patch, [5], [0], 256, symmetric=True, normed=True)
dissimilarity_scores.append(graycoprops(glcm, "contrast")[0, 0])
# Sort patch indices by their dissimilarity score
sorted_indices = np.argsort(dissimilarity_scores)[::-1]
# Extract top k patches and convert them to tensor
top_patches = patches[sorted_indices[:top_k_patches]]
top_patches = torch.from_numpy(np.transpose(top_patches, (0, 3, 1, 2))) / 255.0
# Predict patches
model.eval()
with torch.no_grad():
preds = model(top_patches)
# Process results
classes = ['Authentic', 'DALL路E 3', 'Stable Diffusion 1.x', 'MJ V5/V6', 'MJ V1/V2', 'Stable Diffusion XL']
mean_probs = preds.mean(dim=0).numpy()
# Create a dictionary of class probabilities
class_probs = {cls: prob for cls, prob in zip(classes, mean_probs)}
# Sort probabilities in descending order
sorted_probs = dict(sorted(class_probs.items(), key=lambda item: item[1], reverse=True))
return sorted_probs
# Define Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=6),
title="SuSy: Synthetic Image Detector",
description="""
<table style="border-collapse: collapse; border: none; padding: 20px;">
<tr style="border: none;">
<td style="border: none; vertical-align: top; padding-right: 30px; padding-left: 30px;">
<img src="https://cdn-uploads.huggingface.co/production/uploads/62f7a16192950415b637e201/NobqlpFbFkTyBi1LsT9JE.png" alt="SuSy Logo" width="120" style="margin-bottom: 10px;">
</td>
<td style="border: none; vertical-align: top; padding: 10px;">
<p style="margin-bottom: 15px;">Detect synthetic images with SuSy! SuSy can distinguish between authentic images and those generated by DALL路E, Midjourney and Stable Diffusion.</p>
<p style="margin-top: 15px;">Learn more about SuSy: <a href="https://arxiv.org/abs/2409.14128">Present and Future Generalization of Synthetic Image Detectors</a></p>
<p style="margin-top: 15px;">
Enter the SuSy-verse!
<a href="https://huggingface.co/HPAI-BSC/SuSy">Model</a> |
<a href="https://github.com/HPAI-BSC/SuSy">Code</a> |
<a href="https://huggingface.co/datasets/HPAI-BSC/SuSy-Dataset">Dataset</a>
</p>
</td>
</tr>
</table>
""",
examples=[
["example_authentic.jpg"],
["example_dalle3.jpg"],
["example_mjv5.jpg"],
["example_sdxl.jpg"]
]
)
# Launch the interface
iface.launch()