Karumoon commited on
Commit
9b20d14
1 Parent(s): 36a452f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -114
app.py CHANGED
@@ -16,55 +16,20 @@ from cldm.ddim_hacked import DDIMSampler
16
  import dlib
17
  from PIL import Image, ImageDraw
18
 
19
- if torch.cuda.is_available():
20
- device = torch.device("cuda")
21
- else:
22
- device = torch.device("cpu")
23
-
24
  model = create_model('./models/cldm_v15.yaml').cpu()
25
- model.load_state_dict(load_state_dict(
26
- './models/control_sd15_landmarks.pth', location='cpu'))
27
- model = model.to(device)
28
  ddim_sampler = DDIMSampler(model)
29
 
30
  detector = dlib.get_frontal_face_detector()
31
  predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
32
 
33
-
34
- canvas_html = "<face-canvas id='canvas-root' data-mode='points' style='display:flex;max-width: 500px;margin: 0 auto;'></face-canvas>"
35
- load_js = """
36
- async () => {
37
- const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/face-canvas.js"
38
- fetch(url)
39
- .then(res => res.text())
40
- .then(text => {
41
- const script = document.createElement('script');
42
- script.type = "module"
43
- script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
44
- document.head.appendChild(script);
45
- });
46
- }
47
- """
48
- get_js_image = """
49
- async (input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, landmark_direct_mode, strength, scale, seed, eta, image_file_live_opt) => {
50
- const canvasEl = document.getElementById("canvas-root");
51
- const imageData = canvasEl? canvasEl._data : null;
52
- if(image_file_live_opt === 'webcam'){
53
- input_image = imageData['image']
54
- landmark_direct_mode = true
55
- }
56
- return [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, landmark_direct_mode, strength, scale, seed, eta, image_file_live_opt]
57
- }
58
- """
59
-
60
-
61
  def draw_landmarks(image, landmarks, color="white", radius=2.5):
62
  draw = ImageDraw.Draw(image)
63
  for dot in landmarks:
64
  x, y = dot
65
  draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill=color)
66
 
67
-
68
  def get_68landmarks_img(img):
69
  gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
70
  faces = detector(gray)
@@ -80,14 +45,7 @@ def get_68landmarks_img(img):
80
  con_img = np.array(con_img)
81
  return con_img
82
 
83
-
84
- def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, landmark_direct_mode, strength, scale, seed, eta, image_file_live_opt="file"):
85
- input_image = input_image.convert('RGB')
86
- input_image = np.array(input_image)
87
- input_image = np.flip(input_image, axis=2)
88
- print('input_image.shape', input_image.shape)
89
- # Limit the number of samples to 2 for Spaces only
90
- num_samples = min(num_samples, 2)
91
  with torch.no_grad():
92
  img = resize_image(HWC3(input_image), image_resolution)
93
  H, W, C = img.shape
@@ -98,29 +56,25 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
98
  detected_map = get_68landmarks_img(img)
99
  detected_map = HWC3(detected_map)
100
 
101
- control = torch.from_numpy(
102
- detected_map.copy()).float().to(device) / 255.0
103
  control = torch.stack([control for _ in range(num_samples)], dim=0)
104
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
105
 
106
  if seed == -1:
107
- seed = random.randint(0, 2**32 - 1)
108
  seed_everything(seed)
109
 
110
  if config.save_memory:
111
  model.low_vram_shift(is_diffusing=False)
112
 
113
- cond = {"c_concat": [control], "c_crossattn": [
114
- model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
115
- un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [
116
- model.get_learned_conditioning([n_prompt] * num_samples)]}
117
  shape = (4, H // 8, W // 8)
118
 
119
  if config.save_memory:
120
  model.low_vram_shift(is_diffusing=True)
121
 
122
- model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
123
- [strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
124
  samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
125
  shape, cond, verbose=False, eta=eta,
126
  unconditional_guidance_scale=scale,
@@ -130,82 +84,38 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
130
  model.low_vram_shift(is_diffusing=False)
131
 
132
  x_samples = model.decode_first_stage(samples)
133
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')
134
- * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
135
 
136
  results = [x_samples[i] for i in range(num_samples)]
137
-
138
  return [255 - detected_map] + results
139
 
140
 
141
- def toggle(choice):
142
- if choice == "file":
143
- return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
144
- elif choice == "webcam":
145
- return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html)
146
-
147
-
148
  block = gr.Blocks().queue()
149
  with block:
150
- live_conditioning = gr.JSON(value={}, visible=False)
151
  with gr.Row():
152
  gr.Markdown("## Control Stable Diffusion with Face Landmarks")
153
  with gr.Row():
154
  with gr.Column():
155
- image_file_live_opt = gr.Radio(["file", "webcam"], value="file",
156
- label="How would you like to upload your image?")
157
- input_image = gr.Image(source="upload", visible=True, type="pil")
158
- canvas = gr.HTML(None, elem_id="canvas_html", visible=False)
159
-
160
- image_file_live_opt.change(fn=toggle,
161
- inputs=[image_file_live_opt],
162
- outputs=[input_image, canvas],
163
- queue=False)
164
-
165
  prompt = gr.Textbox(label="Prompt")
166
  run_button = gr.Button(label="Run")
167
  with gr.Accordion("Advanced options", open=False):
168
- num_samples = gr.Slider(
169
- label="Images", minimum=1, maximum=2, value=1, step=1)
170
- image_resolution = gr.Slider(
171
- label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
172
- strength = gr.Slider(
173
- label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
174
  guess_mode = gr.Checkbox(label='Guess Mode', value=False)
175
- landmark_direct_mode = gr.Checkbox(
176
- label='Input Landmark Directly', value=False)
177
- ddim_steps = gr.Slider(
178
- label="Steps", minimum=1, maximum=100, value=20, step=1)
179
- scale = gr.Slider(label="Guidance Scale",
180
- minimum=0.1, maximum=30.0, value=9.0, step=0.1)
181
- seed = gr.Slider(label="Seed", minimum=-1,
182
- maximum=2147483647, step=1, randomize=True)
183
  eta = gr.Number(label="eta (DDIM)", value=0.0)
184
- a_prompt = gr.Textbox(
185
- label="Added Prompt", value='best quality, extremely detailed')
186
  n_prompt = gr.Textbox(label="Negative Prompt",
187
  value='cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
188
  with gr.Column():
189
- result_gallery = gr.Gallery(
190
- label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
191
- ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution,
192
- ddim_steps, guess_mode, landmark_direct_mode, strength, scale, seed, eta]
193
-
194
- gr.Examples(fn=process, examples=[
195
- ["examples/image0.jpg", "a silly clown face", "best quality, extremely detailed",
196
- "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
197
- ["examples/image1.png", "a photo of a woman wearing glasses", "best quality, extremely detailed",
198
- "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
199
- ["examples/image2.png", "a silly portrait of man with head tilted and a beautiful hair on the side", "best quality, extremely detailed",
200
- "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
201
- ["examples/image3.png", "portrait handsome men", "best quality, extremely detailed",
202
- "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
203
- ["examples/image4.jpg", "a beautiful woman looking at the sky", "best quality, extremely detailed",
204
- "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
205
- ], inputs=ips, outputs=[result_gallery], cache_examples=True)
206
- run_button.click(fn=process, inputs=ips + [image_file_live_opt],
207
- outputs=[result_gallery], _js=get_js_image)
208
- block.load(None, None, None, _js=load_js)
209
-
210
-
211
- block.launch()
 
16
  import dlib
17
  from PIL import Image, ImageDraw
18
 
 
 
 
 
 
19
  model = create_model('./models/cldm_v15.yaml').cpu()
20
+ model.load_state_dict(load_state_dict('./models/control_sd15_landmarks.pth', location='cuda'))
21
+ model = model.cuda()
 
22
  ddim_sampler = DDIMSampler(model)
23
 
24
  detector = dlib.get_frontal_face_detector()
25
  predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def draw_landmarks(image, landmarks, color="white", radius=2.5):
28
  draw = ImageDraw.Draw(image)
29
  for dot in landmarks:
30
  x, y = dot
31
  draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill=color)
32
 
 
33
  def get_68landmarks_img(img):
34
  gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
35
  faces = detector(gray)
 
45
  con_img = np.array(con_img)
46
  return con_img
47
 
48
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, landmark_direct_mode, strength, scale, seed, eta):
 
 
 
 
 
 
 
49
  with torch.no_grad():
50
  img = resize_image(HWC3(input_image), image_resolution)
51
  H, W, C = img.shape
 
56
  detected_map = get_68landmarks_img(img)
57
  detected_map = HWC3(detected_map)
58
 
59
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
 
60
  control = torch.stack([control for _ in range(num_samples)], dim=0)
61
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
62
 
63
  if seed == -1:
64
+ seed = random.randint(0, 65535)
65
  seed_everything(seed)
66
 
67
  if config.save_memory:
68
  model.low_vram_shift(is_diffusing=False)
69
 
70
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
71
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
 
 
72
  shape = (4, H // 8, W // 8)
73
 
74
  if config.save_memory:
75
  model.low_vram_shift(is_diffusing=True)
76
 
77
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
 
78
  samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
79
  shape, cond, verbose=False, eta=eta,
80
  unconditional_guidance_scale=scale,
 
84
  model.low_vram_shift(is_diffusing=False)
85
 
86
  x_samples = model.decode_first_stage(samples)
87
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
 
88
 
89
  results = [x_samples[i] for i in range(num_samples)]
 
90
  return [255 - detected_map] + results
91
 
92
 
 
 
 
 
 
 
 
93
  block = gr.Blocks().queue()
94
  with block:
 
95
  with gr.Row():
96
  gr.Markdown("## Control Stable Diffusion with Face Landmarks")
97
  with gr.Row():
98
  with gr.Column():
99
+ input_image = gr.Image(source='upload', type="numpy")
 
 
 
 
 
 
 
 
 
100
  prompt = gr.Textbox(label="Prompt")
101
  run_button = gr.Button(label="Run")
102
  with gr.Accordion("Advanced options", open=False):
103
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
104
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
105
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
 
 
 
106
  guess_mode = gr.Checkbox(label='Guess Mode', value=False)
107
+ landmark_direct_mode = gr.Checkbox(label='Input Landmark Directly', value=False)
108
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
109
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
110
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
 
 
 
 
111
  eta = gr.Number(label="eta (DDIM)", value=0.0)
112
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
 
113
  n_prompt = gr.Textbox(label="Negative Prompt",
114
  value='cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
115
  with gr.Column():
116
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
117
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, landmark_direct_mode, strength, scale, seed, eta]
118
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
119
+
120
+
121
+ block.launch(server_name='0.0.0.0')