This is a simple AI image detection model utilizing visual transformers trained on the CIFake dataset.
Example usage:
import torch
from PIL import Image
from torchvision import transforms
from transformers import ViTForImageClassification, ViTImageProcessor
# Load the trained model
model_path = 'vit_model.pth'
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 2)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Define the image preprocessing pipeline
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
def predict(image_path, model, preprocess):
# Load and preprocess the image
image = Image.open(image_path).convert('RGB')
inputs = preprocess(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
outputs = model(inputs).logits
predicted_label = torch.argmax(outputs).item()
# Map the predicted label to the corresponding class
label_map = {0: 'FAKE', 1: 'REAL'}
predicted_class = label_map[predicted_label]
return predicted_class
# Example usage
image_paths = [
'path/to/image.jpg',
'path/to/image.jpg',
'path/to/image.jpg'
]
for image_path in image_paths:
predicted_class = predict(image_path, model, preprocess)
print(f'Predicted class: {predicted_class}', image_path)
- Downloads last month
- 55
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.
Evaluation results
- accuracyself-reported0.980