ProPainter / scripts /evaluate_flow_completion.py
sczhou's picture
init code
320e465
raw
history blame
8.2 kB
# -*- coding: utf-8 -*-
import sys
sys.path.append(".")
import cv2
import os
import numpy as np
import argparse
from PIL import Image
import torch
from torch.utils.data import DataLoader
from core.dataset import TestDataset
from model.modules.flow_comp_raft import RAFT_bi
from model.recurrent_flow_completion import RecurrentFlowCompleteNet
from RAFT.utils.flow_viz_pt import flow_to_image
import cvbase
import imageio
from time import time
import warnings
warnings.filterwarnings("ignore")
def create_dir(dir):
"""Creates a directory if not exist.
"""
if not os.path.exists(dir):
os.makedirs(dir)
def save_flows(output, videoFlowF, videoFlowB):
# create_dir(os.path.join(output, 'forward_flo'))
# create_dir(os.path.join(output, 'backward_flo'))
create_dir(os.path.join(output, 'forward_png'))
create_dir(os.path.join(output, 'backward_png'))
N = videoFlowF.shape[-1]
for i in range(N):
forward_flow = videoFlowF[..., i]
backward_flow = videoFlowB[..., i]
forward_flow_vis = cvbase.flow2rgb(forward_flow)
backward_flow_vis = cvbase.flow2rgb(backward_flow)
# cvbase.write_flow(forward_flow, os.path.join(output, 'forward_flo', '{:05d}.flo'.format(i)))
# cvbase.write_flow(backward_flow, os.path.join(output, 'backward_flo', '{:05d}.flo'.format(i)))
forward_flow_vis = (forward_flow_vis*255.0).astype(np.uint8)
backward_flow_vis = (backward_flow_vis*255.0).astype(np.uint8)
imageio.imwrite(os.path.join(output, 'forward_png', '{:05d}.png'.format(i)), forward_flow_vis)
imageio.imwrite(os.path.join(output, 'backward_png', '{:05d}.png'.format(i)), backward_flow_vis)
def tensor2np(array):
array = torch.stack(array, dim=-1).squeeze(0).permute(1, 2, 0, 3).cpu().numpy()
return array
def main_worker(args):
# set up datasets and data loader
args.size = (args.width, args.height)
test_dataset = TestDataset(vars(args))
test_loader = DataLoader(test_dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers)
# set up models
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
fix_raft = RAFT_bi(args.raft_model_path, device)
fix_flow_complete = RecurrentFlowCompleteNet(args.fc_model_path)
for p in fix_flow_complete.parameters():
p.requires_grad = False
fix_flow_complete.to(device)
fix_flow_complete.eval()
total_frame_epe = []
time_all = []
print('Start evaluation...')
# create results directory
result_path = os.path.join('results_flow', f'{args.dataset}')
if not os.path.exists(result_path):
os.makedirs(result_path)
eval_summary = open(os.path.join(result_path, f"{args.dataset}_metrics.txt"), "w")
for index, items in enumerate(test_loader):
frames, masks, flows_f, flows_b, video_name, frames_PIL = items
local_masks = masks.float().to(device)
video_length = frames.size(1)
if args.load_flow:
gt_flows_bi = (flows_f.to(device), flows_b.to(device))
else:
short_len = 60
if frames.size(1) > short_len:
gt_flows_f_list, gt_flows_b_list = [], []
for f in range(0, video_length, short_len):
end_f = min(video_length, f + short_len)
if f == 0:
flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter)
else:
flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter)
gt_flows_f_list.append(flows_f)
gt_flows_b_list.append(flows_b)
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
gt_flows_bi = (gt_flows_f, gt_flows_b)
else:
gt_flows_bi = fix_raft(frames, iters=20)
torch.cuda.synchronize()
time_start = time()
# flow_length = flows_f.size(1)
# f_stride = 30
# pred_flows_f = []
# pred_flows_b = []
# suffix = flow_length%f_stride
# last = flow_length//f_stride
# for f in range(0, flow_length, f_stride):
# gt_flows_bi_i = (gt_flows_bi[0][:,f:f+f_stride], gt_flows_bi[1][:,f:f+f_stride])
# pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi_i, local_masks[:,f:f+f_stride+1])
# pred_flows_f_i, pred_flows_b_i = fix_flow_complete.combine_flow(gt_flows_bi_i, pred_flows_bi, local_masks[:,f:f+f_stride+1])
# pred_flows_f.append(pred_flows_f_i)
# pred_flows_b.append(pred_flows_b_i)
# pred_flows_f = torch.cat(pred_flows_f, dim=1)
# pred_flows_b = torch.cat(pred_flows_b, dim=1)
# pred_flows_bi = (pred_flows_f, pred_flows_b)
pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks)
pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks)
torch.cuda.synchronize()
time_i = time() - time_start
time_i = time_i*1.0/frames.size(1)
time_all = time_all+[time_i]*frames.size(1)
cur_video_epe = []
epe1 = torch.mean(torch.sum((flows_f - pred_flows_bi[0].cpu())**2, dim=2).sqrt())
epe2 = torch.mean(torch.sum((flows_b - pred_flows_bi[1].cpu())**2, dim=2).sqrt())
cur_video_epe.append(epe1.numpy())
cur_video_epe.append(epe2.numpy())
total_frame_epe = total_frame_epe+[epe1.numpy()]*flows_f.size(1)
total_frame_epe = total_frame_epe+[epe2.numpy()]*flows_f.size(1)
cur_epe = sum(cur_video_epe) / len(cur_video_epe)
avg_time = sum(time_all) / len(time_all)
print(
f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}'
)
eval_summary.write(
f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}\n'
)
# saving images for evaluating warpping errors
if args.save_results:
forward_flows = pred_flows_bi[0].cpu().permute(1,0,2,3,4)
backward_flows = pred_flows_bi[1].cpu().permute(1,0,2,3,4)
# forward_flows = flows_f.cpu().permute(1,0,2,3,4)
# backward_flows = flows_b.cpu().permute(1,0,2,3,4)
videoFlowF = list(forward_flows)
videoFlowB = list(backward_flows)
videoFlowF = tensor2np(videoFlowF)
videoFlowB = tensor2np(videoFlowB)
save_frame_path = os.path.join(result_path, video_name[0])
save_flows(save_frame_path, videoFlowF, videoFlowB)
avg_frame_epe = sum(total_frame_epe) / len(total_frame_epe)
print(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}')
eval_summary.write(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}\n')
eval_summary.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--height', type=int, default=240)
parser.add_argument('--width', type=int, default=432)
parser.add_argument('--raft_model_path', default='weights/raft-things.pth', type=str)
parser.add_argument('--fc_model_path', default='weights/recurrent_flow_completion.pth', type=str)
parser.add_argument('--dataset', choices=['davis', 'youtube-vos'], type=str)
parser.add_argument('--video_root', default='dataset_root', type=str)
parser.add_argument('--mask_root', default='mask_root', type=str)
parser.add_argument('--flow_root', default='flow_ground_truth_root', type=str)
parser.add_argument('--load_flow', default=False, type=bool)
parser.add_argument("--raft_iter", type=int, default=20)
parser.add_argument('--save_results', action='store_true')
parser.add_argument('--num_workers', default=4, type=int)
args = parser.parse_args()
main_worker(args)