angelahzyuan commited on
Commit
6f3e35c
1 Parent(s): f48abf3

Add application file

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ from diffusers import StableDiffusionPipeline, UNet2DConditionModel
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+
8
+
9
+ MODEL="/mnt/bn/ailab-yuningshen-psg/mlx/users/quanquan.gu/playground/trl/iter0_0_2.0e-5_linear200_new/checkpoints/checkpoint_0"
10
+ MODEL="/mnt/bn/ailab-yuningshen-psg/mlx/users/quanquan.gu/playground/trl/iter2_12h_5.0e-8_beta5_rep_wtie/checkpoints/checkpoint_10"
11
+ PROMPTS="/mnt/bn/ailab-yuningshen-psg/mlx/users/quanquan.gu/playground/trl/dataset/pickapic/pickapic_v2/validation_unique-00007_filtered.parquet"
12
+ NUM=1
13
+
14
+
15
+
16
+ def set_seed(seed=5775709):
17
+ random.seed(seed)
18
+ np.random.seed(seed)
19
+ torch.manual_seed(seed)
20
+ torch.cuda.manual_seed(seed)
21
+
22
+ set_seed()
23
+
24
+ def get_pipeline(device='cuda'):
25
+ model_id = "runwayml/stable-diffusion-v1-5"
26
+ #pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker = None, requires_safety_checker = False)
27
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
28
+
29
+ # load finetuned model
30
+ unet_id = MODEL
31
+ unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16)
32
+ pipe.unet = unet
33
+ pipe = pipe.to(device)
34
+ return pipe
35
+
36
+
37
+ def generate(prompt: str, num_images: int=5, guidance_scale=7.5):
38
+ pipe = get_pipeline()
39
+ generator = torch.Generator(pipe.device).manual_seed(5775709)
40
+ # Ensure num_images is an integer
41
+ num_images = int(num_images)
42
+ images = pipe(prompt, generator=generator, guidance_scale=guidance_scale, num_inference_steps=50, num_images_per_prompt=num_images).images
43
+ images = [x.resize((512, 512)) for x in images]
44
+ return images
45
+
46
+ def gen_image(args, image):
47
+ output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
48
+ return output
49
+
50
+
51
+
52
+
53
+ with gr.Blocks() as demo:
54
+ gr.Markdown("# SPIN-Diffusion 1.0 Demo")
55
+
56
+ with gr.Row():
57
+ prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type something...", lines=2)
58
+ generate_btn = gr.Button("Generate images")
59
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0, maximum=50, value=9, step=0.1)
60
+ num_images_input = gr.Number(label="Number of images", value=5, minimum=1, maximum=10, step=1)
61
+ gallery = gr.Gallery(label="Generated images", elem_id="gallery", columns=5, object_fit="contain")
62
+
63
+ # Define your example prompts
64
+ examples = [
65
+ ["The Eiffel Tower at sunset"],
66
+ ["A futuristic city skyline"],
67
+ ["A cat wearing a wizard hat"],
68
+ ["A futuristic city at sunset"],
69
+ ["A landscape with mountains and lakes"],
70
+ ["A portrait of a robot in Renaissance style"],
71
+ ]
72
+
73
+ # Add the Examples component linked to the prompt_input
74
+ gr.Examples(examples=examples, inputs=prompt_input, fn=generate, outputs=gallery)
75
+
76
+ generate_btn.click(fn=generate, inputs=[prompt_input, num_images_input, guidance_scale], outputs=gallery)
77
+
78
+ if __name__ == "__main__":
79
+ demo.launch(share=True)