jsulz HF staff commited on
Commit
746d998
1 Parent(s): 6209161

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.keras filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -10,4 +10,46 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  license: mit
11
  ---
12
 
13
+ # MNIST Streamlit
14
+
15
+ This is a simple Streamlit app that demonstrates the differences between neural nets trained on the MNIST
16
+
17
+ There are three models saved locally available in the `models` directory:
18
+
19
+ - `autokeras_model.keras`
20
+ - `mnist_12.onnx`
21
+ - `mnist_model.keras`
22
+
23
+ The `mnist_model.keras` is a simple 300x300 neural net trained over 35 epochs.
24
+
25
+ The `autokeras_model.keras` is a more complex model generated by running the [Autokeras image classifier class](https://autokeras.com/image_classifier/).
26
+
27
+ Meanwhile, the `mnist_12.onnx` model is a pre-trained model from theOnnx model zoo. Onnx provides detailed information about how the model was created [in the repository on GitHub](https://github.com/onnx/models/blob/main/validated/vision/classification/mnist/README.md).
28
+
29
+ The application allows you to:
30
+
31
+ 1. Select which model you want to use for predicting a handwritten digit
32
+ 2. Select your stroke width of the digit you draw
33
+ 3. Draw a specific digit within a canvas
34
+
35
+ Once you draw a digit, the model will be loaded, asked to make a prediction on your input, and provide:
36
+
37
+ - The name of the model used to make the prediction
38
+ - A prediction (the top prediction from it's probability distribution)
39
+ - The time the model took to predict
40
+ - The time it took to load the model
41
+ - The probability distribution of predictions as a bar chart and table
42
+
43
+ ## Usage
44
+
45
+ To run the Streamlit app locally using Poetry, clone the repository, `cd` into the created directory, and run the following commands:
46
+
47
+ - `poetry shell`
48
+ - `poetry install`
49
+ - `streamlit run app.py`
50
+
51
+ If you don't have Poetry installed, never fear! There is a `requirements.txt` file that you may use to install the necessary packages with Pip. Simply create a new virtual environment and run:
52
+
53
+ ```shell
54
+ pip install -r requirements.txt
55
+ ```
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A small Streamlit app that loads a Keras model trained on the MNIST dataset and allows the user to draw a digit on a canvas and get a predicted digit from the model.
3
+ """
4
+
5
+ import streamlit as st
6
+ from PIL import Image
7
+ from streamlit_drawable_canvas import st_canvas
8
+ import os
9
+ import numpy as np
10
+ from keras import models
11
+ import keras.datasets.mnist as mnist
12
+ import matplotlib.pyplot as plt
13
+ import pandas as pd
14
+ import time
15
+ import onnx
16
+ import onnxruntime
17
+ from scipy.special import softmax
18
+
19
+
20
+ @st.cache_resource
21
+ def load_picture():
22
+ """
23
+ Loads the first 9 images from the mnist dataset and add them to a plot
24
+ to be displayed in streamlit.
25
+ """
26
+ # load the mnist dataset
27
+ (x_train, y_train), (x_test, y_test) = mnist.load_data()
28
+ # plot the first 9 images
29
+ for i in range(9):
30
+ plt.subplot(330 + 1 + i)
31
+ image = x_train[i] / 255.0
32
+ plt.imshow(image, cmap=plt.get_cmap("gray"))
33
+
34
+ # Save the plot as a png file and show it in streamlit
35
+ # This is commented out for not because the plot was created and saved in the img directory during the initial run of the app locally
36
+ # plt.savefig("img/show.png")
37
+ st.image("img/show.png", width=250, caption="First 9 images from the MNIST dataset")
38
+
39
+
40
+ def keras_prediction(final, model_path):
41
+ load_time = time.time()
42
+ model = models.load_model(
43
+ os.path.abspath(os.path.join(os.path.dirname(__file__), model_path))
44
+ )
45
+ after_load_curr = time.time()
46
+ curr_time = time.time()
47
+ prediction = model.predict(final[None, ...])
48
+ after_time = time.time()
49
+ return prediction, after_time - curr_time, after_load_curr - load_time
50
+
51
+
52
+ def onnx_prediction(final, model_path):
53
+ im_np = np.expand_dims(final, axis=0) # Add batch dimension
54
+ im_np = np.expand_dims(im_np, axis=0) # Add channel dimension
55
+ im_np = im_np.astype("float32")
56
+ load_curr = time.time()
57
+ session = onnxruntime.InferenceSession(model_path, None)
58
+ input_name = session.get_inputs()[0].name
59
+ output_name = session.get_outputs()[0].name
60
+ after_load_curr = time.time()
61
+
62
+ curr_time = time.time()
63
+ result = session.run([output_name], {input_name: im_np})
64
+ prediction = softmax(np.array(result).squeeze(), axis=0)
65
+ after_time = time.time()
66
+ return prediction, after_time - curr_time, after_load_curr - load_curr
67
+
68
+
69
+ def main():
70
+ """
71
+ The main function/primary entry point of the app
72
+ """
73
+ # write the title of the page as MNIST Digit Recognizer
74
+ st.title("MNIST Digit Recognizer")
75
+
76
+ col1, col2 = st.columns([0.8, 0.2], gap="small")
77
+ with col1:
78
+ st.markdown(
79
+ """
80
+ This Streamlit app loads a Keras neural network trained on the MNIST dataset to predict handwritten digits. Draw a digit in the canvas below and see the model's prediction. You can:
81
+ - Change the stroke width of the digit using the slider
82
+ - Choose what model you use for predictions
83
+ - Onnx: The mnist-12 Onnx model from <a href="https://xethub.com/XetHub/onnx-models/src/branch/main/vision/classification/mnist">Onnx's pre-trained MNIST models</a>
84
+ - Autokeras: A model generated using the <a href="https://autokeras.com/image_classifier/">Autokeras image classifier class</a>
85
+ - Basic: A simple two layer nueral net where each layer has 300 nodes
86
+
87
+ Like any machine learning model, this model is a function of the data it was fed during training. As you can see in the picture, the numbers in the images have a specific shape, location, and size. By playing around with the stroke width and where you draw the digit, you can see how the model's prediction changes.""",
88
+ unsafe_allow_html=True,
89
+ )
90
+ with col2:
91
+ # Load the first 9 images from the MNIST dataset and show them
92
+ load_picture()
93
+
94
+ col3, col4 = st.columns(2, gap="small")
95
+
96
+ with col4:
97
+ # Stroke width slider to change the width of the canvas stroke
98
+ # Starts at 10 because that's reasonably close to the width of the MNIST digits
99
+ stroke_width = st.slider("Stroke width: ", 1, 25, 10)
100
+ model_choice = st.selectbox(
101
+ "Choose what model to use for predictions:", ("Onnx", "Autokeras", "Basic")
102
+ )
103
+ if "Basic" in model_choice:
104
+ model_path = "models/mnist_model.keras"
105
+
106
+ if "Auto" in model_choice:
107
+ model_path = "models/autokeras_model.keras"
108
+
109
+ if "Onnx" in model_choice:
110
+ model_path = "models/mnist_12.onnx"
111
+
112
+ with col3:
113
+ # Create a canvas component
114
+ canvas_result = st_canvas(
115
+ stroke_width=stroke_width,
116
+ stroke_color="#FFF",
117
+ fill_color="#000",
118
+ background_color="#000",
119
+ background_image=None,
120
+ update_streamlit=True,
121
+ height=200,
122
+ width=200,
123
+ drawing_mode="freedraw",
124
+ point_display_radius=0,
125
+ key="canvas",
126
+ )
127
+
128
+ if canvas_result is not None and canvas_result.image_data is not None:
129
+
130
+ # Get the image data, convert it to grayscale, and resize it to 28x28 (the same size as the MNIST dataset images)
131
+ img_data = canvas_result.image_data
132
+ im = Image.fromarray(img_data.astype("uint8")).convert("L")
133
+ im = im.resize((28, 28))
134
+
135
+ # Convert the image to a numpy array and normalize the values
136
+ final = np.array(im, dtype=np.float32) / 255.0
137
+
138
+ # if final is not all zeros, run the prediction
139
+ if not np.all(final == 0):
140
+
141
+ if model_choice != "Onnx":
142
+ prediction, pred_time, load_time = keras_prediction(final, model_path)
143
+ else:
144
+ prediction, pred_time, load_time = onnx_prediction(final, model_path)
145
+
146
+ # print the prediction
147
+ st.header(f"Using model: {model_choice}")
148
+ st.write(f"Prediction: {np.argmax(prediction)}")
149
+ st.write(f"Load time (in ms): {(load_time) * 1000:.2f}")
150
+ st.write(f"Prediction time (in ms): {(pred_time) * 1000:.2f}")
151
+
152
+ # Create a 2 column dataframe with one column as the digits and the other as the probability
153
+ data = pd.DataFrame(
154
+ {"Digit": list(range(10)), "Probability": np.ravel(prediction)}
155
+ )
156
+
157
+ col1, col2 = st.columns([0.8, 0.2], gap="small")
158
+ # create a bar chart to show the predictions
159
+ with col1:
160
+ st.bar_chart(data, x="Digit", y="Probability", height=500)
161
+
162
+ # show the probability distribution numerically
163
+ with col2:
164
+ data["Probability"] = data["Probability"].apply(lambda x: f"{x:.2%}")
165
+ st.dataframe(data, hide_index=True)
166
+
167
+
168
+ if __name__ == "__main__":
169
+ main()
img/show.png ADDED
models/autokeras_model.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52ade8cfa53511a4c1bb4718d6aae0c9997537b6aa41cbb0561e9a2ba8776379
3
+ size 1374323
models/mnist_12.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c688690f8bacf667d4c2074af5ad0646ca328d7ab03eccf944a65b320171bdd
3
+ size 26143
models/mnist_model.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca930b40ac309a75ed3ccea89c9dad0b54e957de648f052e87f5c5799173923f
3
+ size 1333377
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "mnist-streamlit"
3
+ version = "0.1.0"
4
+ description = "A small Streamlit project demoing a Keras-backed neural network trained on the MNIST dataset."
5
+ authors = ["jsulz <[email protected]>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.12"
10
+ streamlit = "^1.35.0"
11
+ keras = "^3.3.3"
12
+ tensorflow = "^2.16.1"
13
+ streamlit-drawable-canvas = "^0.9.3"
14
+ pillow = "^10.3.0"
15
+ numpy = "^1.26.4"
16
+ matplotlib = "^3.9.0"
17
+ onnx = "^1.16.1"
18
+ onnxruntime = "^1.18.1"
19
+ scipy = "1.13.0"
20
+ autokeras = "^2.0.0"
21
+
22
+
23
+ [build-system]
24
+ requires = ["poetry-core"]
25
+ build-backend = "poetry.core.masonry.api"
requirements.txt ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ altair==5.3.0
3
+ astunparse==1.6.3
4
+ attrs==23.2.0
5
+ autokeras==2.0.0
6
+ blinker==1.8.2
7
+ cachetools==5.3.3
8
+ certifi==2024.6.2
9
+ charset-normalizer==3.3.2
10
+ click==8.1.7
11
+ colorama==0.4.6
12
+ coloredlogs==15.0.1
13
+ contourpy==1.2.1
14
+ cycler==0.12.1
15
+ dm-tree==0.1.8
16
+ flatbuffers==24.3.25
17
+ fonttools==4.53.0
18
+ gast==0.5.4
19
+ gitdb==4.0.11
20
+ gitpython==3.1.43
21
+ google-pasta==0.2.0
22
+ grpcio==1.64.1
23
+ h5py==3.11.0
24
+ humanfriendly==10.0
25
+ idna==3.7
26
+ jinja2==3.1.4
27
+ jsonschema-specifications==2023.12.1
28
+ jsonschema==4.22.0
29
+ kagglehub==0.2.9
30
+ keras-nlp==0.14.3
31
+ keras-tuner==1.4.7
32
+ keras==3.3.3
33
+ kiwisolver==1.4.5
34
+ kt-legacy==1.0.5
35
+ libclang==18.1.1
36
+ markdown-it-py==3.0.0
37
+ markdown==3.6
38
+ markupsafe==2.1.5
39
+ matplotlib==3.9.0
40
+ mdurl==0.1.2
41
+ ml-dtypes==0.3.2
42
+ mpmath==1.3.0
43
+ namex==0.0.8
44
+ numpy==1.26.4
45
+ onnx==1.16.1
46
+ onnxruntime==1.18.1
47
+ opt-einsum==3.3.0
48
+ optree==0.11.0
49
+ packaging==24.0
50
+ pandas==2.2.2
51
+ pillow==10.3.0
52
+ protobuf==4.25.3
53
+ pyarrow==16.1.0
54
+ pydeck==0.9.1
55
+ pygments==2.18.0
56
+ pyparsing==3.1.2
57
+ pyreadline3==3.4.1
58
+ python-dateutil==2.9.0.post0
59
+ pytz==2024.1
60
+ referencing==0.35.1
61
+ regex==2024.7.24
62
+ requests==2.32.3
63
+ rich==13.7.1
64
+ rpds-py==0.18.1
65
+ scipy==1.13.0
66
+ setuptools==70.0.0
67
+ six==1.16.0
68
+ smmap==5.0.1
69
+ streamlit-drawable-canvas==0.9.3
70
+ streamlit==1.35.0
71
+ sympy==1.12.1
72
+ tenacity==8.3.0
73
+ tensorboard-data-server==0.7.2
74
+ tensorboard==2.16.2
75
+ tensorflow-text==2.16.1
76
+ tensorflow==2.16.1
77
+ termcolor==2.4.0
78
+ toml==0.10.2
79
+ toolz==0.12.1
80
+ tornado==6.4.1
81
+ tqdm==4.66.5
82
+ typing-extensions==4.12.2
83
+ tzdata==2024.1
84
+ urllib3==2.2.1
85
+ watchdog==4.0.1
86
+ werkzeug==3.0.3
87
+ wheel==0.43.0
88
+ wrapt==1.16.0