RamAnanth1 commited on
Commit
1287f22
1 Parent(s): c6b395b

Update app.py

Browse files

Attempt at adding scribble checkpoint

Files changed (1) hide show
  1. app.py +39 -7
app.py CHANGED
@@ -16,7 +16,7 @@ from huggingface_hub import hf_hub_url, cached_download
16
 
17
  REPO_ID = "lllyasviel/ControlNet"
18
  canny_checkpoint = "models/control_sd15_canny.pth"
19
- pose_checkpoint = "models/control_sd15_openpose.pth"
20
 
21
  canny_model = create_model('./models/cldm_v15.yaml')
22
  canny_model.load_state_dict(load_state_dict(cached_download(
@@ -26,14 +26,17 @@ canny_model = canny_model.cuda()
26
  ddim_sampler = DDIMSampler(canny_model)
27
 
28
 
29
- # pose_model = create_model('./models/cldm_v15.yaml')
30
- # pose_model.load_state_dict(load_state_dict(cached_download(
31
- # hf_hub_url(REPO_ID, pose_checkpoint)
32
- # ), location='cpu'))
33
- # ddim_sampler_pose = DDIMSampler(pose_model)
 
34
 
35
  def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
36
  # TODO: Add other control tasks
 
 
37
  return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
38
 
39
  def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
@@ -64,11 +67,40 @@ def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_re
64
  results = [x_samples[i] for i in range(num_samples)]
65
  return [255 - detected_map] + results
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  block = gr.Blocks().queue()
70
  control_task_list = [
71
- "Canny Edge Map"
 
72
  ]
73
  with block:
74
  gr.Markdown("## Adding Conditional Control to Text-to-Image Diffusion Models")
 
16
 
17
  REPO_ID = "lllyasviel/ControlNet"
18
  canny_checkpoint = "models/control_sd15_canny.pth"
19
+ scribble_checkpoint = "models/control_sd15_scribble.pth"
20
 
21
  canny_model = create_model('./models/cldm_v15.yaml')
22
  canny_model.load_state_dict(load_state_dict(cached_download(
 
26
  ddim_sampler = DDIMSampler(canny_model)
27
 
28
 
29
+ scribble_model = create_model('./models/cldm_v15.yaml')
30
+ sribbble_model.load_state_dict(load_state_dict(cached_download(
31
+ hf_hub_url(REPO_ID, scribble_checkpoint)
32
+ ), location='cpu'))
33
+ scribble_model = canny_model.cuda()
34
+ ddim_sampler_scribble = DDIMSampler(scribble_model)
35
 
36
  def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
37
  # TODO: Add other control tasks
38
+ if input_control == "Scribble":
39
+ return process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
40
  return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
41
 
42
  def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
 
67
  results = [x_samples[i] for i in range(num_samples)]
68
  return [255 - detected_map] + results
69
 
70
+ def process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta):
71
+ with torch.no_grad():
72
+ img = resize_image(HWC3(input_image), image_resolution)
73
+ H, W, C = img.shape
74
+
75
+ detected_map = np.zeros_like(img, dtype=np.uint8)
76
+ detected_map[np.min(img, axis=2) < 127] = 255
77
+
78
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
79
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
80
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
81
+
82
+ seed_everything(seed)
83
+
84
+ cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
85
+ un_cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([n_prompt] * num_samples)]}
86
+ shape = (4, H // 8, W // 8)
87
+
88
 
89
+ samples, intermediates = ddim_sampler_scribble.sample(ddim_steps, num_samples,
90
+ shape, cond, verbose=False, eta=eta,
91
+ unconditional_guidance_scale=scale,
92
+ unconditional_conditioning=un_cond)
93
+
94
+ x_samples = scribble_model.decode_first_stage(samples)
95
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
96
+
97
+ results = [x_samples[i] for i in range(num_samples)]
98
+ return [255 - detected_map] + results
99
 
100
  block = gr.Blocks().queue()
101
  control_task_list = [
102
+ "Canny Edge Map",
103
+ "Scribble"
104
  ]
105
  with block:
106
  gr.Markdown("## Adding Conditional Control to Text-to-Image Diffusion Models")