File size: 2,694 Bytes
3fe1151
 
 
 
a61ca6d
3fe1151
a61ca6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fe1151
 
 
7fd17d1
3fe1151
293637a
9f992a7
 
 
293637a
9f992a7
 
 
 
 
293637a
9f992a7
293637a
 
 
9f992a7
293637a
0cdaffb
9f992a7
 
293637a
9f992a7
 
293637a
 
9f992a7
293637a
9f992a7
 
293637a
 
9f992a7
0cdaffb
293637a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cdaffb
293637a
0cdaffb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python

from __future__ import annotations

import os
import pathlib
import shlex
import subprocess

if os.getenv("SYSTEM") == "spaces":
    subprocess.run(shlex.split("pip install click==7.1.2"))
    subprocess.run(shlex.split("pip install typer==0.9.4"))

    import mim

    mim.uninstall("mmcv-full", confirm_yes=True)
    mim.install("mmcv-full==1.5.0", is_yes=True)

    subprocess.run(shlex.split("pip uninstall -y opencv-python"))
    subprocess.run(shlex.split("pip uninstall -y opencv-python-headless"))
    subprocess.run(shlex.split("pip install opencv-python-headless==4.8.0.74"))

    with open("patch") as f:
        subprocess.run(shlex.split("patch -p1"), cwd="CBNetV2", stdin=f)
    subprocess.run("mv palette.py CBNetV2/mmdet/core/visualization/".split())


import gradio as gr

from model import Model

DESCRIPTION = "# [CBNetV2](https://github.com/VDIGPKU/CBNetV2)"

model = Model()

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Row():
        with gr.Column():
            with gr.Row():
                input_image = gr.Image(label="Input Image", type="numpy")
            with gr.Row():
                detector_name = gr.Dropdown(
                    label="Detector", choices=list(model.models.keys()), value=model.model_name
                )
            with gr.Row():
                detect_button = gr.Button("Detect")
                detection_results = gr.State()
        with gr.Column():
            with gr.Row():
                detection_visualization = gr.Image(label="Detection Result", type="numpy")
            with gr.Row():
                visualization_score_threshold = gr.Slider(
                    label="Visualization Score Threshold", minimum=0, maximum=1, step=0.05, value=0.3
                )
            with gr.Row():
                redraw_button = gr.Button("Redraw")

    with gr.Row():
        paths = sorted(pathlib.Path("images").rglob("*.jpg"))
        gr.Examples(examples=[[path.as_posix()] for path in paths], inputs=input_image)

    detector_name.change(fn=model.set_model_name, inputs=detector_name)
    detect_button.click(
        fn=model.detect_and_visualize,
        inputs=[
            input_image,
            visualization_score_threshold,
        ],
        outputs=[
            detection_results,
            detection_visualization,
        ],
    )
    redraw_button.click(
        fn=model.visualize_detection_results,
        inputs=[
            input_image,
            detection_results,
            visualization_score_threshold,
        ],
        outputs=detection_visualization,
    )

if __name__ == "__main__":
    demo.queue(max_size=10).launch()