|
import gc |
|
|
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector, |
|
LineartAnimeDetector, LineartDetector, |
|
MidasDetector, MLSDdetector, NormalBaeDetector, |
|
OpenposeDetector, PidiNetDetector) |
|
from controlnet_aux.util import HWC3 |
|
|
|
from cv_utils import resize_image |
|
from depth_estimator import DepthEstimator |
|
from image_segmentor import ImageSegmentor |
|
|
|
|
|
class Preprocessor: |
|
MODEL_ID = 'lllyasviel/Annotators' |
|
|
|
def __init__(self): |
|
self.model = None |
|
self.name = '' |
|
|
|
def load(self, name: str) -> None: |
|
if name == self.name: |
|
return |
|
if name == 'HED': |
|
self.model = HEDdetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'Midas': |
|
self.model = MidasDetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'MLSD': |
|
self.model = MLSDdetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'Openpose': |
|
self.model = OpenposeDetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'PidiNet': |
|
self.model = PidiNetDetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'NormalBae': |
|
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'Lineart': |
|
self.model = LineartDetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'LineartAnime': |
|
self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'Canny': |
|
self.model = CannyDetector() |
|
elif name == 'ContentShuffle': |
|
self.model = ContentShuffleDetector() |
|
elif name == 'DPT': |
|
self.model = DepthEstimator() |
|
elif name == 'UPerNet': |
|
self.model = ImageSegmentor() |
|
else: |
|
raise ValueError |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
self.name = name |
|
|
|
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: |
|
if self.name == 'Canny': |
|
if 'detect_resolution' in kwargs: |
|
detect_resolution = kwargs.pop('detect_resolution') |
|
image = np.array(image) |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=detect_resolution) |
|
image = self.model(image, **kwargs) |
|
return PIL.Image.fromarray(image) |
|
elif self.name == 'Midas': |
|
detect_resolution = kwargs.pop('detect_resolution', 512) |
|
image_resolution = kwargs.pop('image_resolution', 512) |
|
image = np.array(image) |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=detect_resolution) |
|
image = self.model(image, **kwargs) |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
return PIL.Image.fromarray(image) |
|
else: |
|
return self.model(image, **kwargs) |
|
|