lpetrov commited on
Commit
44fb66d
1 Parent(s): 40a29c2

Initial working tutorial version with gradio downgraded to 3.50.2 due to image formatting problems

Browse files
Files changed (5) hide show
  1. .gitignore +2 -0
  2. app.py +51 -0
  3. class_names.txt +100 -0
  4. pytorch_model.bin +3 -0
  5. requirements.txt +88 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .ipynb_checkpoints/
2
+ __pycache__
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import gradio as gr
4
+ from torch import nn
5
+
6
+ LABELS = Path("class_names.txt").read_text().splitlines()
7
+
8
+ model = nn.Sequential(
9
+ nn.Conv2d(1, 32, 3, padding="same"),
10
+ nn.ReLU(),
11
+ nn.MaxPool2d(2),
12
+ nn.Conv2d(32, 64, 3, padding="same"),
13
+ nn.ReLU(),
14
+ nn.MaxPool2d(2),
15
+ nn.Conv2d(64, 128, 3, padding="same"),
16
+ nn.ReLU(),
17
+ nn.MaxPool2d(2),
18
+ nn.Flatten(),
19
+ nn.Linear(1152, 256),
20
+ nn.ReLU(),
21
+ nn.Linear(256, len(LABELS)),
22
+ )
23
+ state_dict = torch.load("pytorch_model.bin", map_location="cpu")
24
+ model.load_state_dict(state_dict, strict=False)
25
+ model.eval()
26
+
27
+
28
+ def predict(im):
29
+ x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
30
+ with torch.no_grad():
31
+ out = model(x)
32
+ probabilities = torch.nn.functional.softmax(out[0], dim=0)
33
+ values, indices = torch.topk(probabilities, 5)
34
+ return {LABELS[i]: v.item() for i, v in zip(indices, values)}
35
+
36
+ # import gradio as gr
37
+
38
+ # from app import predict
39
+
40
+ interface = gr.Interface(
41
+ predict,
42
+ inputs="sketchpad",
43
+ outputs="label",
44
+ # theme="huggingface",
45
+ title="Sketch Recognition",
46
+ description="Who wants to play Pictionary? Draw a common object like a shovel or a laptop, and the algorithm will guess in real time!",
47
+ article="<p style='text-align: center'>Sketch Recognition | Demo Model</p>",
48
+ live=False,
49
+ )
50
+
51
+ interface.launch(share=False)
class_names.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ airplane
2
+ alarm_clock
3
+ anvil
4
+ apple
5
+ axe
6
+ baseball
7
+ baseball_bat
8
+ basketball
9
+ beard
10
+ bed
11
+ bench
12
+ bicycle
13
+ bird
14
+ book
15
+ bread
16
+ bridge
17
+ broom
18
+ butterfly
19
+ camera
20
+ candle
21
+ car
22
+ cat
23
+ ceiling_fan
24
+ cell_phone
25
+ chair
26
+ circle
27
+ clock
28
+ cloud
29
+ coffee_cup
30
+ cookie
31
+ cup
32
+ diving_board
33
+ donut
34
+ door
35
+ drums
36
+ dumbbell
37
+ envelope
38
+ eye
39
+ eyeglasses
40
+ face
41
+ fan
42
+ flower
43
+ frying_pan
44
+ grapes
45
+ hammer
46
+ hat
47
+ headphones
48
+ helmet
49
+ hot_dog
50
+ ice_cream
51
+ key
52
+ knife
53
+ ladder
54
+ laptop
55
+ light_bulb
56
+ lightning
57
+ line
58
+ lollipop
59
+ microphone
60
+ moon
61
+ mountain
62
+ moustache
63
+ mushroom
64
+ pants
65
+ paper_clip
66
+ pencil
67
+ pillow
68
+ pizza
69
+ power_outlet
70
+ radio
71
+ rainbow
72
+ rifle
73
+ saw
74
+ scissors
75
+ screwdriver
76
+ shorts
77
+ shovel
78
+ smiley_face
79
+ snake
80
+ sock
81
+ spider
82
+ spoon
83
+ square
84
+ star
85
+ stop_sign
86
+ suitcase
87
+ sun
88
+ sword
89
+ syringe
90
+ t-shirt
91
+ table
92
+ tennis_racquet
93
+ tent
94
+ tooth
95
+ traffic_light
96
+ tree
97
+ triangle
98
+ umbrella
99
+ wheel
100
+ wristwatch
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:effb6ea6f1593c09e8247944028ed9c309b5ff1cef82ba38b822bee2ca4d0f3c
3
+ size 1656903
requirements.txt ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.3.0
3
+ annotated-types==0.7.0
4
+ anyio==4.4.0
5
+ attrs==23.2.0
6
+ certifi==2024.2.2
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ contourpy==1.2.1
10
+ cycler==0.12.1
11
+ dnspython==2.6.1
12
+ email_validator==2.1.1
13
+ fastapi==0.111.0
14
+ fastapi-cli==0.0.4
15
+ ffmpy==0.3.2
16
+ filelock==3.14.0
17
+ fonttools==4.52.4
18
+ fsspec==2024.5.0
19
+ gradio==3.50.2
20
+ gradio_client==0.6.1
21
+ h11==0.14.0
22
+ httpcore==1.0.5
23
+ httptools==0.6.1
24
+ httpx==0.27.0
25
+ huggingface-hub==0.23.2
26
+ idna==3.7
27
+ importlib_resources==6.4.0
28
+ Jinja2==3.1.4
29
+ jsonschema==4.22.0
30
+ jsonschema-specifications==2023.12.1
31
+ kiwisolver==1.4.5
32
+ markdown-it-py==3.0.0
33
+ MarkupSafe==2.1.5
34
+ matplotlib==3.9.0
35
+ mdurl==0.1.2
36
+ mpmath==1.3.0
37
+ networkx==3.3
38
+ numpy==1.26.4
39
+ nvidia-cublas-cu12==12.1.3.1
40
+ nvidia-cuda-cupti-cu12==12.1.105
41
+ nvidia-cuda-nvrtc-cu12==12.1.105
42
+ nvidia-cuda-runtime-cu12==12.1.105
43
+ nvidia-cudnn-cu12==8.9.2.26
44
+ nvidia-cufft-cu12==11.0.2.54
45
+ nvidia-curand-cu12==10.3.2.106
46
+ nvidia-cusolver-cu12==11.4.5.107
47
+ nvidia-cusparse-cu12==12.1.0.106
48
+ nvidia-nccl-cu12==2.20.5
49
+ nvidia-nvjitlink-cu12==12.5.40
50
+ nvidia-nvtx-cu12==12.1.105
51
+ orjson==3.10.3
52
+ packaging==24.0
53
+ pandas==2.2.2
54
+ pathlib==1.0.1
55
+ pillow==10.3.0
56
+ pydantic==2.7.1
57
+ pydantic_core==2.18.2
58
+ pydub==0.25.1
59
+ Pygments==2.18.0
60
+ pyparsing==3.1.2
61
+ python-dateutil==2.9.0.post0
62
+ python-dotenv==1.0.1
63
+ python-multipart==0.0.9
64
+ pytz==2024.1
65
+ PyYAML==6.0.1
66
+ referencing==0.35.1
67
+ requests==2.32.2
68
+ rich==13.7.1
69
+ rpds-py==0.18.1
70
+ semantic-version==2.10.0
71
+ shellingham==1.5.4
72
+ six==1.16.0
73
+ sniffio==1.3.1
74
+ starlette==0.37.2
75
+ sympy==1.12
76
+ toolz==0.12.1
77
+ torch==2.3.0
78
+ tqdm==4.66.4
79
+ triton==2.3.0
80
+ typer==0.12.3
81
+ typing_extensions==4.12.0
82
+ tzdata==2024.1
83
+ ujson==5.10.0
84
+ urllib3==2.2.1
85
+ uvicorn==0.30.0
86
+ uvloop==0.19.0
87
+ watchfiles==0.22.0
88
+ websockets==11.0.3