radames commited on
Commit
94b2c21
1 Parent(s): 4564155

enable live pose conditining (#6)

Browse files

- enable live pose conditining (60591c00398e95d73c7fa0c484177fcd18ac77e8)

Files changed (1) hide show
  1. app.py +103 -35
app.py CHANGED
@@ -3,6 +3,31 @@ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
3
  from diffusers import UniPCMultistepScheduler
4
  import gradio as gr
5
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Constants
8
  low_threshold = 100
@@ -28,41 +53,84 @@ pipe.enable_xformers_memory_efficient_attention()
28
  # Generator seed,
29
  generator = torch.manual_seed(0)
30
 
 
31
  def get_pose(image):
32
- return pose_model(image)
33
-
34
-
35
- def generate_images(image, prompt):
36
- pose = get_pose(image)
37
- output = pipe(
38
- prompt,
39
- pose,
40
- generator=generator,
41
- num_images_per_prompt=3,
42
- num_inference_steps=20,
43
- )
44
- all_outputs = []
45
- all_outputs.append(pose)
46
- for image in output.images:
47
- all_outputs.append(image)
48
- return all_outputs
49
-
50
-
51
- gr.Interface(
52
- generate_images,
53
- inputs=[
54
- gr.Image(type="pil"),
55
- gr.Textbox(
56
- label="Enter your prompt",
57
- max_lines=1,
58
- placeholder="best quality, extremely detailed",
59
- ),
60
- ],
61
- outputs=gr.Gallery().style(grid=[2], height="auto"),
62
- title="Generate controlled outputs with ControlNet and Stable Diffusion. ",
63
- description="This Space uses pose estimated lines as the additional conditioning.",
64
- examples=[["yoga1.jpeg", "best quality, extremely detailed"]],
65
- allow_flagging=False,
66
- ).launch(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
68
 
 
 
3
  from diffusers import UniPCMultistepScheduler
4
  import gradio as gr
5
  import torch
6
+ import base64
7
+ from io import BytesIO
8
+ from PIL import Image
9
+ # live conditioning
10
+ canvas_html = "<pose-canvas id='canvas-root' style='display:flex;max-width: 500px;margin: 0 auto;'></pose-canvas>"
11
+ load_js = """
12
+ async () => {
13
+ const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/pose-gradio.js"
14
+ fetch(url)
15
+ .then(res => res.text())
16
+ .then(text => {
17
+ const script = document.createElement('script');
18
+ script.type = "module"
19
+ script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
20
+ document.head.appendChild(script);
21
+ });
22
+ }
23
+ """
24
+ get_js_image = """
25
+ async (image_in_img, prompt, image_file_live_opt, live_conditioning) => {
26
+ const canvasEl = document.getElementById("canvas-root");
27
+ const data = canvasEl? canvasEl._data : null;
28
+ return [image_in_img, prompt, image_file_live_opt, data]
29
+ }
30
+ """
31
 
32
  # Constants
33
  low_threshold = 100
 
53
  # Generator seed,
54
  generator = torch.manual_seed(0)
55
 
56
+
57
  def get_pose(image):
58
+ return pose_model(image)
59
+
60
+
61
+ def generate_images(image, prompt, image_file_live_opt='file', live_conditioning=None):
62
+ if image is None and 'image' not in live_conditioning:
63
+ raise gr.Error("Please provide an image")
64
+ try:
65
+ if image_file_live_opt == 'file':
66
+ pose = get_pose(image)
67
+ elif image_file_live_opt == 'webcam':
68
+ base64_img = live_conditioning['image']
69
+ image_data = base64.b64decode(base64_img.split(',')[1])
70
+ pose = Image.open(BytesIO(image_data)).convert(
71
+ 'RGB').resize((512, 512))
72
+ output = pipe(
73
+ prompt,
74
+ pose,
75
+ generator=generator,
76
+ num_images_per_prompt=3,
77
+ num_inference_steps=20,
78
+ )
79
+ all_outputs = []
80
+ all_outputs.append(pose)
81
+ for image in output.images:
82
+ all_outputs.append(image)
83
+ return all_outputs
84
+ except Exception as e:
85
+ raise gr.Error(str(e))
86
+
87
+
88
+ def toggle(choice):
89
+ if choice == "file":
90
+ return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
91
+ elif choice == "webcam":
92
+ return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html)
93
+
94
+
95
+ with gr.Blocks() as blocks:
96
+ gr.Markdown("""
97
+ ## Generate Uncanny Faces with ControlNet Stable Diffusion
98
+ [Check out our blog to see how this was done (and train your own controlnet)](https://huggingface.co/blog/train-your-controlnet)
99
+ """)
100
+ with gr.Row():
101
+ live_conditioning = gr.JSON(value={}, visible=False)
102
+ with gr.Column():
103
+ image_file_live_opt = gr.Radio(["file", "webcam"], value="file",
104
+ label="How would you like to upload your image?")
105
+ image_in_img = gr.Image(source="upload", visible=True, type="pil")
106
+ canvas = gr.HTML(None, elem_id="canvas_html", visible=False)
107
+
108
+ image_file_live_opt.change(fn=toggle,
109
+ inputs=[image_file_live_opt],
110
+ outputs=[image_in_img, canvas],
111
+ queue=False)
112
+ prompt = gr.Textbox(
113
+ label="Enter your prompt",
114
+ max_lines=1,
115
+ placeholder="best quality, extremely detailed",
116
+ )
117
+ run_button = gr.Button("Generate")
118
+ with gr.Column():
119
+ gallery = gr.Gallery().style(grid=[2], height="auto")
120
+ run_button.click(fn=generate_images,
121
+ inputs=[image_in_img, prompt,
122
+ image_file_live_opt, live_conditioning],
123
+ outputs=[gallery],
124
+ _js=get_js_image)
125
+ blocks.load(None, None, None, _js=load_js)
126
 
127
+ gr.Examples(fn=generate_images,
128
+ examples=[
129
+ ["./yoga1.jpeg",
130
+ "best quality, extremely detailed"]
131
+ ],
132
+ inputs=[image_in_img, prompt],
133
+ outputs=[gallery],
134
+ cache_examples=True)
135
 
136
+ blocks.launch(debug=True)