RamAnanth1 commited on
Commit
efe3c52
1 Parent(s): ad64d06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -17
app.py CHANGED
@@ -14,19 +14,30 @@ from annotator.openpose import apply_openpose
14
  from cldm.model import create_model, load_state_dict
15
 
16
  from huggingface_hub import hf_hub_url, cached_download
 
17
  REPO_ID = "lllyasviel/ControlNet"
18
- FILENAME = "models/control_sd15_canny.pth"
 
19
 
20
- model = create_model('./models/cldm_v15.yaml')
21
- model.load_state_dict(load_state_dict(cached_download(
22
- hf_hub_url(REPO_ID, FILENAME)
23
  ), location='cpu'))
24
- ddim_sampler = DDIMSampler(model)
 
 
 
 
 
 
 
25
 
26
  def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
27
  # TODO: Add other control tasks
28
- return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
29
-
 
 
30
 
31
  def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
32
  with torch.no_grad():
@@ -42,24 +53,24 @@ def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_re
42
 
43
  seed_everything(seed)
44
 
45
- cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
46
- un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
47
  shape = (4, H // 8, W // 8)
48
 
49
- samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
50
  shape, cond, verbose=False, eta=eta,
51
  unconditional_guidance_scale=scale,
52
  unconditional_conditioning=un_cond)
53
- x_samples = model.decode_first_stage(samples)
54
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
55
 
56
  results = [x_samples[i] for i in range(num_samples)]
57
  return [255 - detected_map] + results
58
 
59
- def process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta):
60
  with torch.no_grad():
61
  input_image = HWC3(input_image)
62
- detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution))
63
  detected_map = HWC3(detected_map)
64
  img = resize_image(input_image, image_resolution)
65
  H, W, C = img.shape
@@ -72,15 +83,15 @@ def process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_res
72
 
73
  seed_everything(seed)
74
 
75
- cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
76
- un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
77
  shape = (4, H // 8, W // 8)
78
 
79
- samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
80
  shape, cond, verbose=False, eta=eta,
81
  unconditional_guidance_scale=scale,
82
  unconditional_conditioning=un_cond)
83
- x_samples = model.decode_first_stage(samples)
84
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
85
 
86
  results = [x_samples[i] for i in range(num_samples)]
 
14
  from cldm.model import create_model, load_state_dict
15
 
16
  from huggingface_hub import hf_hub_url, cached_download
17
+
18
  REPO_ID = "lllyasviel/ControlNet"
19
+ canny_checkpoint = "models/control_sd15_canny.pth"
20
+ pose_checkpoint = "models/control_sd15_openpose.pth"
21
 
22
+ canny_model = create_model('./models/cldm_v15.yaml')
23
+ canny_model.load_state_dict(load_state_dict(cached_download(
24
+ hf_hub_url(REPO_ID, canny_checkpoint)
25
  ), location='cpu'))
26
+ ddim_sampler_canny = DDIMSampler(canny_model)
27
+
28
+
29
+ pose_model = create_model('./models/cldm_v15.yaml')
30
+ pose_model.load_state_dict(load_state_dict(cached_download(
31
+ hf_hub_url(REPO_ID, pose_checkpoint)
32
+ ), location='cpu'))
33
+ ddim_sampler_pose = DDIMSampler(pose_model)
34
 
35
  def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
36
  # TODO: Add other control tasks
37
+ if input_control == "Canny Edge Map":
38
+ return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
39
+ else:
40
+ return process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
41
 
42
  def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
43
  with torch.no_grad():
 
53
 
54
  seed_everything(seed)
55
 
56
+ cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
57
+ un_cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([n_prompt] * num_samples)]}
58
  shape = (4, H // 8, W // 8)
59
 
60
+ samples, intermediates = ddim_sampler_canny.sample(ddim_steps, num_samples,
61
  shape, cond, verbose=False, eta=eta,
62
  unconditional_guidance_scale=scale,
63
  unconditional_conditioning=un_cond)
64
+ x_samples = canny_model.decode_first_stage(samples)
65
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
66
 
67
  results = [x_samples[i] for i in range(num_samples)]
68
  return [255 - detected_map] + results
69
 
70
+ def process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta):
71
  with torch.no_grad():
72
  input_image = HWC3(input_image)
73
+ detected_map, _ = apply_openpose(resize_image(input_image, image_resolution))
74
  detected_map = HWC3(detected_map)
75
  img = resize_image(input_image, image_resolution)
76
  H, W, C = img.shape
 
83
 
84
  seed_everything(seed)
85
 
86
+ cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
87
+ un_cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([n_prompt] * num_samples)]}
88
  shape = (4, H // 8, W // 8)
89
 
90
+ samples, intermediates = ddim_sampler_pose.sample(ddim_steps, num_samples,
91
  shape, cond, verbose=False, eta=eta,
92
  unconditional_guidance_scale=scale,
93
  unconditional_conditioning=un_cond)
94
+ x_samples = pose_model.decode_first_stage(samples)
95
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
96
 
97
  results = [x_samples[i] for i in range(num_samples)]