Model description
This is a LogisticRegressionCV model trained on averages of patch embeddings from the Imagenette dataset. This forms the GAM of an Emb-GAM extended to images. Patch embeddings are meant to be extracted with the google/vit-base-patch16-224
ViT checkpoint.
Intended uses & limitations
This model is not intended to be used in production.
Training Procedure
Hyperparameters
The model is trained with below hyperparameters.
Click to expand
Hyperparameter | Value |
---|---|
Cs | 10 |
class_weight | |
cv | StratifiedKFold(n_splits=5, random_state=1, shuffle=True) |
dual | False |
fit_intercept | True |
intercept_scaling | 1.0 |
l1_ratios | |
max_iter | 100 |
multi_class | auto |
n_jobs | |
penalty | l2 |
random_state | 1 |
refit | False |
scoring | |
solver | lbfgs |
tol | 0.0001 |
verbose | 0 |
Model Plot
The model plot is below.
LogisticRegressionCV(cv=StratifiedKFold(n_splits=5, random_state=1, shuffle=True),random_state=1, refit=False)Please rerun this cell to show the HTML repr or trust the notebook.
LogisticRegressionCV(cv=StratifiedKFold(n_splits=5, random_state=1, shuffle=True),random_state=1, refit=False)
Evaluation Results
You can find the details about evaluation process and the evaluation results.
Metric | Value |
---|---|
accuracy | 0.99465 |
f1 score | 0.99465 |
How to Get Started with the Model
Use the code below to get started with the model.
Click to expand
from PIL import Image
from skops import hub_utils
import torch
from transformers import AutoFeatureExtractor, AutoModel
import pickle
import os
# load embedding model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModel.from_pretrained("google/vit-base-patch16-224").eval().to(device)
# load logistic regression
os.mkdir("emb-gam-vit")
hub_utils.download(repo_id="Ramos-Ramos/emb-gam-vit", dst="emb-gam-vit")
with open("emb-gam-vit/model.pkl", "rb") as file:
logistic_regression = pickle.load(file)
# load image
img = Image.open("examples/english_springer.png")
# preprocess image
inputs = {k: v.to(device) for k, v in feature_extractor(img, return_tensors='pt').items()}
# extract patch embeddings
with torch.no_grad():
patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
# classify
pred = logistic_regression.predict(patch_embeddings.sum(dim=0, keepdim=True))
# get patch contributions
patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
Model Card Authors
This model card is written by following authors:
Patrick Ramos and Ryan Ramos
Model Card Contact
You can contact the model card authors through following channels: [More Information Needed]
Citation
Below you can find information related to citation.
BibTeX:
@article{singh2022emb,
title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
author={Singh, Chandan and Gao, Jianfeng},
journal={arXiv preprint arXiv:2209.11799},
year={2022}
}
Additional Content
confusion_matrix
- Downloads last month
- 0
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.