|
from PIL import Image |
|
import torch |
|
|
|
from transformers import ( |
|
AutoModelForImageClassification, |
|
AutoImageProcessor, |
|
Pipeline, |
|
) |
|
|
|
import numpy as np |
|
from typing import Union |
|
|
|
class SiglipTaggerPipe(Pipeline): |
|
def __init__(self,**kwargs): |
|
self.processor = AutoImageProcessor.from_pretrained("p1atdev/siglip-tagger-test-3") |
|
if "torch_dtype" not in kwargs : |
|
kwargs["torch_dtype"] = torch.bfloat16 |
|
Pipeline.__init__(self,**kwargs) |
|
def _sanitize_parameters(self, **kwargs): |
|
postprocess_kwargs = {} |
|
if "threshold" in kwargs : |
|
|
|
|
|
postprocess_kwargs["threshold"] = kwargs["threshold"] |
|
if "return_scores" in kwargs : |
|
postprocess_kwargs["return_scores"] = kwargs["return_scores"] |
|
return {},{},postprocess_kwargs |
|
|
|
def preprocess(self,inputs: Union[str,Image.Image,np.ndarray]): |
|
if isinstance(inputs,str) : |
|
img = Image.open(inputs) |
|
elif isinstance(inputs,Image.Image) : |
|
img = inputs |
|
else : |
|
|
|
|
|
|
|
img = Image.fromarray(inputs) |
|
|
|
inputs = self.processor(img, return_tensors="pt").to(self.model.device, self.model.dtype) |
|
return inputs |
|
|
|
def _forward(self,inputs): |
|
logits = self.model(**inputs).logits.detach().cpu().float()[0] |
|
logits = np.clip(logits, 0.0, 1.0) |
|
return logits |
|
def postprocess(self,logits,threshold:float=0,return_scores=False): |
|
results = { |
|
self.model.config.id2label[i]: logit for i, logit in enumerate(logits) if logit > 0 |
|
} |
|
results = sorted(results.items(), key=lambda x: x[1], reverse=True) |
|
out = {} |
|
for tag, score in results: |
|
if score >= threshold : |
|
out[tag] = f"{score*100:.2f}" |
|
if return_scores == True : |
|
return out |
|
else : |
|
return ", ".join(list(out.keys())) |
|
|