radames commited on
Commit
c57f7ae
1 Parent(s): 46aba97

optional - Use retinaface for face detection

Browse files

use https://pypi.org/project/retinaface-py/ by the same author from SPIGA

Files changed (2) hide show
  1. app.py +30 -19
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,27 +1,30 @@
1
  import gradio as gr
2
  import torch
3
- import dlib
4
  import numpy as np
5
  import PIL
6
  import base64
7
  from io import BytesIO
8
  from PIL import Image
9
- # Only used to convert to gray, could do it differently and remove this big dependency
10
- import cv2
11
 
12
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
13
  from diffusers import UniPCMultistepScheduler
14
 
15
  from spiga.inference.config import ModelConfig
16
  from spiga.inference.framework import SPIGAFramework
 
17
 
18
  import matplotlib.pyplot as plt
19
  from matplotlib.path import Path
20
  import matplotlib.patches as patches
21
 
22
  # Bounding boxes
23
- face_detector = dlib.get_frontal_face_detector()
24
-
 
 
 
25
  # Landmark extraction
26
  spiga_extractor = SPIGAFramework(ModelConfig("300wpublic"))
27
 
@@ -59,14 +62,19 @@ async (image_in_img, prompt, image_file_live_opt, live_conditioning) => {
59
  }
60
  """
61
 
 
62
  def get_bounding_box(image):
63
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
64
- faces = face_detector(gray)
65
- if len(faces) == 0:
66
- raise Exception("No face detected in image")
67
- face = faces[0]
68
- bbox = [face.left(), face.top(), face.width(), face.height()]
69
- return bbox
 
 
 
 
70
 
71
 
72
  def get_landmarks(image, bbox):
@@ -143,9 +151,9 @@ def get_conditioning(image):
143
 
144
 
145
  def generate_images(image_in_img, prompt, image_file_live_opt='file', live_conditioning=None):
 
146
  if image_in_img is None and 'image' not in live_conditioning:
147
  raise gr.Error("Please provide an image")
148
-
149
  try:
150
  if image_file_live_opt == 'file':
151
  conditioning = get_conditioning(image_in_img)
@@ -166,29 +174,31 @@ def generate_images(image_in_img, prompt, image_file_live_opt='file', live_condi
166
  except Exception as e:
167
  raise gr.Error(str(e))
168
 
 
169
  def toggle(choice):
170
  if choice == "file":
171
  return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
172
  elif choice == "webcam":
173
  return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html)
174
 
 
175
  with gr.Blocks() as blocks:
176
  gr.Markdown("""
177
  ## Generate Uncanny Faces with ControlNet Stable Diffusion
178
  [Check out our blog to see how this was done (and train your own controlnet)](https://huggingface.co/blog/train-your-controlnet)
179
  """)
180
  with gr.Row():
181
- live_conditioning = gr.JSON(value={}, visible=False)
182
  with gr.Column():
183
  image_file_live_opt = gr.Radio(["file", "webcam"], value="file",
184
- label="How would you like to upload your image?")
185
  image_in_img = gr.Image(source="upload", visible=True, type="pil")
186
  canvas = gr.HTML(None, elem_id="canvas_html", visible=False)
187
 
188
  image_file_live_opt.change(fn=toggle,
189
- inputs=[image_file_live_opt],
190
- outputs=[image_in_img, canvas],
191
- queue=False)
192
  prompt = gr.Textbox(
193
  label="Enter your prompt",
194
  max_lines=1,
@@ -198,7 +208,8 @@ with gr.Blocks() as blocks:
198
  with gr.Column():
199
  gallery = gr.Gallery().style(grid=[2], height="auto")
200
  run_button.click(fn=generate_images,
201
- inputs=[image_in_img, prompt, image_file_live_opt, live_conditioning],
 
202
  outputs=[gallery],
203
  _js=get_js_image)
204
  blocks.load(None, None, None, _js=load_js)
 
1
  import gradio as gr
2
  import torch
 
3
  import numpy as np
4
  import PIL
5
  import base64
6
  from io import BytesIO
7
  from PIL import Image
8
+ # import for face detection
9
+ import retinaface
10
 
11
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
12
  from diffusers import UniPCMultistepScheduler
13
 
14
  from spiga.inference.config import ModelConfig
15
  from spiga.inference.framework import SPIGAFramework
16
+ import spiga.demo.analyze.track.retinasort.config as cfg
17
 
18
  import matplotlib.pyplot as plt
19
  from matplotlib.path import Path
20
  import matplotlib.patches as patches
21
 
22
  # Bounding boxes
23
+ config = cfg.cfg_retinasort
24
+ face_detector = retinaface.RetinaFaceDetector(model=config['retina']['model_name'],
25
+ device='cuda' if torch.cuda.is_available() else 'cpu',
26
+ extra_features=config['retina']['extra_features'],
27
+ cfg_postreat=config['retina']['postreat'])
28
  # Landmark extraction
29
  spiga_extractor = SPIGAFramework(ModelConfig("300wpublic"))
30
 
 
62
  }
63
  """
64
 
65
+
66
  def get_bounding_box(image):
67
+ pil_image = Image.fromarray(image)
68
+ face_detector.set_input_shape(pil_image.size[1], pil_image.size[0])
69
+ features = face_detector.inference(pil_image)
70
+
71
+ if (features is None) and (len(features['bbox']) <= 0):
72
+ raise Exception("No face detected")
73
+ # get the first face detected
74
+ bbox = features['bbox'][0]
75
+ x1, y1, x2, y2 = bbox[:4]
76
+ bbox_wh = [x1, y1, x2-x1, y2-y1]
77
+ return bbox_wh
78
 
79
 
80
  def get_landmarks(image, bbox):
 
151
 
152
 
153
  def generate_images(image_in_img, prompt, image_file_live_opt='file', live_conditioning=None):
154
+ print(image_in_img, prompt, image_file_live_opt, live_conditioning)
155
  if image_in_img is None and 'image' not in live_conditioning:
156
  raise gr.Error("Please provide an image")
 
157
  try:
158
  if image_file_live_opt == 'file':
159
  conditioning = get_conditioning(image_in_img)
 
174
  except Exception as e:
175
  raise gr.Error(str(e))
176
 
177
+
178
  def toggle(choice):
179
  if choice == "file":
180
  return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
181
  elif choice == "webcam":
182
  return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html)
183
 
184
+
185
  with gr.Blocks() as blocks:
186
  gr.Markdown("""
187
  ## Generate Uncanny Faces with ControlNet Stable Diffusion
188
  [Check out our blog to see how this was done (and train your own controlnet)](https://huggingface.co/blog/train-your-controlnet)
189
  """)
190
  with gr.Row():
191
+ live_conditioning = gr.JSON(value={}, visible=False)
192
  with gr.Column():
193
  image_file_live_opt = gr.Radio(["file", "webcam"], value="file",
194
+ label="How would you like to upload your image?")
195
  image_in_img = gr.Image(source="upload", visible=True, type="pil")
196
  canvas = gr.HTML(None, elem_id="canvas_html", visible=False)
197
 
198
  image_file_live_opt.change(fn=toggle,
199
+ inputs=[image_file_live_opt],
200
+ outputs=[image_in_img, canvas],
201
+ queue=False)
202
  prompt = gr.Textbox(
203
  label="Enter your prompt",
204
  max_lines=1,
 
208
  with gr.Column():
209
  gallery = gr.Gallery().style(grid=[2], height="auto")
210
  run_button.click(fn=generate_images,
211
+ inputs=[image_in_img, prompt,
212
+ image_file_live_opt, live_conditioning],
213
  outputs=[gallery],
214
  _js=get_js_image)
215
  blocks.load(None, None, None, _js=load_js)
requirements.txt CHANGED
@@ -7,3 +7,4 @@ dlib
7
  opencv-python
8
  matplotlib
9
  Pillow
 
 
7
  opencv-python
8
  matplotlib
9
  Pillow
10
+ retinaface-py>=0.0.2