|
import torch |
|
from torch import nn |
|
|
|
from transformers import ResNetPreTrainedModel |
|
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention |
|
from transformers.image_processing_utils import BaseImageProcessor |
|
from transformers import ResNetConfig, ResNetModel |
|
from typing import Optional |
|
|
|
class ResNetForZeroBitWatermarkDetection(ResNetPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.resnet = ResNetModel(config) |
|
self.classifier = nn.Sequential( |
|
nn.Flatten(), |
|
nn.Linear(config.hidden_sizes[-1], 128), |
|
nn.ReLU(), |
|
nn.Dropout(0.2), |
|
nn.Linear(128, 1)) |
|
self.register_buffer('exp', torch.tensor([1.0])) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> ImageClassifierOutputWithNoAttention: |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.resnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) |
|
|
|
pooled_output = outputs.pooler_output if return_dict else outputs[1] |
|
|
|
x = self.classifier(pooled_output) |
|
|
|
|
|
x = 0.5 + torch.sign(x) * 0.5 * torch.special.gammainc(1 / self.exp, torch.abs(x)**self.exp) |
|
|
|
|
|
|
|
|
|
|
|
logits = torch.log(x) - torch.log1p(-x) |
|
|
|
loss = None |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) |
|
|