hysts HF staff commited on
Commit
9f992a7
β€’
1 Parent(s): d25c07c
Files changed (5) hide show
  1. .pre-commit-config.yaml +4 -13
  2. README.md +4 -1
  3. app.py +57 -95
  4. model.py +8 -6
  5. requirements.txt +1 -1
.pre-commit-config.yaml CHANGED
@@ -1,4 +1,4 @@
1
- exclude: ^(CBNetV2/|patch)
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
  rev: v4.2.0
@@ -21,26 +21,17 @@ repos:
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
- rev: 5.10.1
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.812
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
 
32
  - repo: https://github.com/google/yapf
33
  rev: v0.32.0
34
  hooks:
35
  - id: yapf
36
  args: ['--parallel', '--in-place']
37
- - repo: https://github.com/kynan/nbstripout
38
- rev: 0.5.0
39
- hooks:
40
- - id: nbstripout
41
- args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
- - repo: https://github.com/nbQA-dev/nbQA
43
- rev: 1.3.1
44
- hooks:
45
- - id: nbqa-isort
46
- - id: nbqa-yapf
 
1
+ exclude: ^patch
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
  rev: v4.2.0
 
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
  - repo: https://github.com/google/yapf
34
  rev: v0.32.0
35
  hooks:
36
  - id: yapf
37
  args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -4,9 +4,12 @@ emoji: πŸ“‰
4
  colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.0.17
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
4
  colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
10
+ suggested_hardware: t4-small
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
14
+
15
+ https://arxiv.org/abs/2107.00420
app.py CHANGED
@@ -2,104 +2,66 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import pathlib
7
 
8
  import gradio as gr
9
 
10
  from model import Model
11
 
12
- DESCRIPTION = '''# CBNetV2
13
-
14
- This is an unofficial demo for [https://github.com/VDIGPKU/CBNetV2](https://github.com/VDIGPKU/CBNetV2).'''
15
- FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.cbnetv2" />'
16
-
17
-
18
- def parse_args() -> argparse.Namespace:
19
- parser = argparse.ArgumentParser()
20
- parser.add_argument('--device', type=str, default='cpu')
21
- parser.add_argument('--theme', type=str)
22
- parser.add_argument('--share', action='store_true')
23
- parser.add_argument('--port', type=int)
24
- parser.add_argument('--disable-queue',
25
- dest='enable_queue',
26
- action='store_false')
27
- return parser.parse_args()
28
-
29
-
30
- def set_example_image(example: list) -> dict:
31
- return gr.Image.update(value=example[0])
32
-
33
-
34
- def main():
35
- args = parse_args()
36
- model = Model(args.device)
37
-
38
- with gr.Blocks(theme=args.theme, css='style.css') as demo:
39
- gr.Markdown(DESCRIPTION)
40
-
41
- with gr.Row():
42
- with gr.Column():
43
- with gr.Row():
44
- input_image = gr.Image(label='Input Image', type='numpy')
45
- with gr.Row():
46
- detector_name = gr.Dropdown(list(model.models.keys()),
47
- value=model.model_name,
48
- label='Detector')
49
- with gr.Row():
50
- detect_button = gr.Button(value='Detect')
51
- detection_results = gr.Variable()
52
- with gr.Column():
53
- with gr.Row():
54
- detection_visualization = gr.Image(
55
- label='Detection Result', type='numpy')
56
- with gr.Row():
57
- visualization_score_threshold = gr.Slider(
58
- 0,
59
- 1,
60
- step=0.05,
61
- value=0.3,
62
- label='Visualization Score Threshold')
63
- with gr.Row():
64
- redraw_button = gr.Button(value='Redraw')
65
-
66
- with gr.Row():
67
- paths = sorted(pathlib.Path('images').rglob('*.jpg'))
68
- example_images = gr.Dataset(components=[input_image],
69
- samples=[[path.as_posix()]
70
- for path in paths])
71
-
72
- gr.Markdown(FOOTER)
73
-
74
- detector_name.change(fn=model.set_model_name,
75
- inputs=[detector_name],
76
- outputs=None)
77
- detect_button.click(fn=model.detect_and_visualize,
78
- inputs=[
79
- input_image,
80
- visualization_score_threshold,
81
- ],
82
- outputs=[
83
- detection_results,
84
- detection_visualization,
85
- ])
86
- redraw_button.click(fn=model.visualize_detection_results,
87
- inputs=[
88
- input_image,
89
- detection_results,
90
- visualization_score_threshold,
91
- ],
92
- outputs=[detection_visualization])
93
- example_images.click(fn=set_example_image,
94
- inputs=[example_images],
95
- outputs=[input_image])
96
-
97
- demo.launch(
98
- enable_queue=args.enable_queue,
99
- server_port=args.port,
100
- share=args.share,
101
- )
102
-
103
-
104
- if __name__ == '__main__':
105
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import pathlib
6
 
7
  import gradio as gr
8
 
9
  from model import Model
10
 
11
+ DESCRIPTION = '# [CBNetV2](https://github.com/VDIGPKU/CBNetV2)'
12
+
13
+ model = Model()
14
+
15
+ with gr.Blocks(css='style.css') as demo:
16
+ gr.Markdown(DESCRIPTION)
17
+
18
+ with gr.Row():
19
+ with gr.Column():
20
+ with gr.Row():
21
+ input_image = gr.Image(label='Input Image', type='numpy')
22
+ with gr.Row():
23
+ detector_name = gr.Dropdown(label='Detector',
24
+ choices=list(model.models.keys()),
25
+ value=model.model_name)
26
+ with gr.Row():
27
+ detect_button = gr.Button('Detect')
28
+ detection_results = gr.Variable()
29
+ with gr.Column():
30
+ with gr.Row():
31
+ detection_visualization = gr.Image(label='Detection Result',
32
+ type='numpy')
33
+ with gr.Row():
34
+ visualization_score_threshold = gr.Slider(
35
+ label='Visualization Score Threshold',
36
+ minimum=0,
37
+ maximum=1,
38
+ step=0.05,
39
+ value=0.3)
40
+ with gr.Row():
41
+ redraw_button = gr.Button('Redraw')
42
+
43
+ with gr.Row():
44
+ paths = sorted(pathlib.Path('images').rglob('*.jpg'))
45
+ gr.Examples(examples=[[path.as_posix()] for path in paths],
46
+ inputs=input_image)
47
+
48
+ detector_name.change(fn=model.set_model_name,
49
+ inputs=[detector_name],
50
+ outputs=None)
51
+ detect_button.click(fn=model.detect_and_visualize,
52
+ inputs=[
53
+ input_image,
54
+ visualization_score_threshold,
55
+ ],
56
+ outputs=[
57
+ detection_results,
58
+ detection_visualization,
59
+ ])
60
+ redraw_button.click(fn=model.visualize_detection_results,
61
+ inputs=[
62
+ input_image,
63
+ detection_results,
64
+ visualization_score_threshold,
65
+ ],
66
+ outputs=[detection_visualization])
67
+ demo.queue(max_size=10).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  import os
4
  import pathlib
 
5
  import subprocess
6
  import sys
7
 
@@ -11,12 +12,12 @@ if os.getenv('SYSTEM') == 'spaces':
11
  mim.uninstall('mmcv-full', confirm_yes=True)
12
  mim.install('mmcv-full==1.5.0', is_yes=True)
13
 
14
- subprocess.run('pip uninstall -y opencv-python'.split())
15
- subprocess.run('pip uninstall -y opencv-python-headless'.split())
16
- subprocess.run('pip install opencv-python-headless==4.5.5.64'.split())
17
 
18
  with open('patch') as f:
19
- subprocess.run('patch -p1'.split(), cwd='CBNetV2', stdin=f)
20
  subprocess.run('mv palette.py CBNetV2/mmdet/core/visualization/'.split())
21
 
22
  import numpy as np
@@ -31,8 +32,9 @@ from mmdet.apis import inference_detector, init_detector
31
 
32
 
33
  class Model:
34
- def __init__(self, device: str | torch.device):
35
- self.device = torch.device(device)
 
36
  self.models = self._load_models()
37
  self.model_name = 'Improved HTC (DB-Swin-B)'
38
 
 
2
 
3
  import os
4
  import pathlib
5
+ import shlex
6
  import subprocess
7
  import sys
8
 
 
12
  mim.uninstall('mmcv-full', confirm_yes=True)
13
  mim.install('mmcv-full==1.5.0', is_yes=True)
14
 
15
+ subprocess.run(shlex.split('pip uninstall -y opencv-python'))
16
+ subprocess.run(shlex.split('pip uninstall -y opencv-python-headless'))
17
+ subprocess.run(shlex.split('pip install opencv-python-headless==4.8.0.74'))
18
 
19
  with open('patch') as f:
20
+ subprocess.run(shlex.split('patch -p1'), cwd='CBNetV2', stdin=f)
21
  subprocess.run('mv palette.py CBNetV2/mmdet/core/visualization/'.split())
22
 
23
  import numpy as np
 
32
 
33
 
34
  class Model:
35
+ def __init__(self):
36
+ self.device = torch.device(
37
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
38
  self.models = self._load_models()
39
  self.model_name = 'Improved HTC (DB-Swin-B)'
40
 
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  mmcv-full==1.5.0
2
  mmdet==2.24.1
3
  numpy==1.22.4
4
- opencv-python-headless==4.5.5.64
5
  openmim==0.1.5
6
  timm==0.5.4
7
  torch==1.11.0
 
1
  mmcv-full==1.5.0
2
  mmdet==2.24.1
3
  numpy==1.22.4
4
+ opencv-python-headless==4.8.0.74
5
  openmim==0.1.5
6
  timm==0.5.4
7
  torch==1.11.0