Spaces:
Running
Running
File size: 3,852 Bytes
a660631 b5515fe f521e88 a660631 b5515fe a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 b5515fe f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 b5515fe f521e88 a660631 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gc
import numpy as np
import PIL.Image
import torch
import torchvision
from controlnet_aux import (
CannyDetector,
ContentShuffleDetector,
HEDdetector,
LineartAnimeDetector,
LineartDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
OpenposeDetector,
PidiNetDetector,
)
from controlnet_aux.util import HWC3
from cv_utils import resize_image
from depth_estimator import DepthEstimator
from image_segmentor import ImageSegmentor
from kornia.core import Tensor
from kornia.filters import canny
class Canny:
def __call__(
self,
images: np.array,
low_threshold: float = 0.1,
high_threshold: float = 0.2,
kernel_size: tuple[int, int] | int = (5, 5),
sigma: tuple[float, float] | Tensor = (1, 1),
hysteresis: bool = True,
eps: float = 1e-6
) -> torch.Tensor:
assert low_threshold is not None, "low_threshold must be provided"
assert high_threshold is not None, "high_threshold must be provided"
images = torch.from_numpy(images).permute(2, 0, 1).unsqueeze(0) / 255.0
images_tensor = canny(images, low_threshold, high_threshold, kernel_size, sigma, hysteresis, eps)[1]
images_tensor = (images_tensor[0][0].numpy() * 255).astype(np.uint8)
return images_tensor
class Preprocessor:
MODEL_ID = "lllyasviel/Annotators"
def __init__(self):
self.model = None
self.name = ""
def load(self, name: str) -> None:
if name == self.name:
return
if name == "HED":
self.model = HEDdetector.from_pretrained(self.MODEL_ID)
elif name == "Midas":
self.model = MidasDetector.from_pretrained(self.MODEL_ID)
elif name == "MLSD":
self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
elif name == "Openpose":
self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
elif name == "PidiNet":
self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
elif name == "NormalBae":
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
elif name == "Lineart":
self.model = LineartDetector.from_pretrained(self.MODEL_ID)
elif name == "LineartAnime":
self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
elif name == "Canny":
self.model = Canny()
elif name == "ContentShuffle":
self.model = ContentShuffleDetector()
elif name == "DPT":
self.model = DepthEstimator()
elif name == "UPerNet":
self.model = ImageSegmentor()
else:
raise ValueError
torch.cuda.empty_cache()
gc.collect()
self.name = name
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
if self.name == "Canny":
if "detect_resolution" in kwargs:
detect_resolution = kwargs.pop("detect_resolution")
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
return PIL.Image.fromarray(image).convert('RGB')
elif self.name == "Midas":
detect_resolution = kwargs.pop("detect_resolution", 512)
image_resolution = kwargs.pop("image_resolution", 512)
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
return PIL.Image.fromarray(image)
else:
return self.model(image, **kwargs)
|