API-SDXL-Flash / handler.py
UAI-Software's picture
Upload folder using huggingface_hub
868b5ab verified
import os
from typing import Dict, List, Any
import sys
rootDir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(rootDir)
from uaiDiffusers.common.imageRequest import ImageRequest
from diffusers import StableDiffusionXLPipeline, DPMSolverSinglestepScheduler
import torch
from uaiDiffusers.uaiDiffusers import ImagesToBase64
import torch
class EndpointHandler:
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
self.pipe = None
self.modelName = ""
base = "sd-community/sdxl-flash"
baseReq = ImageRequest()
baseReq.model = base
print(f"Loading model: {base}")
self.LoadModel(baseReq)
def LoadModel(self, request):
base = "sd-community/sdxl-flash"
if request.model == "default":
request.model = base
else:
base = request.model
if self.pipe is None:
del self.pipe
torch.cuda.empty_cache()
self.pipe = StableDiffusionXLPipeline.from_pretrained(base).to("cuda")
# Ensure sampler uses "trailing" timesteps.
self.pipe.scheduler = DPMSolverSinglestepScheduler.from_config(self.pipe.scheduler.config, timestep_spacing="trailing")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
input (:obj: `str` | `PIL.Image` | `np.array`)
seed (:obj: `int`)
prompt (:obj: `str`)
negative_prompt (:obj: `str`)
num_images_per_prompt (:obj: `int`)
steps (:obj: `int`)
guidance_scale (:obj: `float`)
width (:obj: `int`)
height (:obj: `int`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# inputs = data.pop("parameters", data)
request = ImageRequest.FromDict(data)
response = self.__runProcess__(request)
return response
def __runProcess__(self, request: ImageRequest) -> List[Dict[str, Any]]:
"""
Run SDXL Lightning pipeline
"""
self.LoadModel(request)
# Ensure using the same inference steps as the loaded model and CFG set to 0.
images = self.pipe(request.prompt, negative_prompt = request.negative_prompt, num_inference_steps=request.steps, guidance_scale=request.guidance_scale, num_images_per_prompt=request.num_images_per_prompt ).images
return {"media":[{"media":ImagesToBase64(img)} for img in images]}