Spaces:
Running
on
A10G
Running
on
A10G
Commit
•
5500fcd
1
Parent(s):
0e674dd
set_max_key_frames_env (#10)
Browse files- add midas depth and env for MAX_KEYFRAME (c962acd2ca442aebd3c83d20cc91ebabc2b1ab75)
Co-authored-by: Radamés Ajna <[email protected]>
app.py
CHANGED
@@ -18,6 +18,7 @@ from skimage import exposure
|
|
18 |
import src.import_util # noqa: F401
|
19 |
from ControlNet.annotator.canny import CannyDetector
|
20 |
from ControlNet.annotator.hed import HEDdetector
|
|
|
21 |
from ControlNet.annotator.util import HWC3
|
22 |
from ControlNet.cldm.model import create_model, load_state_dict
|
23 |
from gmflow_module.gmflow.gmflow import GMFlow
|
@@ -61,7 +62,7 @@ class ProcessingState(Enum):
|
|
61 |
KEY_IMGS = 2
|
62 |
|
63 |
|
64 |
-
MAX_KEYFRAME = 8
|
65 |
|
66 |
|
67 |
class GlobalState:
|
@@ -111,6 +112,12 @@ class GlobalState:
|
|
111 |
load_state_dict(huggingface_hub.hf_hub_download(
|
112 |
'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'),
|
113 |
location=device))
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
model.to(device)
|
115 |
sd_model_path = model_dict[sd_model]
|
116 |
if len(sd_model_path) > 0:
|
@@ -162,6 +169,15 @@ class GlobalState:
|
|
162 |
|
163 |
self.detector = apply_canny
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
global_state = GlobalState()
|
167 |
global_video_path = None
|
@@ -716,7 +732,7 @@ with block:
|
|
716 |
value=0,
|
717 |
step=1)
|
718 |
with gr.Row():
|
719 |
-
control_type = gr.Dropdown(['HED', 'canny'],
|
720 |
label='Control type',
|
721 |
value='HED')
|
722 |
low_threshold = gr.Slider(label='Canny low threshold',
|
@@ -756,14 +772,14 @@ with block:
|
|
756 |
interval = gr.Slider(
|
757 |
label='Key frame frequency (K)',
|
758 |
minimum=1,
|
759 |
-
maximum=
|
760 |
value=1,
|
761 |
step=1,
|
762 |
info='Uniformly sample the key frames every K frames')
|
763 |
keyframe_count = gr.Slider(
|
764 |
label='Number of key frames',
|
765 |
minimum=1,
|
766 |
-
maximum=
|
767 |
value=1,
|
768 |
step=1,
|
769 |
info='To avoid overload, maximum 8 key frames')
|
|
|
18 |
import src.import_util # noqa: F401
|
19 |
from ControlNet.annotator.canny import CannyDetector
|
20 |
from ControlNet.annotator.hed import HEDdetector
|
21 |
+
from ControlNet.annotator.midas import MidasDetector
|
22 |
from ControlNet.annotator.util import HWC3
|
23 |
from ControlNet.cldm.model import create_model, load_state_dict
|
24 |
from gmflow_module.gmflow.gmflow import GMFlow
|
|
|
62 |
KEY_IMGS = 2
|
63 |
|
64 |
|
65 |
+
MAX_KEYFRAME = float(os.environ.get('MAX_KEYFRAME', 8))
|
66 |
|
67 |
|
68 |
class GlobalState:
|
|
|
112 |
load_state_dict(huggingface_hub.hf_hub_download(
|
113 |
'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'),
|
114 |
location=device))
|
115 |
+
elif control_type == 'depth':
|
116 |
+
model.load_state_dict(
|
117 |
+
load_state_dict(huggingface_hub.hf_hub_download(
|
118 |
+
'lllyasviel/ControlNet', 'models/control_sd15_depth.pth'),
|
119 |
+
location=device))
|
120 |
+
|
121 |
model.to(device)
|
122 |
sd_model_path = model_dict[sd_model]
|
123 |
if len(sd_model_path) > 0:
|
|
|
169 |
|
170 |
self.detector = apply_canny
|
171 |
|
172 |
+
elif control_type == 'depth':
|
173 |
+
midas = MidasDetector()
|
174 |
+
|
175 |
+
def apply_midas(x):
|
176 |
+
detected_map, _ = midas(x)
|
177 |
+
return detected_map
|
178 |
+
|
179 |
+
self.detector = apply_midas
|
180 |
+
|
181 |
|
182 |
global_state = GlobalState()
|
183 |
global_video_path = None
|
|
|
732 |
value=0,
|
733 |
step=1)
|
734 |
with gr.Row():
|
735 |
+
control_type = gr.Dropdown(['HED', 'canny', 'depth'],
|
736 |
label='Control type',
|
737 |
value='HED')
|
738 |
low_threshold = gr.Slider(label='Canny low threshold',
|
|
|
772 |
interval = gr.Slider(
|
773 |
label='Key frame frequency (K)',
|
774 |
minimum=1,
|
775 |
+
maximum=MAX_KEYFRAME,
|
776 |
value=1,
|
777 |
step=1,
|
778 |
info='Uniformly sample the key frames every K frames')
|
779 |
keyframe_count = gr.Slider(
|
780 |
label='Number of key frames',
|
781 |
minimum=1,
|
782 |
+
maximum=MAX_KEYFRAME,
|
783 |
value=1,
|
784 |
step=1,
|
785 |
info='To avoid overload, maximum 8 key frames')
|