Spaces:
Running
on
Zero
Running
on
Zero
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Any | |
import torch | |
from PIL import Image | |
from refiners.foundationals.clip.concepts import ConceptExtender | |
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( | |
MultiUpscaler, | |
UpscalerCheckpoints, | |
) | |
from esrgan_model import UpscalerESRGAN | |
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints): | |
esrgan: Path | None = None | |
class ESRGANUpscaler(MultiUpscaler): | |
def __init__( | |
self, | |
checkpoints: ESRGANUpscalerCheckpoints, | |
device: torch.device, | |
dtype: torch.dtype, | |
) -> None: | |
super().__init__(checkpoints=checkpoints, device=device, dtype=dtype) | |
self.esrgan = self.load_esrgan(checkpoints.esrgan) | |
def to(self, device: torch.device, dtype: torch.dtype): | |
self.esrgan.to(device=device, dtype=dtype) | |
self.sd = self.sd.to(device=device, dtype=dtype) | |
self.device = device | |
self.dtype = dtype | |
def load_esrgan(self, path: Path | None) -> UpscalerESRGAN | None: | |
if path is None: | |
return None | |
return UpscalerESRGAN(path, device=self.device, dtype=self.dtype) | |
def load_negative_embedding(self, path: Path | None, key: str | None) -> str: | |
if path is None: | |
return "" | |
embeddings: torch.Tensor | dict[str, Any] = torch.load( # type: ignore | |
path, weights_only=True, map_location=self.device | |
) | |
if isinstance(embeddings, dict): | |
assert ( | |
key is not None | |
), "Key must be provided to access the negative embedding." | |
key_sequence = key.split(".") | |
for key in key_sequence: | |
assert ( | |
key in embeddings | |
), f"Key {key} not found in the negative embedding dictionary. Available keys: {list(embeddings.keys())}" | |
embeddings = embeddings[key] | |
assert isinstance( | |
embeddings, torch.Tensor | |
), f"The negative embedding must be a tensor, found {type(embeddings)}." | |
assert ( | |
embeddings.ndim == 2 | |
), f"The negative embedding must be a 2D tensor, found {embeddings.ndim}D tensor." | |
extender = ConceptExtender(self.sd.clip_text_encoder) | |
negative_embedding_token = ", " | |
for i, embedding in enumerate(embeddings): | |
embedding = embedding.to(device=self.device, dtype=self.dtype) | |
extender.add_concept(token=f"<{i}>", embedding=embedding) | |
negative_embedding_token += f"<{i}> " | |
extender.inject() | |
return negative_embedding_token | |
def pre_upscale( | |
self, | |
image: Image.Image, | |
upscale_factor: float, | |
use_esrgan: bool = True, | |
use_esrgan_tiling: bool = True, | |
**_: Any, | |
) -> Image.Image: | |
if self.esrgan is None or not use_esrgan: | |
return super().pre_upscale(image=image, upscale_factor=upscale_factor) | |
width, height = image.size | |
if use_esrgan_tiling: | |
image = self.esrgan.upscale_with_tiling(image) | |
else: | |
image = self.esrgan.upscale_without_tiling(image) | |
return image.resize( | |
size=( | |
int(width * upscale_factor), | |
int(height * upscale_factor), | |
), | |
resample=Image.LANCZOS, | |
) | |