File size: 3,425 Bytes
13531f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import torch
from PIL import Image
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
MultiUpscaler,
UpscalerCheckpoints,
)
from esrgan_model import UpscalerESRGAN
@dataclass(kw_only=True)
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
esrgan: Path | None = None
class ESRGANUpscaler(MultiUpscaler):
def __init__(
self,
checkpoints: ESRGANUpscalerCheckpoints,
device: torch.device,
dtype: torch.dtype,
) -> None:
super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
self.esrgan = self.load_esrgan(checkpoints.esrgan)
def to(self, device: torch.device, dtype: torch.dtype):
self.esrgan.to(device=device, dtype=dtype)
self.sd = self.sd.to(device=device, dtype=dtype)
self.device = device
self.dtype = dtype
def load_esrgan(self, path: Path | None) -> UpscalerESRGAN | None:
if path is None:
return None
return UpscalerESRGAN(path, device=self.device, dtype=self.dtype)
def load_negative_embedding(self, path: Path | None, key: str | None) -> str:
if path is None:
return ""
embeddings: torch.Tensor | dict[str, Any] = torch.load( # type: ignore
path, weights_only=True, map_location=self.device
)
if isinstance(embeddings, dict):
assert (
key is not None
), "Key must be provided to access the negative embedding."
key_sequence = key.split(".")
for key in key_sequence:
assert (
key in embeddings
), f"Key {key} not found in the negative embedding dictionary. Available keys: {list(embeddings.keys())}"
embeddings = embeddings[key]
assert isinstance(
embeddings, torch.Tensor
), f"The negative embedding must be a tensor, found {type(embeddings)}."
assert (
embeddings.ndim == 2
), f"The negative embedding must be a 2D tensor, found {embeddings.ndim}D tensor."
extender = ConceptExtender(self.sd.clip_text_encoder)
negative_embedding_token = ", "
for i, embedding in enumerate(embeddings):
embedding = embedding.to(device=self.device, dtype=self.dtype)
extender.add_concept(token=f"<{i}>", embedding=embedding)
negative_embedding_token += f"<{i}> "
extender.inject()
return negative_embedding_token
def pre_upscale(
self,
image: Image.Image,
upscale_factor: float,
use_esrgan: bool = True,
use_esrgan_tiling: bool = True,
**_: Any,
) -> Image.Image:
if self.esrgan is None or not use_esrgan:
return super().pre_upscale(image=image, upscale_factor=upscale_factor)
width, height = image.size
if use_esrgan_tiling:
image = self.esrgan.upscale_with_tiling(image)
else:
image = self.esrgan.upscale_without_tiling(image)
return image.resize(
size=(
int(width * upscale_factor),
int(height * upscale_factor),
),
resample=Image.LANCZOS,
)
|