Spaces:
Runtime error
Runtime error
watchtowerss
commited on
Commit
•
508b599
1
Parent(s):
5fd7b77
RAM and VRAM usage reduce
Browse files- app.py +72 -55
- inpainter/base_inpainter.py +122 -17
- test_sample/test-sample13.mp4 +2 -2
- test_sample/test-sample4.mp4 +0 -0
- test_sample/test-sample8.mp4 +2 -2
app.py
CHANGED
@@ -341,7 +341,6 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
|
|
341 |
operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
|
342 |
inpainted_frames = video_state["origin_images"]
|
343 |
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
344 |
-
|
345 |
return video_output, operation_log
|
346 |
|
347 |
|
@@ -423,7 +422,7 @@ SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
|
423 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
424 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
425 |
# args.port = 12213
|
426 |
-
# args.device = "cuda:
|
427 |
# args.mask_save = True
|
428 |
|
429 |
# initialize sam, xmem, e2fgvi models
|
@@ -432,7 +431,7 @@ model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args
|
|
432 |
|
433 |
title = """<p><h1 align="center">Track-Anything</h1></p>
|
434 |
"""
|
435 |
-
description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting.
|
436 |
|
437 |
|
438 |
with gr.Blocks() as iface:
|
@@ -450,7 +449,7 @@ with gr.Blocks() as iface:
|
|
450 |
"masks": []
|
451 |
},
|
452 |
"track_end_number": None,
|
453 |
-
"resize_ratio":
|
454 |
}
|
455 |
)
|
456 |
|
@@ -470,48 +469,78 @@ with gr.Blocks() as iface:
|
|
470 |
gr.Markdown(title)
|
471 |
gr.Markdown(description)
|
472 |
with gr.Row():
|
473 |
-
|
474 |
-
# for user video input
|
475 |
with gr.Column():
|
476 |
-
with gr.
|
477 |
-
|
478 |
-
with gr.Column():
|
479 |
-
video_info = gr.Textbox(label="Video Info")
|
480 |
-
resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
|
481 |
-
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.")
|
482 |
-
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
483 |
-
|
484 |
-
|
485 |
-
with gr.Row():
|
486 |
-
# put the template frame under the radio button
|
487 |
with gr.Column():
|
488 |
-
|
489 |
-
|
490 |
-
|
|
|
|
|
|
|
|
|
|
|
491 |
|
492 |
-
# click points settins, negative or positive, mode continuous or single
|
493 |
with gr.Row():
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
with gr.Column():
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
|
516 |
# first step: get the video information
|
517 |
extract_frames_button.click(
|
@@ -601,7 +630,7 @@ with gr.Blocks() as iface:
|
|
601 |
"masks": []
|
602 |
},
|
603 |
"track_end_number": 0,
|
604 |
-
"resize_ratio":
|
605 |
},
|
606 |
[[],[]],
|
607 |
None,
|
@@ -609,7 +638,7 @@ with gr.Blocks() as iface:
|
|
609 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
610 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
611 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
|
612 |
-
gr.update(visible=False), gr.update(visible=
|
613 |
|
614 |
),
|
615 |
[],
|
@@ -631,18 +660,6 @@ with gr.Blocks() as iface:
|
|
631 |
inputs = [video_state, click_state,],
|
632 |
outputs = [template_frame,click_state, run_status],
|
633 |
)
|
634 |
-
# set example
|
635 |
-
gr.Markdown("## Examples")
|
636 |
-
gr.Examples(
|
637 |
-
examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample8.mp4","test-sample4.mp4", \
|
638 |
-
"test-sample2.mp4","test-sample13.mp4"]],
|
639 |
-
fn=run_example,
|
640 |
-
inputs=[
|
641 |
-
video_input
|
642 |
-
],
|
643 |
-
outputs=[video_input],
|
644 |
-
# cache_examples=True,
|
645 |
-
)
|
646 |
iface.queue(concurrency_count=1)
|
647 |
# iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
|
648 |
iface.launch(debug=True, enable_queue=True)
|
|
|
341 |
operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
|
342 |
inpainted_frames = video_state["origin_images"]
|
343 |
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
|
|
344 |
return video_output, operation_log
|
345 |
|
346 |
|
|
|
422 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
423 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
424 |
# args.port = 12213
|
425 |
+
# args.device = "cuda:8"
|
426 |
# args.mask_save = True
|
427 |
|
428 |
# initialize sam, xmem, e2fgvi models
|
|
|
431 |
|
432 |
title = """<p><h1 align="center">Track-Anything</h1></p>
|
433 |
"""
|
434 |
+
description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a> If you stuck in unknown errors, please feel free to watch the Tutorial video.</p>"""
|
435 |
|
436 |
|
437 |
with gr.Blocks() as iface:
|
|
|
449 |
"masks": []
|
450 |
},
|
451 |
"track_end_number": None,
|
452 |
+
"resize_ratio": 0.6
|
453 |
}
|
454 |
)
|
455 |
|
|
|
469 |
gr.Markdown(title)
|
470 |
gr.Markdown(description)
|
471 |
with gr.Row():
|
|
|
|
|
472 |
with gr.Column():
|
473 |
+
with gr.Tab("Test"):
|
474 |
+
# for user video input
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
475 |
with gr.Column():
|
476 |
+
with gr.Row(scale=0.4):
|
477 |
+
video_input = gr.Video(autosize=True)
|
478 |
+
with gr.Column():
|
479 |
+
video_info = gr.Textbox(label="Video Info")
|
480 |
+
resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
|
481 |
+
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.")
|
482 |
+
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=0.6, label="Resize ratio", visible=True)
|
483 |
+
|
484 |
|
|
|
485 |
with gr.Row():
|
486 |
+
# put the template frame under the radio button
|
487 |
+
with gr.Column():
|
488 |
+
# extract frames
|
489 |
+
with gr.Column():
|
490 |
+
extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
|
491 |
+
|
492 |
+
# click points settins, negative or positive, mode continuous or single
|
493 |
+
with gr.Row():
|
494 |
+
with gr.Row():
|
495 |
+
point_prompt = gr.Radio(
|
496 |
+
choices=["Positive", "Negative"],
|
497 |
+
value="Positive",
|
498 |
+
label="Point Prompt",
|
499 |
+
interactive=True,
|
500 |
+
visible=False)
|
501 |
+
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
|
502 |
+
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
|
503 |
+
Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
|
504 |
+
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
|
505 |
+
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", visible=False)
|
506 |
+
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
|
507 |
+
|
508 |
+
with gr.Column():
|
509 |
+
run_status = gr.HighlightedText(value=[("Run","Error"),("Status","Normal")], visible=True)
|
510 |
+
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
|
511 |
+
video_output = gr.Video(autosize=True, visible=False).style(height=360)
|
512 |
+
with gr.Row():
|
513 |
+
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
|
514 |
+
inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
|
515 |
+
# set example
|
516 |
+
gr.Markdown("## Examples")
|
517 |
+
gr.Examples(
|
518 |
+
examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample8.mp4","test-sample4.mp4", \
|
519 |
+
"test-sample2.mp4","test-sample13.mp4"]],
|
520 |
+
fn=run_example,
|
521 |
+
inputs=[
|
522 |
+
video_input
|
523 |
+
],
|
524 |
+
outputs=[video_input],
|
525 |
+
# cache_examples=True,
|
526 |
+
)
|
527 |
+
|
528 |
+
with gr.Tab("Tutorial"):
|
529 |
with gr.Column():
|
530 |
+
with gr.Row(scale=0.4):
|
531 |
+
video_demo_operation = gr.Video(autosize=True)
|
532 |
+
|
533 |
+
# set example
|
534 |
+
gr.Markdown("## Operation tutorial video")
|
535 |
+
gr.Examples(
|
536 |
+
examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["huggingface_demo_operation.mp4"]],
|
537 |
+
fn=run_example,
|
538 |
+
inputs=[
|
539 |
+
video_demo_operation
|
540 |
+
],
|
541 |
+
outputs=[video_demo_operation],
|
542 |
+
# cache_examples=True,
|
543 |
+
)
|
544 |
|
545 |
# first step: get the video information
|
546 |
extract_frames_button.click(
|
|
|
630 |
"masks": []
|
631 |
},
|
632 |
"track_end_number": 0,
|
633 |
+
"resize_ratio": 0.6
|
634 |
},
|
635 |
[[],[]],
|
636 |
None,
|
|
|
638 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
639 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
640 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
|
641 |
+
gr.update(visible=False), gr.update(visible=True)
|
642 |
|
643 |
),
|
644 |
[],
|
|
|
660 |
inputs = [video_state, click_state,],
|
661 |
outputs = [template_frame,click_state, run_status],
|
662 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
663 |
iface.queue(concurrency_count=1)
|
664 |
# iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
|
665 |
iface.launch(debug=True, enable_queue=True)
|
inpainter/base_inpainter.py
CHANGED
@@ -9,9 +9,9 @@ import numpy as np
|
|
9 |
from tqdm import tqdm
|
10 |
from inpainter.util.tensor_util import resize_frames, resize_masks
|
11 |
|
12 |
-
def
|
13 |
# if type:
|
14 |
-
image =
|
15 |
# else:
|
16 |
# image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
|
17 |
return image
|
@@ -56,29 +56,30 @@ class BaseInpainter:
|
|
56 |
break
|
57 |
ref_index.append(i)
|
58 |
return ref_index
|
59 |
-
|
60 |
-
def
|
61 |
"""
|
|
|
62 |
frames: numpy array, T, H, W, 3
|
63 |
masks: numpy array, T, H, W
|
|
|
|
|
64 |
dilate_radius: radius when applying dilation on masks
|
65 |
ratio: down-sample ratio
|
66 |
|
67 |
Output:
|
68 |
inpainted_frames: numpy array, T, H, W, 3
|
69 |
"""
|
70 |
-
frames = []
|
71 |
-
for file in frames_path:
|
72 |
-
frames.append(read_image_from_userfolder(file))
|
73 |
-
frames = np.asarray(frames)
|
74 |
-
|
75 |
assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
|
76 |
assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
|
|
|
|
|
|
|
|
|
77 |
masks = masks.copy()
|
78 |
masks = np.clip(masks, 0, 1)
|
79 |
kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
|
80 |
masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
|
81 |
-
|
82 |
T, H, W = masks.shape
|
83 |
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
84 |
# size: (w, h)
|
@@ -96,14 +97,37 @@ class BaseInpainter:
|
|
96 |
frames = resize_frames(frames, tuple(size)) # T, H, W, 3
|
97 |
# frames and binary_masks are numpy arrays
|
98 |
h, w = frames.shape[1:3]
|
99 |
-
video_length = T
|
100 |
-
|
101 |
# convert to tensor
|
102 |
imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
|
103 |
masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)
|
104 |
-
|
105 |
imgs, masks = imgs.to(self.device), masks.to(self.device)
|
106 |
comp_frames = [None] * video_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
|
109 |
neighbor_ids = [
|
@@ -111,8 +135,24 @@ class BaseInpainter:
|
|
111 |
min(video_length, f + self.neighbor_stride + 1))
|
112 |
]
|
113 |
ref_ids = self.get_ref_index(f, neighbor_ids, video_length)
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
with torch.no_grad():
|
117 |
masked_imgs = selected_imgs * (1 - selected_masks)
|
118 |
mod_size_h = 60
|
@@ -138,10 +178,75 @@ class BaseInpainter:
|
|
138 |
else:
|
139 |
comp_frames[idx] = comp_frames[idx].astype(
|
140 |
np.float32) * 0.5 + img.astype(np.float32) * 0.5
|
141 |
-
|
142 |
inpainted_frames = np.stack(comp_frames, 0)
|
143 |
return inpainted_frames.astype(np.uint8)
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
if __name__ == '__main__':
|
146 |
|
147 |
frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg'))
|
@@ -179,4 +284,4 @@ if __name__ == '__main__':
|
|
179 |
# save
|
180 |
for ti, inpainted_frame in enumerate(inpainted_frames):
|
181 |
frame = Image.fromarray(inpainted_frame).convert('RGB')
|
182 |
-
frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))
|
|
|
9 |
from tqdm import tqdm
|
10 |
from inpainter.util.tensor_util import resize_frames, resize_masks
|
11 |
|
12 |
+
def read_image_from_split(videp_split_path):
|
13 |
# if type:
|
14 |
+
image = np.asarray([np.asarray(Image.open(path)) for path in videp_split_path])
|
15 |
# else:
|
16 |
# image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
|
17 |
return image
|
|
|
56 |
break
|
57 |
ref_index.append(i)
|
58 |
return ref_index
|
59 |
+
|
60 |
+
def inpaint_efficient(self, frames, masks, num_tcb, num_tca, dilate_radius=15, ratio=1):
|
61 |
"""
|
62 |
+
Perform Inpainting for video subsets
|
63 |
frames: numpy array, T, H, W, 3
|
64 |
masks: numpy array, T, H, W
|
65 |
+
num_tcb: constant, number of temporal context before, frames
|
66 |
+
num_tca: constant, number of temporal context after, frames
|
67 |
dilate_radius: radius when applying dilation on masks
|
68 |
ratio: down-sample ratio
|
69 |
|
70 |
Output:
|
71 |
inpainted_frames: numpy array, T, H, W, 3
|
72 |
"""
|
|
|
|
|
|
|
|
|
|
|
73 |
assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
|
74 |
assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
|
75 |
+
|
76 |
+
# --------------------
|
77 |
+
# pre-processing
|
78 |
+
# --------------------
|
79 |
masks = masks.copy()
|
80 |
masks = np.clip(masks, 0, 1)
|
81 |
kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
|
82 |
masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
|
|
|
83 |
T, H, W = masks.shape
|
84 |
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
85 |
# size: (w, h)
|
|
|
97 |
frames = resize_frames(frames, tuple(size)) # T, H, W, 3
|
98 |
# frames and binary_masks are numpy arrays
|
99 |
h, w = frames.shape[1:3]
|
100 |
+
video_length = T - (num_tca + num_tcb) # real video length
|
|
|
101 |
# convert to tensor
|
102 |
imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
|
103 |
masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)
|
|
|
104 |
imgs, masks = imgs.to(self.device), masks.to(self.device)
|
105 |
comp_frames = [None] * video_length
|
106 |
+
tcb_imgs = None
|
107 |
+
tca_imgs = None
|
108 |
+
tcb_masks = None
|
109 |
+
tca_masks = None
|
110 |
+
# --------------------
|
111 |
+
# end of pre-processing
|
112 |
+
# --------------------
|
113 |
+
|
114 |
+
# separate tc frames/masks from imgs and masks
|
115 |
+
if num_tcb > 0:
|
116 |
+
tcb_imgs = imgs[:, :num_tcb]
|
117 |
+
tcb_masks = masks[:, :num_tcb]
|
118 |
+
tcb_binary = binary_masks[:num_tcb]
|
119 |
+
if num_tca > 0:
|
120 |
+
tca_imgs = imgs[:, -num_tca:]
|
121 |
+
tca_masks = masks[:, -num_tca:]
|
122 |
+
tca_binary = binary_masks[-num_tca:]
|
123 |
+
end_idx = -num_tca
|
124 |
+
else:
|
125 |
+
end_idx = T
|
126 |
+
|
127 |
+
imgs = imgs[:, num_tcb:end_idx]
|
128 |
+
masks = masks[:, num_tcb:end_idx]
|
129 |
+
binary_masks = binary_masks[num_tcb:end_idx] # only neighbor area are involved
|
130 |
+
frames = frames[num_tcb:end_idx] # only neighbor area are involved
|
131 |
|
132 |
for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
|
133 |
neighbor_ids = [
|
|
|
135 |
min(video_length, f + self.neighbor_stride + 1))
|
136 |
]
|
137 |
ref_ids = self.get_ref_index(f, neighbor_ids, video_length)
|
138 |
+
|
139 |
+
# selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
|
140 |
+
# selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
|
141 |
+
|
142 |
+
selected_imgs = imgs[:, neighbor_ids]
|
143 |
+
selected_masks = masks[:, neighbor_ids]
|
144 |
+
# pad before
|
145 |
+
if tcb_imgs is not None:
|
146 |
+
selected_imgs = torch.concat([selected_imgs, tcb_imgs], dim=1)
|
147 |
+
selected_masks = torch.concat([selected_masks, tcb_masks], dim=1)
|
148 |
+
# integrate ref frames
|
149 |
+
selected_imgs = torch.concat([selected_imgs, imgs[:, ref_ids]], dim=1)
|
150 |
+
selected_masks = torch.concat([selected_masks, masks[:, ref_ids]], dim=1)
|
151 |
+
# pad after
|
152 |
+
if tca_imgs is not None:
|
153 |
+
selected_imgs = torch.concat([selected_imgs, tca_imgs], dim=1)
|
154 |
+
selected_masks = torch.concat([selected_masks, tca_masks], dim=1)
|
155 |
+
|
156 |
with torch.no_grad():
|
157 |
masked_imgs = selected_imgs * (1 - selected_masks)
|
158 |
mod_size_h = 60
|
|
|
178 |
else:
|
179 |
comp_frames[idx] = comp_frames[idx].astype(
|
180 |
np.float32) * 0.5 + img.astype(np.float32) * 0.5
|
181 |
+
torch.cuda.empty_cache()
|
182 |
inpainted_frames = np.stack(comp_frames, 0)
|
183 |
return inpainted_frames.astype(np.uint8)
|
184 |
|
185 |
+
def inpaint(self, frames_path, masks, dilate_radius=15, ratio=1):
|
186 |
+
"""
|
187 |
+
Perform Inpainting for video subsets
|
188 |
+
frames: numpy array, T, H, W, 3
|
189 |
+
masks: numpy array, T, H, W
|
190 |
+
dilate_radius: radius when applying dilation on masks
|
191 |
+
ratio: down-sample ratio
|
192 |
+
|
193 |
+
Output:
|
194 |
+
inpainted_frames: numpy array, T, H, W, 3
|
195 |
+
"""
|
196 |
+
# assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
|
197 |
+
assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
|
198 |
+
|
199 |
+
# set interval
|
200 |
+
interval = 45
|
201 |
+
context_range = 10 # for each split, consider its temporal context [-context_range] frames and [context_range] frames
|
202 |
+
# split frames into subsets
|
203 |
+
video_length = len(frames_path)
|
204 |
+
num_splits = video_length // interval
|
205 |
+
id_splits = [[i*interval, (i+1)*interval] for i in range(num_splits)] # id splits
|
206 |
+
# if remaining split > interval/2, add a new split, else, append to the last split
|
207 |
+
if video_length - id_splits[-1][-1] > interval / 2:
|
208 |
+
id_splits.append([num_splits*interval, video_length])
|
209 |
+
else:
|
210 |
+
id_splits[-1][-1] = video_length
|
211 |
+
|
212 |
+
# perform inpainting for each split
|
213 |
+
inpainted_splits = []
|
214 |
+
for id_split in id_splits:
|
215 |
+
video_split_path = frames_path[id_split[0]:id_split[1]]
|
216 |
+
video_split = read_image_from_split(video_split_path)
|
217 |
+
mask_split = masks[id_split[0]:id_split[1]]
|
218 |
+
|
219 |
+
# | id_before | ----- | id_split[0] | ----- | id_split[1] | ----- | id_after |
|
220 |
+
# add temporal context
|
221 |
+
id_before = max(0, id_split[0] - self.step * context_range)
|
222 |
+
try:
|
223 |
+
tcb_frames = np.stack([np.array(Image.open(frames_path[idb])) for idb in range(id_before, id_split[0]-self.step, self.step)], 0)
|
224 |
+
tcb_masks = np.stack([masks[idb] for idb in range(id_before, id_split[0]-self.step, self.step)], 0)
|
225 |
+
num_tcb = len(tcb_frames)
|
226 |
+
except:
|
227 |
+
num_tcb = 0
|
228 |
+
id_after = min(video_length, id_split[1] + self.step * context_range)
|
229 |
+
try:
|
230 |
+
tca_frames = np.stack([np.array(Image.open(frames_path[ida])) for ida in range(id_split[1]+self.step, id_after, self.step)], 0)
|
231 |
+
tca_masks = np.stack([masks[ida] for ida in range(id_split[1]+self.step, id_after, self.step)], 0)
|
232 |
+
num_tca = len(tca_frames)
|
233 |
+
except:
|
234 |
+
num_tca = 0
|
235 |
+
|
236 |
+
# concatenate temporal context frames/masks with input frames/masks (for parallel pre-processing)
|
237 |
+
if num_tcb > 0:
|
238 |
+
video_split = np.concatenate([tcb_frames, video_split], 0)
|
239 |
+
mask_split = np.concatenate([tcb_masks, mask_split], 0)
|
240 |
+
if num_tca > 0:
|
241 |
+
video_split = np.concatenate([video_split, tca_frames], 0)
|
242 |
+
mask_split = np.concatenate([mask_split, tca_masks], 0)
|
243 |
+
|
244 |
+
# inpaint each split
|
245 |
+
inpainted_splits.append(self.inpaint_efficient(video_split, mask_split, num_tcb, num_tca, dilate_radius, ratio))
|
246 |
+
|
247 |
+
inpainted_frames = np.concatenate(inpainted_splits, 0)
|
248 |
+
return inpainted_frames.astype(np.uint8)
|
249 |
+
|
250 |
if __name__ == '__main__':
|
251 |
|
252 |
frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg'))
|
|
|
284 |
# save
|
285 |
for ti, inpainted_frame in enumerate(inpainted_frames):
|
286 |
frame = Image.fromarray(inpainted_frame).convert('RGB')
|
287 |
+
frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))
|
test_sample/test-sample13.mp4
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54f24a0aaae482aff7ff3555256f60ad1931d478dc5694fb37624cac85479eee
|
3 |
+
size 2528426
|
test_sample/test-sample4.mp4
CHANGED
Binary files a/test_sample/test-sample4.mp4 and b/test_sample/test-sample4.mp4 differ
|
|
test_sample/test-sample8.mp4
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2414d24cc1ddfe1619c17e9876a7c3ed0f1f37da234c63c08af2cecbbb16c1ed
|
3 |
+
size 8714250
|