File size: 3,354 Bytes
320e465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import List, Dict, Optional
from omegaconf import DictConfig

import torch
import torch.nn as nn
import torch.nn.functional as F
from tracker.model.transformer.positional_encoding import PositionalEncoding


# @torch.jit.script
def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor,
                      logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
    # value: B*num_objects*H*W*value_dim
    # logits: B*num_objects*H*W*num_summaries
    # masks: B*num_objects*H*W*num_summaries: 1 if allowed
    weights = logits.sigmoid() * masks
    # B*num_objects*num_summaries*value_dim
    sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value)
    # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1
    area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1)

    # B*num_objects*num_summaries*value_dim
    return sums, area


class ObjectSummarizer(nn.Module):
    def __init__(self, model_cfg: DictConfig):
        super().__init__()

        this_cfg = model_cfg.object_summarizer
        self.value_dim = model_cfg.value_dim
        self.embed_dim = this_cfg.embed_dim
        self.num_summaries = this_cfg.num_summaries
        self.add_pe = this_cfg.add_pe
        self.pixel_pe_scale = model_cfg.pixel_pe_scale
        self.pixel_pe_temperature = model_cfg.pixel_pe_temperature

        if self.add_pe:
            self.pos_enc = PositionalEncoding(self.embed_dim,
                                              scale=self.pixel_pe_scale,
                                              temperature=self.pixel_pe_temperature)

        self.input_proj = nn.Linear(self.value_dim, self.embed_dim)
        self.feature_pred = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dim, self.embed_dim),
        )
        self.weights_pred = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dim, self.num_summaries),
        )

    def forward(self,
                masks: torch.Tensor,
                value: torch.Tensor,
                need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]):
        # masks: B*num_objects*(H0)*(W0)
        # value: B*num_objects*value_dim*H*W
        # -> B*num_objects*H*W*value_dim
        h, w = value.shape[-2:]
        masks = F.interpolate(masks, size=(h, w), mode='area')
        masks = masks.unsqueeze(-1)
        inv_masks = 1 - masks
        repeated_masks = torch.cat([
            masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
            inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
        ],
                                   dim=-1)

        value = value.permute(0, 1, 3, 4, 2)
        value = self.input_proj(value)
        if self.add_pe:
            pe = self.pos_enc(value)
            value = value + pe

        with torch.cuda.amp.autocast(enabled=False):
            value = value.float()
            feature = self.feature_pred(value)
            logits = self.weights_pred(value)
            sums, area = _weighted_pooling(repeated_masks, feature, logits)

        summaries = torch.cat([sums, area], dim=-1)

        if need_weights:
            return summaries, logits
        else:
            return summaries, None