import os from typing import Dict, List, Tuple import numpy as np import onnxruntime as ort from PIL import Image from PIL.Image import Image as PILImage class BaseSession: def __init__( self, model_name: str, sess_opts: ort.SessionOptions, providers=None, *args, **kwargs ): self.model_name = model_name self.providers = [] _providers = ort.get_available_providers() if providers: for provider in providers: if provider in _providers: self.providers.append(provider) else: self.providers.extend(_providers) self.inner_session = ort.InferenceSession( str(self.__class__.download_models()), providers=self.providers, sess_options=sess_opts, ) def normalize( self, img: PILImage, mean: Tuple[float, float, float], std: Tuple[float, float, float], size: Tuple[int, int], *args, **kwargs ) -> Dict[str, np.ndarray]: im = img.convert("RGB").resize(size, Image.LANCZOS) im_ary = np.array(im) im_ary = im_ary / np.max(im_ary) tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3)) tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0] tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1] tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2] tmpImg = tmpImg.transpose((2, 0, 1)) return { self.inner_session.get_inputs()[0] .name: np.expand_dims(tmpImg, 0) .astype(np.float32) } def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: raise NotImplementedError @classmethod def checksum_disabled(cls, *args, **kwargs): return os.getenv("MODEL_CHECKSUM_DISABLED", None) is not None @classmethod def u2net_home(cls, *args, **kwargs): return os.path.expanduser( os.getenv( "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net") ) ) @classmethod def download_models(cls, *args, **kwargs): raise NotImplementedError @classmethod def name(cls, *args, **kwargs): raise NotImplementedError