|
import gc |
|
|
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
import torchvision |
|
from controlnet_aux import ( |
|
CannyDetector, |
|
ContentShuffleDetector, |
|
HEDdetector, |
|
LineartAnimeDetector, |
|
LineartDetector, |
|
MidasDetector, |
|
MLSDdetector, |
|
NormalBaeDetector, |
|
OpenposeDetector, |
|
PidiNetDetector, |
|
) |
|
from controlnet_aux.util import HWC3 |
|
|
|
from cv_utils import resize_image |
|
from depth_estimator import DepthEstimator |
|
from image_segmentor import ImageSegmentor |
|
|
|
from kornia.core import Tensor |
|
from kornia.filters import canny |
|
|
|
|
|
class Canny: |
|
|
|
def __call__( |
|
self, |
|
images: np.array, |
|
low_threshold: float = 0.1, |
|
high_threshold: float = 0.2, |
|
kernel_size: tuple[int, int] | int = (5, 5), |
|
sigma: tuple[float, float] | Tensor = (1, 1), |
|
hysteresis: bool = True, |
|
eps: float = 1e-6 |
|
) -> torch.Tensor: |
|
|
|
assert low_threshold is not None, "low_threshold must be provided" |
|
assert high_threshold is not None, "high_threshold must be provided" |
|
|
|
images = torch.from_numpy(images).permute(2, 0, 1).unsqueeze(0) / 255.0 |
|
|
|
images_tensor = canny(images, low_threshold, high_threshold, kernel_size, sigma, hysteresis, eps)[1] |
|
images_tensor = (images_tensor[0][0].numpy() * 255).astype(np.uint8) |
|
return images_tensor |
|
|
|
|
|
class Preprocessor: |
|
MODEL_ID = "lllyasviel/Annotators" |
|
|
|
def __init__(self): |
|
self.model = None |
|
self.name = "" |
|
|
|
def load(self, name: str) -> None: |
|
if name == self.name: |
|
return |
|
if name == "HED": |
|
self.model = HEDdetector.from_pretrained(self.MODEL_ID) |
|
elif name == "Midas": |
|
self.model = MidasDetector.from_pretrained(self.MODEL_ID) |
|
elif name == "MLSD": |
|
self.model = MLSDdetector.from_pretrained(self.MODEL_ID) |
|
elif name == "Openpose": |
|
self.model = OpenposeDetector.from_pretrained(self.MODEL_ID) |
|
elif name == "PidiNet": |
|
self.model = PidiNetDetector.from_pretrained(self.MODEL_ID) |
|
elif name == "NormalBae": |
|
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID) |
|
elif name == "Lineart": |
|
self.model = LineartDetector.from_pretrained(self.MODEL_ID) |
|
elif name == "LineartAnime": |
|
self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID) |
|
elif name == "Canny": |
|
self.model = Canny() |
|
elif name == "ContentShuffle": |
|
self.model = ContentShuffleDetector() |
|
elif name == "DPT": |
|
self.model = DepthEstimator() |
|
elif name == "UPerNet": |
|
self.model = ImageSegmentor() |
|
else: |
|
raise ValueError |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
self.name = name |
|
|
|
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: |
|
if self.name == "Canny": |
|
if "detect_resolution" in kwargs: |
|
detect_resolution = kwargs.pop("detect_resolution") |
|
image = np.array(image) |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=detect_resolution) |
|
image = self.model(image, **kwargs) |
|
return PIL.Image.fromarray(image).convert('RGB') |
|
elif self.name == "Midas": |
|
detect_resolution = kwargs.pop("detect_resolution", 512) |
|
image_resolution = kwargs.pop("image_resolution", 512) |
|
image = np.array(image) |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=detect_resolution) |
|
image = self.model(image, **kwargs) |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
return PIL.Image.fromarray(image) |
|
else: |
|
return self.model(image, **kwargs) |
|
|