sergeipetrov's picture
Update handler.py
034b0b6 verified
from typing import Dict, List, Any
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import torch
import base64
import logging
import numpy as np
from PIL import Image
from io import BytesIO
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
# check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path=""):
# load the model
self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
# move model to device
self.model.to(device)
def __call__(self, data: Any):
image = data["inputs"]
inputs = self.processor(image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = self.model(**inputs)
output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.moveaxis(output, source=0, destination=-1)
output = (output * 255.0).round().astype(np.uint8)
img = Image.fromarray(output)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode()