lpetrov's picture
Initial working tutorial version with gradio downgraded to 3.50.2 due to image formatting problems
44fb66d
raw
history blame contribute delete
No virus
1.41 kB
from pathlib import Path
import torch
import gradio as gr
from torch import nn
LABELS = Path("class_names.txt").read_text().splitlines()
model = nn.Sequential(
nn.Conv2d(1, 32, 3, padding="same"),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding="same"),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding="same"),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1152, 256),
nn.ReLU(),
nn.Linear(256, len(LABELS)),
)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model.eval()
def predict(im):
x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
with torch.no_grad():
out = model(x)
probabilities = torch.nn.functional.softmax(out[0], dim=0)
values, indices = torch.topk(probabilities, 5)
return {LABELS[i]: v.item() for i, v in zip(indices, values)}
# import gradio as gr
# from app import predict
interface = gr.Interface(
predict,
inputs="sketchpad",
outputs="label",
# theme="huggingface",
title="Sketch Recognition",
description="Who wants to play Pictionary? Draw a common object like a shovel or a laptop, and the algorithm will guess in real time!",
article="<p style='text-align: center'>Sketch Recognition | Demo Model</p>",
live=False,
)
interface.launch(share=False)