import gradio as gr from transformers import AutoModelForImageClassification, AutoImageProcessor from PIL import Image import torch import torchvision.transforms as transforms # Load the model and processor model = AutoModelForImageClassification.from_pretrained("1ancelot/base_rn") # Update with your model path processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") # Define your image preprocessing pipeline def preprocess_image(image): val_test_resnet_combined_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.Lambda(lambda img: processor(images=img, return_tensors="pt")['pixel_values'].squeeze(0)) ]) transformed_image = val_test_resnet_combined_transforms(image) return transformed_image # Inference function def predict(image): # Preprocess the image image_tensor = preprocess_image(image) # Add batch dimension image_tensor = image_tensor.unsqueeze(0) # Perform inference with torch.no_grad(): outputs = model(image_tensor) logits = outputs.logits predicted_class = logits.argmax(-1).item() # Map the predicted class to 'fake' or 'real' class_names = {0: 'fake', 1: 'real'} predicted_label = class_names[predicted_class] return f"Predicted Class: {predicted_label}" # Create the Gradio interface image_input = gr.Image(type="pil") output = gr.Textbox() interface = gr.Interface(fn=predict, inputs=image_input, outputs=output, title="Image Classification with Base ResNet") # Launch the app if __name__ == "__main__": interface.launch()