|
import torch |
|
import torchvision.transforms as tvf |
|
|
|
from ..utils.base_model import BaseModel |
|
|
|
|
|
class OpenIBL(BaseModel): |
|
default_conf = { |
|
"model_name": "vgg16_netvlad", |
|
} |
|
required_inputs = ["image"] |
|
|
|
def _init(self, conf): |
|
self.net = torch.hub.load( |
|
"yxgeee/OpenIBL", conf["model_name"], pretrained=True |
|
).eval() |
|
mean = [0.48501960784313836, 0.4579568627450961, 0.4076039215686255] |
|
std = [0.00392156862745098, 0.00392156862745098, 0.00392156862745098] |
|
self.norm_rgb = tvf.Normalize(mean=mean, std=std) |
|
|
|
def _forward(self, data): |
|
image = self.norm_rgb(data["image"]) |
|
desc = self.net(image) |
|
return { |
|
"global_descriptor": desc, |
|
} |
|
|