Spaces:
Runtime error
Runtime error
init commit
Browse files- .gitignore +5 -0
- README.md +2 -2
- app.py +47 -0
- check.py +81 -0
- requirements.txt +12 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/venv
|
2 |
+
__pycache__
|
3 |
+
/test_*
|
4 |
+
/.idea
|
5 |
+
.env
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
+
title: Shit Checker
|
3 |
+
emoji: 💩
|
4 |
colorFrom: blue
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import textwrap
|
3 |
+
import time
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from check import DEFAULT_MODEL, predict, MODELS
|
9 |
+
|
10 |
+
|
11 |
+
def _predict_fn(image: Image.Image, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8):
|
12 |
+
start_time = time.time()
|
13 |
+
result = predict(image, model_name, max_batch_size)
|
14 |
+
duration = time.time() - start_time
|
15 |
+
info = f'Time cost: **{duration:.3f}s**'
|
16 |
+
return result, info
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == '__main__':
|
20 |
+
with gr.Blocks() as demo:
|
21 |
+
with gr.Row():
|
22 |
+
gr_info = gr.Markdown(textwrap.dedent("""
|
23 |
+
Quickly check if the image is **glazed or misted** (we call it shat💩).
|
24 |
+
|
25 |
+
And then you can just remove these shit with
|
26 |
+
[mf666/mist-fucker](https://huggingface.co/spaces/mf666/mist-fucker),
|
27 |
+
without fucking the normal images (no detail losses).
|
28 |
+
""").strip())
|
29 |
+
|
30 |
+
with gr.Row():
|
31 |
+
with gr.Column():
|
32 |
+
gr_input = gr.Image(label='Image To Check', image_mode='RGB', type='pil')
|
33 |
+
gr_models = gr.Dropdown(choices=MODELS, value=DEFAULT_MODEL, label='Models')
|
34 |
+
gr_max_batch_size = gr.Slider(minimum=1, maximum=16, value=8, step=1, label='Max Batch Size')
|
35 |
+
gr_submit = gr.Button(value='Check The Shit', variant='primary')
|
36 |
+
|
37 |
+
with gr.Column():
|
38 |
+
gr_label = gr.Label(label='Check Result')
|
39 |
+
gr_time_cost = gr.Markdown(label='Information')
|
40 |
+
|
41 |
+
gr_submit.click(
|
42 |
+
_predict_fn,
|
43 |
+
inputs=[gr_input, gr_models, gr_max_batch_size],
|
44 |
+
outputs=[gr_label, gr_time_cost],
|
45 |
+
)
|
46 |
+
|
47 |
+
demo.queue(os.cpu_count()).launch()
|
check.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from huggingface_hub import HfFileSystem, HfApi
|
9 |
+
from imgutils.utils import open_onnx_model
|
10 |
+
from natsort import natsorted
|
11 |
+
|
12 |
+
hf_token = os.environ.get('HF_TOKEN')
|
13 |
+
hf_fs = HfFileSystem(token=hf_token)
|
14 |
+
hf_client = HfApi(token=hf_token)
|
15 |
+
|
16 |
+
REPOSITORY = 'mf666/shit-checker'
|
17 |
+
MODELS = natsorted([
|
18 |
+
os.path.splitext(os.path.relpath(file, REPOSITORY))[0]
|
19 |
+
for file in hf_fs.glob(f'{REPOSITORY}/*.onnx')
|
20 |
+
])
|
21 |
+
DEFAULT_MODEL = 'mobilenet.small'
|
22 |
+
|
23 |
+
|
24 |
+
@lru_cache()
|
25 |
+
def _open_model(model_name):
|
26 |
+
return open_onnx_model(hf_client.hf_hub_download(REPOSITORY, f'{model_name}.onnx'))
|
27 |
+
|
28 |
+
|
29 |
+
_DEFAULT_ORDER = 'HWC'
|
30 |
+
|
31 |
+
|
32 |
+
def _get_hwc_map(order_):
|
33 |
+
return tuple(_DEFAULT_ORDER.index(c) for c in order_.upper())
|
34 |
+
|
35 |
+
|
36 |
+
def _encode_channels(image, channels_order='CHW', is_float=True):
|
37 |
+
array = np.asarray(image.convert('RGB'))
|
38 |
+
array = np.transpose(array, _get_hwc_map(channels_order))
|
39 |
+
if not is_float:
|
40 |
+
assert array.dtype == np.uint8
|
41 |
+
else:
|
42 |
+
array = (array / 255.0).astype(np.float32)
|
43 |
+
assert array.dtype == np.float32
|
44 |
+
return array
|
45 |
+
|
46 |
+
|
47 |
+
def _img_encode(image, size=(384, 384), normalize=(0.5, 0.5)):
|
48 |
+
image = image.resize(size, Image.BILINEAR)
|
49 |
+
data = _encode_channels(image, channels_order='CHW')
|
50 |
+
|
51 |
+
if normalize is not None:
|
52 |
+
mean_, std_ = normalize
|
53 |
+
mean = np.asarray([mean_]).reshape((-1, 1, 1))
|
54 |
+
std = np.asarray([std_]).reshape((-1, 1, 1))
|
55 |
+
data = (data - mean) / std
|
56 |
+
|
57 |
+
return data.astype(np.float32)
|
58 |
+
|
59 |
+
|
60 |
+
def _raw_predict(images, model_name=DEFAULT_MODEL):
|
61 |
+
items = []
|
62 |
+
for image in images:
|
63 |
+
items.append(_img_encode(image.convert('RGB')))
|
64 |
+
input_ = np.stack(items)
|
65 |
+
output, = _open_model(model_name).run(['output'], {'input': input_})
|
66 |
+
return output.mean(axis=0)
|
67 |
+
|
68 |
+
|
69 |
+
def predict(image, model_name=DEFAULT_MODEL, max_batch_size=8):
|
70 |
+
area = image.width * image.height
|
71 |
+
batch_size = int(max(min(math.ceil(area / (384 * 384)) + 1, max_batch_size), 1))
|
72 |
+
blocks = []
|
73 |
+
for _ in range(batch_size):
|
74 |
+
x0 = random.randint(0, max(0, image.width - 384))
|
75 |
+
y0 = random.randint(0, max(0, image.height - 384))
|
76 |
+
x1 = min(x0 + 384, image.width)
|
77 |
+
y1 = min(y0 + 384, image.height)
|
78 |
+
blocks.append(image.crop((x0, y0, x1, y1)))
|
79 |
+
|
80 |
+
scores = _raw_predict(blocks, model_name)
|
81 |
+
return dict(zip(['shat', 'normal'], map(lambda x: x.item(), scores)))
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.12.0
|
2 |
+
numpy
|
3 |
+
pillow
|
4 |
+
onnxruntime
|
5 |
+
huggingface_hub>=0.14.0
|
6 |
+
scikit-image
|
7 |
+
pandas
|
8 |
+
opencv-python>=4.6.0
|
9 |
+
hbutils>=0.9.0
|
10 |
+
dghs-imgutils>=0.1.0
|
11 |
+
httpx==0.23.0
|
12 |
+
natsort
|