0-ma's picture
Update README.md
8547cd4 verified
|
raw
history blame
No virus
1.78 kB
metadata
base_model: google/mobilenet_v2_1.0_224
datasets:
  - 0-ma/geometric-shapes
license: other
metrics:
  - accuracy
pipeline_tag: image-classification

Model Card for Mobilenet Geometric Shapes Dataset

Training Dataset

Base Model

Accuracy

  • Accuracy on dataset 0-ma/geometric-shapes [test] : 0.7683333333333333

Loading and using the model

import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification 
import requests
labels =  [
    "None",
    "Circle",
    "Triangle",
    "Square",
    "Pentagon",
    "Hexagon"
] 
images = [Image.open(requests.get("https://raw.githubusercontent.com/0-ma/geometric-shape-detector/main/input/exemple_circle.jpg", stream=True).raw), 
        Image.open(requests.get("https://raw.githubusercontent.com/0-ma/geometric-shape-detector/main/input/exemple_pentagone.jpg", stream=True).raw)]
feature_extractor = AutoImageProcessor.from_pretrained('0-ma/mobilenet-v2-geometric-shapes')
model = AutoModelForImageClassification.from_pretrained('0-ma/mobilenet-v2-geometric-shapes')
inputs = feature_extractor(images=images, return_tensors="pt")
logits = model(**inputs)['logits'].cpu().detach().numpy()
predictions = np.argmax(logits, axis=1)    
predicted_labels = [labels[prediction] for prediction in predictions]
print(predicted_labels)

Model generation

The model has been created using the 'train_shape_detector.py.py' of the project from the project https://github.com/0-ma/geometric-shape-detector. No external code sources were used.