wtf741 commited on
Commit
8cb8f64
1 Parent(s): 5f5f7aa

init commit

Browse files
Files changed (5) hide show
  1. .gitignore +5 -0
  2. README.md +2 -2
  3. app.py +47 -0
  4. check.py +81 -0
  5. requirements.txt +12 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ /venv
2
+ __pycache__
3
+ /test_*
4
+ /.idea
5
+ .env
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Mist Checker
3
- emoji: 🌖
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
 
1
  ---
2
+ title: Shit Checker
3
+ emoji: 💩
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import textwrap
3
+ import time
4
+
5
+ import gradio as gr
6
+ from PIL import Image
7
+
8
+ from check import DEFAULT_MODEL, predict, MODELS
9
+
10
+
11
+ def _predict_fn(image: Image.Image, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8):
12
+ start_time = time.time()
13
+ result = predict(image, model_name, max_batch_size)
14
+ duration = time.time() - start_time
15
+ info = f'Time cost: **{duration:.3f}s**'
16
+ return result, info
17
+
18
+
19
+ if __name__ == '__main__':
20
+ with gr.Blocks() as demo:
21
+ with gr.Row():
22
+ gr_info = gr.Markdown(textwrap.dedent("""
23
+ Quickly check if the image is **glazed or misted** (we call it shat💩).
24
+
25
+ And then you can just remove these shit with
26
+ [mf666/mist-fucker](https://huggingface.co/spaces/mf666/mist-fucker),
27
+ without fucking the normal images (no detail losses).
28
+ """).strip())
29
+
30
+ with gr.Row():
31
+ with gr.Column():
32
+ gr_input = gr.Image(label='Image To Check', image_mode='RGB', type='pil')
33
+ gr_models = gr.Dropdown(choices=MODELS, value=DEFAULT_MODEL, label='Models')
34
+ gr_max_batch_size = gr.Slider(minimum=1, maximum=16, value=8, step=1, label='Max Batch Size')
35
+ gr_submit = gr.Button(value='Check The Shit', variant='primary')
36
+
37
+ with gr.Column():
38
+ gr_label = gr.Label(label='Check Result')
39
+ gr_time_cost = gr.Markdown(label='Information')
40
+
41
+ gr_submit.click(
42
+ _predict_fn,
43
+ inputs=[gr_input, gr_models, gr_max_batch_size],
44
+ outputs=[gr_label, gr_time_cost],
45
+ )
46
+
47
+ demo.queue(os.cpu_count()).launch()
check.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ from functools import lru_cache
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from huggingface_hub import HfFileSystem, HfApi
9
+ from imgutils.utils import open_onnx_model
10
+ from natsort import natsorted
11
+
12
+ hf_token = os.environ.get('HF_TOKEN')
13
+ hf_fs = HfFileSystem(token=hf_token)
14
+ hf_client = HfApi(token=hf_token)
15
+
16
+ REPOSITORY = 'mf666/shit-checker'
17
+ MODELS = natsorted([
18
+ os.path.splitext(os.path.relpath(file, REPOSITORY))[0]
19
+ for file in hf_fs.glob(f'{REPOSITORY}/*.onnx')
20
+ ])
21
+ DEFAULT_MODEL = 'mobilenet.small'
22
+
23
+
24
+ @lru_cache()
25
+ def _open_model(model_name):
26
+ return open_onnx_model(hf_client.hf_hub_download(REPOSITORY, f'{model_name}.onnx'))
27
+
28
+
29
+ _DEFAULT_ORDER = 'HWC'
30
+
31
+
32
+ def _get_hwc_map(order_):
33
+ return tuple(_DEFAULT_ORDER.index(c) for c in order_.upper())
34
+
35
+
36
+ def _encode_channels(image, channels_order='CHW', is_float=True):
37
+ array = np.asarray(image.convert('RGB'))
38
+ array = np.transpose(array, _get_hwc_map(channels_order))
39
+ if not is_float:
40
+ assert array.dtype == np.uint8
41
+ else:
42
+ array = (array / 255.0).astype(np.float32)
43
+ assert array.dtype == np.float32
44
+ return array
45
+
46
+
47
+ def _img_encode(image, size=(384, 384), normalize=(0.5, 0.5)):
48
+ image = image.resize(size, Image.BILINEAR)
49
+ data = _encode_channels(image, channels_order='CHW')
50
+
51
+ if normalize is not None:
52
+ mean_, std_ = normalize
53
+ mean = np.asarray([mean_]).reshape((-1, 1, 1))
54
+ std = np.asarray([std_]).reshape((-1, 1, 1))
55
+ data = (data - mean) / std
56
+
57
+ return data.astype(np.float32)
58
+
59
+
60
+ def _raw_predict(images, model_name=DEFAULT_MODEL):
61
+ items = []
62
+ for image in images:
63
+ items.append(_img_encode(image.convert('RGB')))
64
+ input_ = np.stack(items)
65
+ output, = _open_model(model_name).run(['output'], {'input': input_})
66
+ return output.mean(axis=0)
67
+
68
+
69
+ def predict(image, model_name=DEFAULT_MODEL, max_batch_size=8):
70
+ area = image.width * image.height
71
+ batch_size = int(max(min(math.ceil(area / (384 * 384)) + 1, max_batch_size), 1))
72
+ blocks = []
73
+ for _ in range(batch_size):
74
+ x0 = random.randint(0, max(0, image.width - 384))
75
+ y0 = random.randint(0, max(0, image.height - 384))
76
+ x1 = min(x0 + 384, image.width)
77
+ y1 = min(y0 + 384, image.height)
78
+ blocks.append(image.crop((x0, y0, x1, y1)))
79
+
80
+ scores = _raw_predict(blocks, model_name)
81
+ return dict(zip(['shat', 'normal'], map(lambda x: x.item(), scores)))
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.12.0
2
+ numpy
3
+ pillow
4
+ onnxruntime
5
+ huggingface_hub>=0.14.0
6
+ scikit-image
7
+ pandas
8
+ opencv-python>=4.6.0
9
+ hbutils>=0.9.0
10
+ dghs-imgutils>=0.1.0
11
+ httpx==0.23.0
12
+ natsort