[Minor] Use The generator function to generate a list
Browse files
app.py
CHANGED
@@ -273,7 +273,6 @@ def generate(
|
|
273 |
m_img.astype('float') / 2.0 * red).astype('uint8'))
|
274 |
|
275 |
|
276 |
-
|
277 |
mask_video_path = "mask.mp4"
|
278 |
fps = 30
|
279 |
with imageio.get_writer(mask_video_path, fps=fps) as video:
|
@@ -282,7 +281,45 @@ def generate(
|
|
282 |
|
283 |
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
|
284 |
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
def generate_list(
|
287 |
input_image: Image.Image,
|
288 |
generate_list: str,
|
@@ -322,9 +359,11 @@ def generate_list(
|
|
322 |
while generate_index < len(generate_list):
|
323 |
print(f'generate_index: {str(generate_index)}')
|
324 |
instruction = generate_list[generate_index]
|
|
|
|
|
325 |
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
326 |
cond = {}
|
327 |
-
input_image_torch = 2 * torch.tensor(np.array(input_image_copy
|
328 |
input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device)
|
329 |
cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
|
330 |
cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)]
|
@@ -351,8 +390,10 @@ def generate_list(
|
|
351 |
|
352 |
x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False)
|
353 |
x_1 = torch.where(x_1 > 0, 1, -1) # Thresholding step
|
|
|
|
|
354 |
|
355 |
-
if
|
356 |
seed += 1
|
357 |
retry_number +=1
|
358 |
if retry_number > max_retry:
|
@@ -384,20 +425,22 @@ def generate_list(
|
|
384 |
|
385 |
image_video.append((mix_image_np * 255).astype(np.uint8))
|
386 |
mix_image = Image.fromarray((mix_image_np * 255).astype(np.uint8)).convert('RGB')
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
|
393 |
-
|
394 |
-
fps = 2
|
395 |
-
with imageio.get_writer(image_video_path, fps=fps) as video:
|
396 |
-
for image in image_video:
|
397 |
-
video.append_data(image)
|
398 |
|
399 |
-
|
400 |
-
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image, mix_result_with_red_mask]
|
401 |
|
402 |
|
403 |
def reset():
|
@@ -553,4 +596,5 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
553 |
# demo.launch(share=True)
|
554 |
|
555 |
|
|
|
556 |
demo.queue().launch()
|
|
|
273 |
m_img.astype('float') / 2.0 * red).astype('uint8'))
|
274 |
|
275 |
|
|
|
276 |
mask_video_path = "mask.mp4"
|
277 |
fps = 30
|
278 |
with imageio.get_writer(mask_video_path, fps=fps) as video:
|
|
|
281 |
|
282 |
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
|
283 |
|
284 |
+
|
285 |
+
def single_generation(model_wrap_cfg, input_image_copy, instruction, steps, seed, text_cfg_scale, image_cfg_scale, height, width):
|
286 |
+
model.cuda()
|
287 |
+
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
288 |
+
cond = {}
|
289 |
+
input_image_torch = 2 * torch.tensor(np.array(input_image_copy.to(model.device))).float() / 255 - 1
|
290 |
+
input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device)
|
291 |
+
cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
|
292 |
+
cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)]
|
293 |
+
|
294 |
+
uncond = {}
|
295 |
+
uncond["c_crossattn"] = [null_token.to(model.device)]
|
296 |
+
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
|
297 |
+
|
298 |
+
sigmas = model_wrap.get_sigmas(steps).to(model.device)
|
299 |
+
|
300 |
+
extra_args = {
|
301 |
+
"cond": cond,
|
302 |
+
"uncond": uncond,
|
303 |
+
"text_cfg_scale": text_cfg_scale,
|
304 |
+
"image_cfg_scale": image_cfg_scale,
|
305 |
+
}
|
306 |
+
torch.manual_seed(seed)
|
307 |
+
z_0 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
|
308 |
+
z_1 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
|
309 |
+
|
310 |
+
z_0, z_1, _, _ = sample_euler_ancestral(model_wrap_cfg, z_0, z_1, sigmas, height, width, extra_args=extra_args)
|
311 |
+
|
312 |
+
x_0 = model.decode_first_stage(z_0)
|
313 |
+
|
314 |
+
x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False)
|
315 |
+
x_1 = torch.where(x_1 > 0, 1, -1) # Thresholding step
|
316 |
+
|
317 |
+
x_1_mean = torch.sum(x_1).item()/x_1.numel()
|
318 |
+
|
319 |
+
return x_0, x_1, x_1_mean
|
320 |
+
|
321 |
+
|
322 |
+
@spaces.GPU(duration=150)
|
323 |
def generate_list(
|
324 |
input_image: Image.Image,
|
325 |
generate_list: str,
|
|
|
359 |
while generate_index < len(generate_list):
|
360 |
print(f'generate_index: {str(generate_index)}')
|
361 |
instruction = generate_list[generate_index]
|
362 |
+
|
363 |
+
# x_0, x_1, x_1_mean = single_generation(model_wrap_cfg, input_image_copy, instruction, steps, seed, text_cfg_scale, image_cfg_scale, height, width)
|
364 |
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
365 |
cond = {}
|
366 |
+
input_image_torch = 2 * torch.tensor(np.array(input_image_copy)).float() / 255 - 1
|
367 |
input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device)
|
368 |
cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
|
369 |
cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)]
|
|
|
390 |
|
391 |
x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False)
|
392 |
x_1 = torch.where(x_1 > 0, 1, -1) # Thresholding step
|
393 |
+
|
394 |
+
x_1_mean = torch.sum(x_1).item()/x_1.numel()
|
395 |
|
396 |
+
if x_1_mean < -0.99:
|
397 |
seed += 1
|
398 |
retry_number +=1
|
399 |
if retry_number > max_retry:
|
|
|
425 |
|
426 |
image_video.append((mix_image_np * 255).astype(np.uint8))
|
427 |
mix_image = Image.fromarray((mix_image_np * 255).astype(np.uint8)).convert('RGB')
|
428 |
+
|
429 |
+
mix_result_with_red_mask = None
|
430 |
+
mask_video_path = None
|
431 |
+
image_video_path = None
|
432 |
+
edited_mask_copy = None
|
433 |
+
|
434 |
+
if generate_index == len(generate_list):
|
435 |
+
image_video_path = "image.mp4"
|
436 |
+
fps = 2
|
437 |
+
with imageio.get_writer(image_video_path, fps=fps) as video:
|
438 |
+
for image in image_video:
|
439 |
+
video.append_data(image)
|
440 |
|
441 |
+
yield [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image, mix_result_with_red_mask]
|
|
|
|
|
|
|
|
|
442 |
|
443 |
+
input_image_copy = mix_image
|
|
|
444 |
|
445 |
|
446 |
def reset():
|
|
|
596 |
# demo.launch(share=True)
|
597 |
|
598 |
|
599 |
+
# demo.queue().launch(enable_queue=True)
|
600 |
demo.queue().launch()
|