JoPmt's picture
Update app.py
89d8c13
from PIL import Image
import gradio as gr
import numpy as np
import random, os, gc, base64, io
import cv2
import torch
from accelerate import Accelerator
from transformers import pipeline
from diffusers.utils import load_image
from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline
from gradio_client import Client
accelerator = Accelerator(cpu=True)
depth_estimator = accelerator.prepare(pipeline("depth-estimation", model="Intel/dpt-hybrid-midas"))
pipe_prior = accelerator.prepare(KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float32))
pipe_prior = accelerator.prepare(pipe_prior.to("cpu"))
pipe = accelerator.prepare(KandinskyV22ControlnetPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float32))
pipe = accelerator.prepare(pipe.to("cpu"))
generator = torch.Generator("cpu").manual_seed(random.randint(1, 867346))
apol=[]
def make_hint(ipage, depth_estimator):
imak = depth_estimator(ipage)["depth"]
apol.append(imak)
imak = np.array(imak)
imak = imak[:, :, None]
imak = np.concatenate([imak, imak, imak], axis=2)
detected_map = torch.from_numpy(imak).float() / 255.0
hint = detected_map.permute(2, 0, 1)
return hint
def plex(goof,prompt):
gc.collect()
apol=[]
goof = load_image(goof).resize((512, 512))
goof = goof.convert("RGB")
goof.save('./gf.png', 'PNG')
##base64_string = ''
##with open('./gf.png', 'rb') as image_file:
## base64_string = base64.b64encode(image_file.read()).decode('utf-8')
hint = make_hint(goof, depth_estimator).unsqueeze(0).to("cpu")
negative_prior_prompt = "lowres,text,bad quality,jpeg artifacts,ugly,bad face,extra fingers,blurry,bad anatomy,extra limbs,fused fingers,long neck,watermark,signature"
image_emb, zero_image_emb = pipe_prior(prompt=prompt, negative_prompt=negative_prior_prompt, num_inference_steps=5,generator=generator).to_tuple()
imags = pipe(image_embeds=image_emb,negative_image_embeds=zero_image_emb,hint=hint,num_inference_steps=5,width=512,height=512,generator=generator).images[0]
apol.append(imags)
return apol
iface = gr.Interface(fn=plex,inputs=[gr.Image(type="filepath"),gr.Textbox()], outputs=gr.Gallery(columns=2), title="Img2Img_SkyV22CntrlNet_CPU", description="Running on CPU, very slow!")
iface.queue(max_size=1,api_open=False)
iface.launch(max_threads=1)