Edit model card

This is a simple AI image detection model utilizing visual transformers trained on the CIFake dataset.

Example usage:

import torch
from PIL import Image
from torchvision import transforms
from transformers import ViTForImageClassification, ViTImageProcessor

# Load the trained model
model_path = 'vit_model.pth'
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 2)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Define the image preprocessing pipeline
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def predict(image_path, model, preprocess):
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    inputs = preprocess(image).unsqueeze(0)

    # Perform inference
    with torch.no_grad():
        outputs = model(inputs).logits
        predicted_label = torch.argmax(outputs).item()

    # Map the predicted label to the corresponding class
    label_map = {0: 'FAKE', 1: 'REAL'}
    predicted_class = label_map[predicted_label]
    return predicted_class

# Example usage
image_paths = [
    'path/to/image.jpg',
    'path/to/image.jpg',
    'path/to/image.jpg'
]

for image_path in image_paths:
    predicted_class = predict(image_path, model, preprocess)
    print(f'Predicted class: {predicted_class}', image_path)
Downloads last month
55
Safetensors
Model size
85.8M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.