Spaces:
Sleeping
Sleeping
initial commit
Browse files- .gitattributes +1 -0
- README.md +43 -1
- app.py +169 -0
- img/show.png +0 -0
- models/autokeras_model.keras +3 -0
- models/mnist_12.onnx +3 -0
- models/mnist_model.keras +3 -0
- poetry.lock +0 -0
- pyproject.toml +25 -0
- requirements.txt +88 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.keras filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -10,4 +10,46 @@ pinned: false
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
# MNIST Streamlit
|
14 |
+
|
15 |
+
This is a simple Streamlit app that demonstrates the differences between neural nets trained on the MNIST
|
16 |
+
|
17 |
+
There are three models saved locally available in the `models` directory:
|
18 |
+
|
19 |
+
- `autokeras_model.keras`
|
20 |
+
- `mnist_12.onnx`
|
21 |
+
- `mnist_model.keras`
|
22 |
+
|
23 |
+
The `mnist_model.keras` is a simple 300x300 neural net trained over 35 epochs.
|
24 |
+
|
25 |
+
The `autokeras_model.keras` is a more complex model generated by running the [Autokeras image classifier class](https://autokeras.com/image_classifier/).
|
26 |
+
|
27 |
+
Meanwhile, the `mnist_12.onnx` model is a pre-trained model from theOnnx model zoo. Onnx provides detailed information about how the model was created [in the repository on GitHub](https://github.com/onnx/models/blob/main/validated/vision/classification/mnist/README.md).
|
28 |
+
|
29 |
+
The application allows you to:
|
30 |
+
|
31 |
+
1. Select which model you want to use for predicting a handwritten digit
|
32 |
+
2. Select your stroke width of the digit you draw
|
33 |
+
3. Draw a specific digit within a canvas
|
34 |
+
|
35 |
+
Once you draw a digit, the model will be loaded, asked to make a prediction on your input, and provide:
|
36 |
+
|
37 |
+
- The name of the model used to make the prediction
|
38 |
+
- A prediction (the top prediction from it's probability distribution)
|
39 |
+
- The time the model took to predict
|
40 |
+
- The time it took to load the model
|
41 |
+
- The probability distribution of predictions as a bar chart and table
|
42 |
+
|
43 |
+
## Usage
|
44 |
+
|
45 |
+
To run the Streamlit app locally using Poetry, clone the repository, `cd` into the created directory, and run the following commands:
|
46 |
+
|
47 |
+
- `poetry shell`
|
48 |
+
- `poetry install`
|
49 |
+
- `streamlit run app.py`
|
50 |
+
|
51 |
+
If you don't have Poetry installed, never fear! There is a `requirements.txt` file that you may use to install the necessary packages with Pip. Simply create a new virtual environment and run:
|
52 |
+
|
53 |
+
```shell
|
54 |
+
pip install -r requirements.txt
|
55 |
+
```
|
app.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A small Streamlit app that loads a Keras model trained on the MNIST dataset and allows the user to draw a digit on a canvas and get a predicted digit from the model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import streamlit as st
|
6 |
+
from PIL import Image
|
7 |
+
from streamlit_drawable_canvas import st_canvas
|
8 |
+
import os
|
9 |
+
import numpy as np
|
10 |
+
from keras import models
|
11 |
+
import keras.datasets.mnist as mnist
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import pandas as pd
|
14 |
+
import time
|
15 |
+
import onnx
|
16 |
+
import onnxruntime
|
17 |
+
from scipy.special import softmax
|
18 |
+
|
19 |
+
|
20 |
+
@st.cache_resource
|
21 |
+
def load_picture():
|
22 |
+
"""
|
23 |
+
Loads the first 9 images from the mnist dataset and add them to a plot
|
24 |
+
to be displayed in streamlit.
|
25 |
+
"""
|
26 |
+
# load the mnist dataset
|
27 |
+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
28 |
+
# plot the first 9 images
|
29 |
+
for i in range(9):
|
30 |
+
plt.subplot(330 + 1 + i)
|
31 |
+
image = x_train[i] / 255.0
|
32 |
+
plt.imshow(image, cmap=plt.get_cmap("gray"))
|
33 |
+
|
34 |
+
# Save the plot as a png file and show it in streamlit
|
35 |
+
# This is commented out for not because the plot was created and saved in the img directory during the initial run of the app locally
|
36 |
+
# plt.savefig("img/show.png")
|
37 |
+
st.image("img/show.png", width=250, caption="First 9 images from the MNIST dataset")
|
38 |
+
|
39 |
+
|
40 |
+
def keras_prediction(final, model_path):
|
41 |
+
load_time = time.time()
|
42 |
+
model = models.load_model(
|
43 |
+
os.path.abspath(os.path.join(os.path.dirname(__file__), model_path))
|
44 |
+
)
|
45 |
+
after_load_curr = time.time()
|
46 |
+
curr_time = time.time()
|
47 |
+
prediction = model.predict(final[None, ...])
|
48 |
+
after_time = time.time()
|
49 |
+
return prediction, after_time - curr_time, after_load_curr - load_time
|
50 |
+
|
51 |
+
|
52 |
+
def onnx_prediction(final, model_path):
|
53 |
+
im_np = np.expand_dims(final, axis=0) # Add batch dimension
|
54 |
+
im_np = np.expand_dims(im_np, axis=0) # Add channel dimension
|
55 |
+
im_np = im_np.astype("float32")
|
56 |
+
load_curr = time.time()
|
57 |
+
session = onnxruntime.InferenceSession(model_path, None)
|
58 |
+
input_name = session.get_inputs()[0].name
|
59 |
+
output_name = session.get_outputs()[0].name
|
60 |
+
after_load_curr = time.time()
|
61 |
+
|
62 |
+
curr_time = time.time()
|
63 |
+
result = session.run([output_name], {input_name: im_np})
|
64 |
+
prediction = softmax(np.array(result).squeeze(), axis=0)
|
65 |
+
after_time = time.time()
|
66 |
+
return prediction, after_time - curr_time, after_load_curr - load_curr
|
67 |
+
|
68 |
+
|
69 |
+
def main():
|
70 |
+
"""
|
71 |
+
The main function/primary entry point of the app
|
72 |
+
"""
|
73 |
+
# write the title of the page as MNIST Digit Recognizer
|
74 |
+
st.title("MNIST Digit Recognizer")
|
75 |
+
|
76 |
+
col1, col2 = st.columns([0.8, 0.2], gap="small")
|
77 |
+
with col1:
|
78 |
+
st.markdown(
|
79 |
+
"""
|
80 |
+
This Streamlit app loads a Keras neural network trained on the MNIST dataset to predict handwritten digits. Draw a digit in the canvas below and see the model's prediction. You can:
|
81 |
+
- Change the stroke width of the digit using the slider
|
82 |
+
- Choose what model you use for predictions
|
83 |
+
- Onnx: The mnist-12 Onnx model from <a href="https://xethub.com/XetHub/onnx-models/src/branch/main/vision/classification/mnist">Onnx's pre-trained MNIST models</a>
|
84 |
+
- Autokeras: A model generated using the <a href="https://autokeras.com/image_classifier/">Autokeras image classifier class</a>
|
85 |
+
- Basic: A simple two layer nueral net where each layer has 300 nodes
|
86 |
+
|
87 |
+
Like any machine learning model, this model is a function of the data it was fed during training. As you can see in the picture, the numbers in the images have a specific shape, location, and size. By playing around with the stroke width and where you draw the digit, you can see how the model's prediction changes.""",
|
88 |
+
unsafe_allow_html=True,
|
89 |
+
)
|
90 |
+
with col2:
|
91 |
+
# Load the first 9 images from the MNIST dataset and show them
|
92 |
+
load_picture()
|
93 |
+
|
94 |
+
col3, col4 = st.columns(2, gap="small")
|
95 |
+
|
96 |
+
with col4:
|
97 |
+
# Stroke width slider to change the width of the canvas stroke
|
98 |
+
# Starts at 10 because that's reasonably close to the width of the MNIST digits
|
99 |
+
stroke_width = st.slider("Stroke width: ", 1, 25, 10)
|
100 |
+
model_choice = st.selectbox(
|
101 |
+
"Choose what model to use for predictions:", ("Onnx", "Autokeras", "Basic")
|
102 |
+
)
|
103 |
+
if "Basic" in model_choice:
|
104 |
+
model_path = "models/mnist_model.keras"
|
105 |
+
|
106 |
+
if "Auto" in model_choice:
|
107 |
+
model_path = "models/autokeras_model.keras"
|
108 |
+
|
109 |
+
if "Onnx" in model_choice:
|
110 |
+
model_path = "models/mnist_12.onnx"
|
111 |
+
|
112 |
+
with col3:
|
113 |
+
# Create a canvas component
|
114 |
+
canvas_result = st_canvas(
|
115 |
+
stroke_width=stroke_width,
|
116 |
+
stroke_color="#FFF",
|
117 |
+
fill_color="#000",
|
118 |
+
background_color="#000",
|
119 |
+
background_image=None,
|
120 |
+
update_streamlit=True,
|
121 |
+
height=200,
|
122 |
+
width=200,
|
123 |
+
drawing_mode="freedraw",
|
124 |
+
point_display_radius=0,
|
125 |
+
key="canvas",
|
126 |
+
)
|
127 |
+
|
128 |
+
if canvas_result is not None and canvas_result.image_data is not None:
|
129 |
+
|
130 |
+
# Get the image data, convert it to grayscale, and resize it to 28x28 (the same size as the MNIST dataset images)
|
131 |
+
img_data = canvas_result.image_data
|
132 |
+
im = Image.fromarray(img_data.astype("uint8")).convert("L")
|
133 |
+
im = im.resize((28, 28))
|
134 |
+
|
135 |
+
# Convert the image to a numpy array and normalize the values
|
136 |
+
final = np.array(im, dtype=np.float32) / 255.0
|
137 |
+
|
138 |
+
# if final is not all zeros, run the prediction
|
139 |
+
if not np.all(final == 0):
|
140 |
+
|
141 |
+
if model_choice != "Onnx":
|
142 |
+
prediction, pred_time, load_time = keras_prediction(final, model_path)
|
143 |
+
else:
|
144 |
+
prediction, pred_time, load_time = onnx_prediction(final, model_path)
|
145 |
+
|
146 |
+
# print the prediction
|
147 |
+
st.header(f"Using model: {model_choice}")
|
148 |
+
st.write(f"Prediction: {np.argmax(prediction)}")
|
149 |
+
st.write(f"Load time (in ms): {(load_time) * 1000:.2f}")
|
150 |
+
st.write(f"Prediction time (in ms): {(pred_time) * 1000:.2f}")
|
151 |
+
|
152 |
+
# Create a 2 column dataframe with one column as the digits and the other as the probability
|
153 |
+
data = pd.DataFrame(
|
154 |
+
{"Digit": list(range(10)), "Probability": np.ravel(prediction)}
|
155 |
+
)
|
156 |
+
|
157 |
+
col1, col2 = st.columns([0.8, 0.2], gap="small")
|
158 |
+
# create a bar chart to show the predictions
|
159 |
+
with col1:
|
160 |
+
st.bar_chart(data, x="Digit", y="Probability", height=500)
|
161 |
+
|
162 |
+
# show the probability distribution numerically
|
163 |
+
with col2:
|
164 |
+
data["Probability"] = data["Probability"].apply(lambda x: f"{x:.2%}")
|
165 |
+
st.dataframe(data, hide_index=True)
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
main()
|
img/show.png
ADDED
models/autokeras_model.keras
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:52ade8cfa53511a4c1bb4718d6aae0c9997537b6aa41cbb0561e9a2ba8776379
|
3 |
+
size 1374323
|
models/mnist_12.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5c688690f8bacf667d4c2074af5ad0646ca328d7ab03eccf944a65b320171bdd
|
3 |
+
size 26143
|
models/mnist_model.keras
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca930b40ac309a75ed3ccea89c9dad0b54e957de648f052e87f5c5799173923f
|
3 |
+
size 1333377
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "mnist-streamlit"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "A small Streamlit project demoing a Keras-backed neural network trained on the MNIST dataset."
|
5 |
+
authors = ["jsulz <[email protected]>"]
|
6 |
+
readme = "README.md"
|
7 |
+
|
8 |
+
[tool.poetry.dependencies]
|
9 |
+
python = "^3.12"
|
10 |
+
streamlit = "^1.35.0"
|
11 |
+
keras = "^3.3.3"
|
12 |
+
tensorflow = "^2.16.1"
|
13 |
+
streamlit-drawable-canvas = "^0.9.3"
|
14 |
+
pillow = "^10.3.0"
|
15 |
+
numpy = "^1.26.4"
|
16 |
+
matplotlib = "^3.9.0"
|
17 |
+
onnx = "^1.16.1"
|
18 |
+
onnxruntime = "^1.18.1"
|
19 |
+
scipy = "1.13.0"
|
20 |
+
autokeras = "^2.0.0"
|
21 |
+
|
22 |
+
|
23 |
+
[build-system]
|
24 |
+
requires = ["poetry-core"]
|
25 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
altair==5.3.0
|
3 |
+
astunparse==1.6.3
|
4 |
+
attrs==23.2.0
|
5 |
+
autokeras==2.0.0
|
6 |
+
blinker==1.8.2
|
7 |
+
cachetools==5.3.3
|
8 |
+
certifi==2024.6.2
|
9 |
+
charset-normalizer==3.3.2
|
10 |
+
click==8.1.7
|
11 |
+
colorama==0.4.6
|
12 |
+
coloredlogs==15.0.1
|
13 |
+
contourpy==1.2.1
|
14 |
+
cycler==0.12.1
|
15 |
+
dm-tree==0.1.8
|
16 |
+
flatbuffers==24.3.25
|
17 |
+
fonttools==4.53.0
|
18 |
+
gast==0.5.4
|
19 |
+
gitdb==4.0.11
|
20 |
+
gitpython==3.1.43
|
21 |
+
google-pasta==0.2.0
|
22 |
+
grpcio==1.64.1
|
23 |
+
h5py==3.11.0
|
24 |
+
humanfriendly==10.0
|
25 |
+
idna==3.7
|
26 |
+
jinja2==3.1.4
|
27 |
+
jsonschema-specifications==2023.12.1
|
28 |
+
jsonschema==4.22.0
|
29 |
+
kagglehub==0.2.9
|
30 |
+
keras-nlp==0.14.3
|
31 |
+
keras-tuner==1.4.7
|
32 |
+
keras==3.3.3
|
33 |
+
kiwisolver==1.4.5
|
34 |
+
kt-legacy==1.0.5
|
35 |
+
libclang==18.1.1
|
36 |
+
markdown-it-py==3.0.0
|
37 |
+
markdown==3.6
|
38 |
+
markupsafe==2.1.5
|
39 |
+
matplotlib==3.9.0
|
40 |
+
mdurl==0.1.2
|
41 |
+
ml-dtypes==0.3.2
|
42 |
+
mpmath==1.3.0
|
43 |
+
namex==0.0.8
|
44 |
+
numpy==1.26.4
|
45 |
+
onnx==1.16.1
|
46 |
+
onnxruntime==1.18.1
|
47 |
+
opt-einsum==3.3.0
|
48 |
+
optree==0.11.0
|
49 |
+
packaging==24.0
|
50 |
+
pandas==2.2.2
|
51 |
+
pillow==10.3.0
|
52 |
+
protobuf==4.25.3
|
53 |
+
pyarrow==16.1.0
|
54 |
+
pydeck==0.9.1
|
55 |
+
pygments==2.18.0
|
56 |
+
pyparsing==3.1.2
|
57 |
+
pyreadline3==3.4.1
|
58 |
+
python-dateutil==2.9.0.post0
|
59 |
+
pytz==2024.1
|
60 |
+
referencing==0.35.1
|
61 |
+
regex==2024.7.24
|
62 |
+
requests==2.32.3
|
63 |
+
rich==13.7.1
|
64 |
+
rpds-py==0.18.1
|
65 |
+
scipy==1.13.0
|
66 |
+
setuptools==70.0.0
|
67 |
+
six==1.16.0
|
68 |
+
smmap==5.0.1
|
69 |
+
streamlit-drawable-canvas==0.9.3
|
70 |
+
streamlit==1.35.0
|
71 |
+
sympy==1.12.1
|
72 |
+
tenacity==8.3.0
|
73 |
+
tensorboard-data-server==0.7.2
|
74 |
+
tensorboard==2.16.2
|
75 |
+
tensorflow-text==2.16.1
|
76 |
+
tensorflow==2.16.1
|
77 |
+
termcolor==2.4.0
|
78 |
+
toml==0.10.2
|
79 |
+
toolz==0.12.1
|
80 |
+
tornado==6.4.1
|
81 |
+
tqdm==4.66.5
|
82 |
+
typing-extensions==4.12.2
|
83 |
+
tzdata==2024.1
|
84 |
+
urllib3==2.2.1
|
85 |
+
watchdog==4.0.1
|
86 |
+
werkzeug==3.0.3
|
87 |
+
wheel==0.43.0
|
88 |
+
wrapt==1.16.0
|