RamAnanth1 commited on
Commit
3b1f734
1 Parent(s): efe3c52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -41
app.py CHANGED
@@ -23,22 +23,19 @@ canny_model = create_model('./models/cldm_v15.yaml')
23
  canny_model.load_state_dict(load_state_dict(cached_download(
24
  hf_hub_url(REPO_ID, canny_checkpoint)
25
  ), location='cpu'))
26
- ddim_sampler_canny = DDIMSampler(canny_model)
27
 
28
 
29
- pose_model = create_model('./models/cldm_v15.yaml')
30
- pose_model.load_state_dict(load_state_dict(cached_download(
31
- hf_hub_url(REPO_ID, pose_checkpoint)
32
- ), location='cpu'))
33
- ddim_sampler_pose = DDIMSampler(pose_model)
34
 
35
  def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
36
  # TODO: Add other control tasks
37
- if input_control == "Canny Edge Map":
38
- return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
39
- else:
40
- return process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
41
-
42
  def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
43
  with torch.no_grad():
44
  img = resize_image(HWC3(input_image), image_resolution)
@@ -67,40 +64,11 @@ def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_re
67
  results = [x_samples[i] for i in range(num_samples)]
68
  return [255 - detected_map] + results
69
 
70
- def process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta):
71
- with torch.no_grad():
72
- input_image = HWC3(input_image)
73
- detected_map, _ = apply_openpose(resize_image(input_image, image_resolution))
74
- detected_map = HWC3(detected_map)
75
- img = resize_image(input_image, image_resolution)
76
- H, W, C = img.shape
77
-
78
- detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
79
-
80
- control = torch.from_numpy(detected_map.copy()).float() / 255.0
81
- control = torch.stack([control for _ in range(num_samples)], dim=0)
82
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
83
 
84
- seed_everything(seed)
85
-
86
- cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
87
- un_cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([n_prompt] * num_samples)]}
88
- shape = (4, H // 8, W // 8)
89
-
90
- samples, intermediates = ddim_sampler_pose.sample(ddim_steps, num_samples,
91
- shape, cond, verbose=False, eta=eta,
92
- unconditional_guidance_scale=scale,
93
- unconditional_conditioning=un_cond)
94
- x_samples = pose_model.decode_first_stage(samples)
95
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
96
-
97
- results = [x_samples[i] for i in range(num_samples)]
98
- return [detected_map] + results
99
 
100
  block = gr.Blocks().queue()
101
  control_task_list = [
102
- "Canny Edge Map",
103
- "Human Pose"
104
  ]
105
  with block:
106
  gr.Markdown("## Adding Conditional Control to Text-to-Image Diffusion Models")
 
23
  canny_model.load_state_dict(load_state_dict(cached_download(
24
  hf_hub_url(REPO_ID, canny_checkpoint)
25
  ), location='cpu'))
26
+ ddim_sampler = DDIMSampler(canny_model)
27
 
28
 
29
+ # pose_model = create_model('./models/cldm_v15.yaml')
30
+ # pose_model.load_state_dict(load_state_dict(cached_download(
31
+ # hf_hub_url(REPO_ID, pose_checkpoint)
32
+ # ), location='cpu'))
33
+ # ddim_sampler_pose = DDIMSampler(pose_model)
34
 
35
  def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
36
  # TODO: Add other control tasks
37
+ return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
38
+
 
 
 
39
  def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
40
  with torch.no_grad():
41
  img = resize_image(HWC3(input_image), image_resolution)
 
64
  results = [x_samples[i] for i in range(num_samples)]
65
  return [255 - detected_map] + results
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  block = gr.Blocks().queue()
70
  control_task_list = [
71
+ "Canny Edge Map"
 
72
  ]
73
  with block:
74
  gr.Markdown("## Adding Conditional Control to Text-to-Image Diffusion Models")