Spaces:
Running
on
A10G
Running
on
A10G
anonymous
commited on
Commit
•
2896183
1
Parent(s):
0a4007d
update
Browse files- app.py +24 -16
- src/ddim_v_hacked.py +5 -3
app.py
CHANGED
@@ -303,6 +303,8 @@ def process1(*args):
|
|
303 |
imgs = sorted(os.listdir(cfg.input_dir))
|
304 |
imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
|
305 |
|
|
|
|
|
306 |
with torch.no_grad():
|
307 |
frame = cv2.imread(imgs[0])
|
308 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
@@ -607,6 +609,7 @@ def process2(*args):
|
|
607 |
|
608 |
return key_video_path
|
609 |
|
|
|
610 |
DESCRIPTION = '''
|
611 |
## Rerender A Video
|
612 |
### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper.
|
@@ -644,12 +647,13 @@ with block:
|
|
644 |
run_button3 = gr.Button(value='Run Propagation')
|
645 |
with gr.Accordion('Advanced options for the 1st frame translation',
|
646 |
open=False):
|
647 |
-
image_resolution = gr.Slider(
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
|
|
653 |
control_strength = gr.Slider(label='ControNet strength',
|
654 |
minimum=0.0,
|
655 |
maximum=2.0,
|
@@ -734,12 +738,13 @@ with block:
|
|
734 |
value=1,
|
735 |
step=1,
|
736 |
info='Uniformly sample the key frames every K frames')
|
737 |
-
keyframe_count = gr.Slider(
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
|
|
743 |
|
744 |
use_constraints = gr.CheckboxGroup(
|
745 |
[
|
@@ -769,8 +774,10 @@ with block:
|
|
769 |
maximum=100,
|
770 |
value=1,
|
771 |
step=1,
|
772 |
-
info=
|
773 |
-
|
|
|
|
|
774 |
with gr.Row():
|
775 |
warp_start = gr.Slider(label='Shape-aware fusion start',
|
776 |
minimum=0,
|
@@ -912,8 +919,9 @@ with block:
|
|
912 |
run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe])
|
913 |
|
914 |
def process3():
|
915 |
-
raise gr.Error(
|
916 |
-
|
|
|
917 |
|
918 |
run_button3.click(fn=process3, outputs=[result_keyframe])
|
919 |
|
|
|
303 |
imgs = sorted(os.listdir(cfg.input_dir))
|
304 |
imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
|
305 |
|
306 |
+
model.cond_stage_model.device = device
|
307 |
+
|
308 |
with torch.no_grad():
|
309 |
frame = cv2.imread(imgs[0])
|
310 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
|
609 |
|
610 |
return key_video_path
|
611 |
|
612 |
+
|
613 |
DESCRIPTION = '''
|
614 |
## Rerender A Video
|
615 |
### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper.
|
|
|
647 |
run_button3 = gr.Button(value='Run Propagation')
|
648 |
with gr.Accordion('Advanced options for the 1st frame translation',
|
649 |
open=False):
|
650 |
+
image_resolution = gr.Slider(
|
651 |
+
label='Frame rsolution',
|
652 |
+
minimum=256,
|
653 |
+
maximum=512,
|
654 |
+
value=512,
|
655 |
+
step=64,
|
656 |
+
info='To avoid overload, maximum 512')
|
657 |
control_strength = gr.Slider(label='ControNet strength',
|
658 |
minimum=0.0,
|
659 |
maximum=2.0,
|
|
|
738 |
value=1,
|
739 |
step=1,
|
740 |
info='Uniformly sample the key frames every K frames')
|
741 |
+
keyframe_count = gr.Slider(
|
742 |
+
label='Number of key frames',
|
743 |
+
minimum=1,
|
744 |
+
maximum=1,
|
745 |
+
value=1,
|
746 |
+
step=1,
|
747 |
+
info='To avoid overload, maximum 8 key frames')
|
748 |
|
749 |
use_constraints = gr.CheckboxGroup(
|
750 |
[
|
|
|
774 |
maximum=100,
|
775 |
value=1,
|
776 |
step=1,
|
777 |
+
info=
|
778 |
+
('Update the key and value for '
|
779 |
+
'cross-frame attention every N key frames (recommend N*K>=10)'
|
780 |
+
))
|
781 |
with gr.Row():
|
782 |
warp_start = gr.Slider(label='Shape-aware fusion start',
|
783 |
minimum=0,
|
|
|
919 |
run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe])
|
920 |
|
921 |
def process3():
|
922 |
+
raise gr.Error(
|
923 |
+
"Coming Soon. Full code for full video translation will be "
|
924 |
+
"released upon the publication of the paper.")
|
925 |
|
926 |
run_button3.click(fn=process3, outputs=[result_keyframe])
|
927 |
|
src/ddim_v_hacked.py
CHANGED
@@ -14,6 +14,8 @@ from ControlNet.ldm.modules.diffusionmodules.util import (
|
|
14 |
|
15 |
_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
|
16 |
|
|
|
|
|
17 |
|
18 |
def register_attention_control(model, controller=None):
|
19 |
|
@@ -36,7 +38,7 @@ def register_attention_control(model, controller=None):
|
|
36 |
|
37 |
# force cast to fp32 to avoid overflowing
|
38 |
if _ATTN_PRECISION == 'fp32':
|
39 |
-
with torch.autocast(enabled=False, device_type=
|
40 |
q, k = q.float(), k.float()
|
41 |
sim = torch.einsum('b i d, b j d -> b i j', q,
|
42 |
k) * self.scale
|
@@ -98,8 +100,8 @@ class DDIMVSampler(object):
|
|
98 |
|
99 |
def register_buffer(self, name, attr):
|
100 |
if type(attr) == torch.Tensor:
|
101 |
-
if attr.device != torch.device(
|
102 |
-
attr = attr.to(torch.device(
|
103 |
setattr(self, name, attr)
|
104 |
|
105 |
def make_schedule(self,
|
|
|
14 |
|
15 |
_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
|
16 |
|
17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
18 |
+
|
19 |
|
20 |
def register_attention_control(model, controller=None):
|
21 |
|
|
|
38 |
|
39 |
# force cast to fp32 to avoid overflowing
|
40 |
if _ATTN_PRECISION == 'fp32':
|
41 |
+
with torch.autocast(enabled=False, device_type=device):
|
42 |
q, k = q.float(), k.float()
|
43 |
sim = torch.einsum('b i d, b j d -> b i j', q,
|
44 |
k) * self.scale
|
|
|
100 |
|
101 |
def register_buffer(self, name, attr):
|
102 |
if type(attr) == torch.Tensor:
|
103 |
+
if attr.device != torch.device(device):
|
104 |
+
attr = attr.to(torch.device(device))
|
105 |
setattr(self, name, attr)
|
106 |
|
107 |
def make_schedule(self,
|