Spaces:
Runtime error
Runtime error
File size: 6,330 Bytes
fb9d4c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, Dict, List
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.structures import Instances
from densepose.data.meshes.catalog import MeshCatalog
from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix
from densepose.structures.mesh import create_mesh
from .embed_utils import PackedCseAnnotations
from .utils import BilinearInterpolationHelper
class SoftEmbeddingLoss:
"""
Computes losses for estimated embeddings given annotated vertices.
Instances in a minibatch that correspond to the same mesh are grouped
together. For each group, loss is computed as cross-entropy for
unnormalized scores given ground truth mesh vertex ids.
Scores are based on:
1) squared distances between estimated vertex embeddings
and mesh vertex embeddings;
2) geodesic distances between vertices of a mesh
"""
def __init__(self, cfg: CfgNode):
"""
Initialize embedding loss from config
"""
self.embdist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA
self.geodist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA
def __call__(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
packed_annotations: PackedCseAnnotations,
interpolator: BilinearInterpolationHelper,
embedder: nn.Module,
) -> Dict[int, torch.Tensor]:
"""
Produces losses for estimated embeddings given annotated vertices.
Embeddings for all the vertices of a mesh are computed by the embedder.
Embeddings for observed pixels are estimated by a predictor.
Losses are computed as cross-entropy for unnormalized scores given
ground truth vertex IDs.
1) squared distances between estimated vertex embeddings
and mesh vertex embeddings;
2) geodesic distances between vertices of a mesh
Args:
proposals_with_gt (list of Instances): detections with associated
ground truth data; each item corresponds to instances detected
on 1 image; the number of items corresponds to the number of
images in a batch
densepose_predictor_outputs: an object of a dataclass that contains predictor
outputs with estimated values; assumed to have the following attributes:
* embedding - embedding estimates, tensor of shape [N, D, S, S], where
N = number of instances (= sum N_i, where N_i is the number of
instances on image i)
D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE)
S = output size (width and height)
packed_annotations (PackedCseAnnotations): contains various data useful
for loss computation, each data is packed into a single tensor
interpolator (BilinearInterpolationHelper): bilinear interpolation helper
embedder (nn.Module): module that computes vertex embeddings for different meshes
Return:
dict(int -> tensor): losses for different mesh IDs
"""
losses = {}
for mesh_id_tensor in packed_annotations.vertex_mesh_ids_gt.unique():
mesh_id = mesh_id_tensor.item()
mesh_name = MeshCatalog.get_mesh_name(mesh_id)
# valid points are those that fall into estimated bbox
# and correspond to the current mesh
j_valid = interpolator.j_valid * ( # pyre-ignore[16]
packed_annotations.vertex_mesh_ids_gt == mesh_id
)
if not torch.any(j_valid):
continue
# extract estimated embeddings for valid points
# -> tensor [J, D]
vertex_embeddings_i = normalize_embeddings(
interpolator.extract_at_points(
densepose_predictor_outputs.embedding,
slice_fine_segm=slice(None),
w_ylo_xlo=interpolator.w_ylo_xlo[:, None], # pyre-ignore[16]
w_ylo_xhi=interpolator.w_ylo_xhi[:, None], # pyre-ignore[16]
w_yhi_xlo=interpolator.w_yhi_xlo[:, None], # pyre-ignore[16]
w_yhi_xhi=interpolator.w_yhi_xhi[:, None], # pyre-ignore[16]
)[j_valid, :]
)
# extract vertex ids for valid points
# -> tensor [J]
vertex_indices_i = packed_annotations.vertex_ids_gt[j_valid]
# embeddings for all mesh vertices
# -> tensor [K, D]
mesh_vertex_embeddings = embedder(mesh_name)
# softmax values of geodesic distances for GT mesh vertices
# -> tensor [J, K]
mesh = create_mesh(mesh_name, mesh_vertex_embeddings.device)
geodist_softmax_values = F.softmax(
mesh.geodists[vertex_indices_i] / (-self.geodist_gauss_sigma), dim=1
)
# logsoftmax values for valid points
# -> tensor [J, K]
embdist_logsoftmax_values = F.log_softmax(
squared_euclidean_distance_matrix(vertex_embeddings_i, mesh_vertex_embeddings)
/ (-self.embdist_gauss_sigma),
dim=1,
)
losses[mesh_name] = (-geodist_softmax_values * embdist_logsoftmax_values).sum(1).mean()
for mesh_name in embedder.mesh_names:
if mesh_name not in losses:
losses[mesh_name] = self.fake_value(
densepose_predictor_outputs, embedder, mesh_name
)
return losses
def fake_values(self, densepose_predictor_outputs: Any, embedder: nn.Module):
losses = {}
for mesh_name in embedder.mesh_names:
losses[mesh_name] = self.fake_value(densepose_predictor_outputs, embedder, mesh_name)
return losses
def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module, mesh_name: str):
return densepose_predictor_outputs.embedding.sum() * 0 + embedder(mesh_name).sum() * 0
|