Spaces:
Runtime error
Runtime error
angelahzyuan
commited on
Commit
•
6f3e35c
1
Parent(s):
f48abf3
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
MODEL="/mnt/bn/ailab-yuningshen-psg/mlx/users/quanquan.gu/playground/trl/iter0_0_2.0e-5_linear200_new/checkpoints/checkpoint_0"
|
10 |
+
MODEL="/mnt/bn/ailab-yuningshen-psg/mlx/users/quanquan.gu/playground/trl/iter2_12h_5.0e-8_beta5_rep_wtie/checkpoints/checkpoint_10"
|
11 |
+
PROMPTS="/mnt/bn/ailab-yuningshen-psg/mlx/users/quanquan.gu/playground/trl/dataset/pickapic/pickapic_v2/validation_unique-00007_filtered.parquet"
|
12 |
+
NUM=1
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def set_seed(seed=5775709):
|
17 |
+
random.seed(seed)
|
18 |
+
np.random.seed(seed)
|
19 |
+
torch.manual_seed(seed)
|
20 |
+
torch.cuda.manual_seed(seed)
|
21 |
+
|
22 |
+
set_seed()
|
23 |
+
|
24 |
+
def get_pipeline(device='cuda'):
|
25 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
26 |
+
#pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker = None, requires_safety_checker = False)
|
27 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
28 |
+
|
29 |
+
# load finetuned model
|
30 |
+
unet_id = MODEL
|
31 |
+
unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16)
|
32 |
+
pipe.unet = unet
|
33 |
+
pipe = pipe.to(device)
|
34 |
+
return pipe
|
35 |
+
|
36 |
+
|
37 |
+
def generate(prompt: str, num_images: int=5, guidance_scale=7.5):
|
38 |
+
pipe = get_pipeline()
|
39 |
+
generator = torch.Generator(pipe.device).manual_seed(5775709)
|
40 |
+
# Ensure num_images is an integer
|
41 |
+
num_images = int(num_images)
|
42 |
+
images = pipe(prompt, generator=generator, guidance_scale=guidance_scale, num_inference_steps=50, num_images_per_prompt=num_images).images
|
43 |
+
images = [x.resize((512, 512)) for x in images]
|
44 |
+
return images
|
45 |
+
|
46 |
+
def gen_image(args, image):
|
47 |
+
output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
48 |
+
return output
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
with gr.Blocks() as demo:
|
54 |
+
gr.Markdown("# SPIN-Diffusion 1.0 Demo")
|
55 |
+
|
56 |
+
with gr.Row():
|
57 |
+
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type something...", lines=2)
|
58 |
+
generate_btn = gr.Button("Generate images")
|
59 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0, maximum=50, value=9, step=0.1)
|
60 |
+
num_images_input = gr.Number(label="Number of images", value=5, minimum=1, maximum=10, step=1)
|
61 |
+
gallery = gr.Gallery(label="Generated images", elem_id="gallery", columns=5, object_fit="contain")
|
62 |
+
|
63 |
+
# Define your example prompts
|
64 |
+
examples = [
|
65 |
+
["The Eiffel Tower at sunset"],
|
66 |
+
["A futuristic city skyline"],
|
67 |
+
["A cat wearing a wizard hat"],
|
68 |
+
["A futuristic city at sunset"],
|
69 |
+
["A landscape with mountains and lakes"],
|
70 |
+
["A portrait of a robot in Renaissance style"],
|
71 |
+
]
|
72 |
+
|
73 |
+
# Add the Examples component linked to the prompt_input
|
74 |
+
gr.Examples(examples=examples, inputs=prompt_input, fn=generate, outputs=gallery)
|
75 |
+
|
76 |
+
generate_btn.click(fn=generate, inputs=[prompt_input, num_images_input, guidance_scale], outputs=gallery)
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
demo.launch(share=True)
|