File size: 5,797 Bytes
320e465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

def flow_warp(x,

              flow,

              interpolation='bilinear',

              padding_mode='zeros',

              align_corners=True):
    """Warp an image or a feature map with optical flow.

    Args:

        x (Tensor): Tensor with size (n, c, h, w).

        flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is

            a two-channel, denoting the width and height relative offsets.

            Note that the values are not normalized to [-1, 1].

        interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.

            Default: 'bilinear'.

        padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.

            Default: 'zeros'.

        align_corners (bool): Whether align corners. Default: True.

    Returns:

        Tensor: Warped image or feature map.

    """
    if x.size()[-2:] != flow.size()[1:3]:
        raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
                         f'flow ({flow.size()[1:3]}) are not the same.')
    _, _, h, w = x.size()
    # create mesh grid
    device = flow.device
    grid_y, grid_x = torch.meshgrid(torch.arange(0, h, device=device), torch.arange(0, w, device=device))
    grid = torch.stack((grid_x, grid_y), 2).type_as(x)  # (w, h, 2)
    grid.requires_grad = False

    grid_flow = grid + flow
    # scale grid_flow to [-1,1]
    grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
    grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
    grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
    output = F.grid_sample(x,
                           grid_flow,
                           mode=interpolation,
                           padding_mode=padding_mode,
                           align_corners=align_corners)
    return output


# def image_warp(image, flow):
#     b, c, h, w = image.size()
#     device = image.device
#     flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) # normalize to [-1~1](from upper left to lower right
#     flow = flow.permute(0, 2, 3, 1) # if you wanna use grid_sample function, the channel(band) shape of show must be in the last dimension
#     x = np.linspace(-1, 1, w)
#     y = np.linspace(-1, 1, h)
#     X, Y = np.meshgrid(x, y)
#     grid = torch.cat((torch.from_numpy(X.astype('float32')).unsqueeze(0).unsqueeze(3),
#                       torch.from_numpy(Y.astype('float32')).unsqueeze(0).unsqueeze(3)), 3).to(device)
#     output = torch.nn.functional.grid_sample(image, grid + flow, mode='bilinear', padding_mode='zeros')
#     return output


def length_sq(x):
    return torch.sum(torch.square(x), dim=1, keepdim=True)


def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5):
    flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1))  # wb(wf(x))
    flow_fw_warped = flow_warp(flow_fw, flow_bw.permute(0, 2, 3, 1))  # wf(wb(x))
    flow_diff_fw = flow_fw + flow_bw_warped  # wf + wb(wf(x))
    flow_diff_bw = flow_bw + flow_fw_warped  # wb + wf(wb(x))

    mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped)  # |wf| + |wb(wf(x))|
    mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped)  # |wb| + |wf(wb(x))|
    occ_thresh_fw = alpha1 * mag_sq_fw + alpha2
    occ_thresh_bw = alpha1 * mag_sq_bw + alpha2

    fb_occ_fw = (length_sq(flow_diff_fw) > occ_thresh_fw).float()
    fb_occ_bw = (length_sq(flow_diff_bw) > occ_thresh_bw).float()

    return fb_occ_fw, fb_occ_bw  # fb_occ_fw -> frame2 area occluded by frame1, fb_occ_bw -> frame1 area occluded by frame2


def rgb2gray(image):
    gray_image = image[:, 0] * 0.299 + image[:, 1] * 0.587 + 0.110 * image[:, 2]
    gray_image = gray_image.unsqueeze(1)
    return gray_image


def ternary_transform(image, max_distance=1):
    device = image.device
    patch_size = 2 * max_distance + 1
    intensities = rgb2gray(image) * 255
    out_channels = patch_size * patch_size
    w = np.eye(out_channels).reshape(out_channels, 1, patch_size, patch_size)
    weights = torch.from_numpy(w).float().to(device)
    patches = F.conv2d(intensities, weights, stride=1, padding=1)
    transf = patches - intensities
    transf_norm = transf / torch.sqrt(0.81 + torch.square(transf))
    return transf_norm


def hamming_distance(t1, t2):
    dist = torch.square(t1 - t2)
    dist_norm = dist / (0.1 + dist)
    dist_sum = torch.sum(dist_norm, dim=1, keepdim=True)
    return dist_sum


def create_mask(mask, paddings):
    """

    padding: [[top, bottom], [left, right]]

    """
    shape = mask.shape
    inner_height = shape[2] - (paddings[0][0] + paddings[0][1])
    inner_width = shape[3] - (paddings[1][0] + paddings[1][1])
    inner = torch.ones([inner_height, inner_width])

    mask2d = F.pad(inner, pad=[paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]]) 
    mask3d = mask2d.unsqueeze(0)
    mask4d = mask3d.unsqueeze(0).repeat(shape[0], 1, 1, 1)
    return mask4d.detach()


def ternary_loss2(frame1, warp_frame21, confMask, masks, max_distance=1):
    """



    Args:

        frame1: torch tensor, with shape [b * t, c, h, w]

        warp_frame21: torch tensor, with shape [b * t, c, h, w]

        confMask: confidence mask, with shape [b * t, c, h, w]

        masks: torch tensor, with shape [b * t, c, h, w]

        max_distance: maximum distance.



    Returns: ternary loss



    """
    t1 = ternary_transform(frame1)
    t21 = ternary_transform(warp_frame21)
    dist = hamming_distance(t1, t21) 
    loss = torch.mean(dist * confMask * masks) / torch.mean(masks)
    return loss