RamAnanth1 commited on
Commit
d07326d
1 Parent(s): bb7ef0f

Update app.py

Browse files

Add new tab for interactive sketch

Files changed (1) hide show
  1. app.py +58 -4
app.py CHANGED
@@ -36,7 +36,9 @@ ddim_sampler_scribble = DDIMSampler(scribble_model)
36
  def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
37
  # TODO: Add other control tasks
38
  if input_control == "Scribble":
39
- return process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
 
 
40
  return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
41
 
42
  def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
@@ -96,11 +98,51 @@ def process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image
96
 
97
  results = [x_samples[i] for i in range(num_samples)]
98
  return [255 - detected_map] + results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  block = gr.Blocks().queue()
101
  control_task_list = [
102
  "Canny Edge Map",
103
- "Scribble"
 
104
  ]
105
  with block:
106
  gr.Markdown("## Adding Conditional Control to Text-to-Image Diffusion Models")
@@ -111,10 +153,22 @@ with block:
111
  ''')
112
  with gr.Row():
113
  with gr.Column():
114
- input_image = gr.Image(source='upload', type="numpy")
115
- input_control = gr.Dropdown(control_task_list, value="Canny Edge Map", label="Control Task")
 
 
 
 
 
 
 
 
 
 
 
116
  prompt = gr.Textbox(label="Prompt")
117
  run_button = gr.Button(label="Run")
 
118
  with gr.Accordion("Advanced options", open=False):
119
  num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
120
  image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
 
36
  def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
37
  # TODO: Add other control tasks
38
  if input_control == "Scribble":
39
+ return process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
40
+ elif input_control == "Interactive Scribble":
41
+ return process_scribble_interactive(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
42
  return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
43
 
44
  def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
 
98
 
99
  results = [x_samples[i] for i in range(num_samples)]
100
  return [255 - detected_map] + results
101
+
102
+ def process_scribble_interactive(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta):
103
+ with torch.no_grad():
104
+ img = resize_image(HWC3(input_image['mask'][:, :, 0]), image_resolution)
105
+ H, W, C = img.shape
106
+
107
+ detected_map = np.zeros_like(img, dtype=np.uint8)
108
+ detected_map[np.min(img, axis=2) > 127] = 255
109
+
110
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
111
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
112
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
113
+
114
+ if seed == -1:
115
+ seed = random.randint(0, 65535)
116
+ seed_everything(seed)
117
+
118
+
119
+ cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
120
+ un_cond = {"c_concat": [control], "c_crossattn": [scribble.get_learned_conditioning([n_prompt] * num_samples)]}
121
+ shape = (4, H // 8, W // 8)
122
+
123
+
124
+ samples, intermediates = ddim_sampler_scribble.sample(ddim_steps, num_samples,
125
+ shape, cond, verbose=False, eta=eta,
126
+ unconditional_guidance_scale=scale,
127
+ unconditional_conditioning=un_cond)
128
+
129
+ x_samples = scribble_model.decode_first_stage(samples)
130
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
131
+
132
+ results = [x_samples[i] for i in range(num_samples)]
133
+ return [255 - detected_map] + results
134
+
135
+
136
+ def create_canvas(w, h):
137
+ new_control_options = ["Interactive Scribble"]
138
+ return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
139
+
140
 
141
  block = gr.Blocks().queue()
142
  control_task_list = [
143
  "Canny Edge Map",
144
+ "Scribble",
145
+ "Interactive Scribble"
146
  ]
147
  with block:
148
  gr.Markdown("## Adding Conditional Control to Text-to-Image Diffusion Models")
 
153
  ''')
154
  with gr.Row():
155
  with gr.Column():
156
+ with gr.Tab("Upload"):
157
+ input_image = gr.Image(source='upload', type="numpy")
158
+
159
+ with gr.Tab("Interactive Scribble"):
160
+ canvas_width = gr.Slider(label="Canvas Width", minimum=256, maximum=1024, value=512, step=1)
161
+ canvas_height = gr.Slider(label="Canvas Height", minimum=256, maximum=1024, value=512, step=1)
162
+ create_button = gr.Button(label="Start", value='Open drawing canvas!')
163
+ input_image = gr.Image(source='upload', type='numpy', tool='sketch')
164
+ gr.Markdown(value='Do not forget to change your brush width to make it thinner. (Gradio do not allow developers to set brush width so you need to do it manually.) '
165
+ 'Just click on the small pencil icon in the upper right corner of the above block.')
166
+ create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image])
167
+
168
+ input_control = gr.Dropdown(control_task_list, value="Scribble", label="Control Task")
169
  prompt = gr.Textbox(label="Prompt")
170
  run_button = gr.Button(label="Run")
171
+
172
  with gr.Accordion("Advanced options", open=False):
173
  num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
174
  image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)