obipix-bird-ID / app.py
dhanesh123in's picture
initial commit
02a29e3
raw
history blame
1.27 kB
import streamlit as st
import pandas as pd
from io import BytesIO
from PIL import Image
import time
from transformers import AutoImageProcessor, ViTForImageClassification
import torch
image_processor = AutoImageProcessor.from_pretrained("dhanesh123in/image_classification_obipix_birdID")
model_s = ViTForImageClassification.from_pretrained("dhanesh123in/image_classification_obipix_birdID")
st.title("Welcome to Bird Species Identification App!")
uploaded_file = st.file_uploader("Upload Image")
if uploaded_file is not None:
# To read file as bytes:
bytes_data = uploaded_file.getvalue()
image = Image.open(BytesIO(bytes_data))
inputs = image_processor(image, return_tensors="pt")
with torch.no_grad():
logits = model_s(**inputs).logits
# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
prediction=model_s.config.id2label[predicted_label]
with st.spinner('Our well trained AI assistant is looking at your image...'):
time.sleep(5)
st.success("Prediction is "+prediction)
st.image(bytes_data)
x=st.radio("Was this correct?",["Yes","No"],horizontal=True)
if (x=="No"):
st.write("Oops.. more to learn I guess")