RamAnanth1 commited on
Commit
a5994ff
1 Parent(s): e82f9dc

Attempt using safetensors for lightweight memory

Browse files
Files changed (1) hide show
  1. app.py +40 -8
app.py CHANGED
@@ -15,10 +15,15 @@ from cldm.model import create_model, load_state_dict
15
 
16
  from huggingface_hub import hf_hub_url, cached_download
17
 
18
- REPO_ID = "lllyasviel/ControlNet"
19
- canny_checkpoint = "models/control_sd15_canny.pth"
20
- scribble_checkpoint = "models/control_sd15_scribble.pth"
21
- pose_checkpoint = "models/control_sd15_openpose.pth"
 
 
 
 
 
22
 
23
  canny_model = create_model('./models/cldm_v15.yaml').cpu()
24
  canny_model.load_state_dict(load_state_dict(cached_download(
@@ -30,7 +35,7 @@ ddim_sampler = DDIMSampler(canny_model)
30
  pose_model = create_model('./models/cldm_v15.yaml').cpu()
31
  pose_model.load_state_dict(load_state_dict(cached_download(
32
  hf_hub_url(REPO_ID, pose_checkpoint)
33
- ), location='cuda'))
34
  pose_model = pose_model.cuda()
35
  ddim_sampler_pose = DDIMSampler(pose_model)
36
 
@@ -41,6 +46,8 @@ scribble_model.load_state_dict(load_state_dict(cached_download(
41
  scribble_model = canny_model.cuda()
42
  ddim_sampler_scribble = DDIMSampler(scribble_model)
43
 
 
 
44
  def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
45
  # TODO: Add other control tasks
46
  if input_control == "Scribble":
@@ -64,14 +71,24 @@ def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_re
64
 
65
  seed_everything(seed)
66
 
 
 
 
67
  cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
68
  un_cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([n_prompt] * num_samples)]}
69
  shape = (4, H // 8, W // 8)
70
 
 
 
 
71
  samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
72
  shape, cond, verbose=False, eta=eta,
73
  unconditional_guidance_scale=scale,
74
  unconditional_conditioning=un_cond)
 
 
 
 
75
  x_samples = canny_model.decode_first_stage(samples)
76
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
77
 
@@ -92,16 +109,24 @@ def process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image
92
 
93
  seed_everything(seed)
94
 
 
 
 
95
  cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
96
  un_cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([n_prompt] * num_samples)]}
97
  shape = (4, H // 8, W // 8)
98
 
99
-
 
 
100
  samples, intermediates = ddim_sampler_scribble.sample(ddim_steps, num_samples,
101
  shape, cond, verbose=False, eta=eta,
102
  unconditional_guidance_scale=scale,
103
  unconditional_conditioning=un_cond)
104
 
 
 
 
105
  x_samples = scribble_model.decode_first_stage(samples)
106
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
107
 
@@ -126,18 +151,25 @@ def process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_res
126
  seed = random.randint(0, 65535)
127
  seed_everything(seed)
128
 
 
 
129
 
 
130
  cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
131
  un_cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([n_prompt] * num_samples)]}
132
  shape = (4, H // 8, W // 8)
133
 
134
-
 
 
135
  samples, intermediates = ddim_sampler_pose.sample(ddim_steps, num_samples,
136
  shape, cond, verbose=False, eta=eta,
137
  unconditional_guidance_scale=scale,
138
  unconditional_conditioning=un_cond)
139
 
140
-
 
 
141
  x_samples = pose_model.decode_first_stage(samples)
142
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
143
 
 
15
 
16
  from huggingface_hub import hf_hub_url, cached_download
17
 
18
+ # REPO_ID = "lllyasviel/ControlNet"
19
+ # canny_checkpoint = "models/control_sd15_canny.pth"
20
+ # scribble_checkpoint = "models/control_sd15_scribble.pth"
21
+ # pose_checkpoint = "models/control_sd15_openpose.pth"
22
+
23
+ REPO_ID = "webui/ControlNet-modules-safetensors"
24
+ canny_checkpoint = " control_canny-fp16.safetensors"
25
+ scribble_checkpoint = "control_scribble-fp16.safetensors"
26
+ pose_checkpoint = "control_openpose-fp16.safetensors"
27
 
28
  canny_model = create_model('./models/cldm_v15.yaml').cpu()
29
  canny_model.load_state_dict(load_state_dict(cached_download(
 
35
  pose_model = create_model('./models/cldm_v15.yaml').cpu()
36
  pose_model.load_state_dict(load_state_dict(cached_download(
37
  hf_hub_url(REPO_ID, pose_checkpoint)
38
+ ), location='cpu'))
39
  pose_model = pose_model.cuda()
40
  ddim_sampler_pose = DDIMSampler(pose_model)
41
 
 
46
  scribble_model = canny_model.cuda()
47
  ddim_sampler_scribble = DDIMSampler(scribble_model)
48
 
49
+ save_memory = False
50
+
51
  def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
52
  # TODO: Add other control tasks
53
  if input_control == "Scribble":
 
71
 
72
  seed_everything(seed)
73
 
74
+ if save_memory:
75
+ canny_model.low_vram_shift(is_diffusing=False)
76
+
77
  cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
78
  un_cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([n_prompt] * num_samples)]}
79
  shape = (4, H // 8, W // 8)
80
 
81
+ if save_memory:
82
+ canny_model.low_vram_shift(is_diffusing=False)
83
+
84
  samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
85
  shape, cond, verbose=False, eta=eta,
86
  unconditional_guidance_scale=scale,
87
  unconditional_conditioning=un_cond)
88
+
89
+ if save_memory:
90
+ canny_model.low_vram_shift(is_diffusing=False)
91
+
92
  x_samples = canny_model.decode_first_stage(samples)
93
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
94
 
 
109
 
110
  seed_everything(seed)
111
 
112
+ if save_memory:
113
+ scribble_model.low_vram_shift(is_diffusing=False)
114
+
115
  cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
116
  un_cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([n_prompt] * num_samples)]}
117
  shape = (4, H // 8, W // 8)
118
 
119
+ if save_memory:
120
+ scribble_model.low_vram_shift(is_diffusing=False)
121
+
122
  samples, intermediates = ddim_sampler_scribble.sample(ddim_steps, num_samples,
123
  shape, cond, verbose=False, eta=eta,
124
  unconditional_guidance_scale=scale,
125
  unconditional_conditioning=un_cond)
126
 
127
+ if save_memory:
128
+ scribble_model.low_vram_shift(is_diffusing=False)
129
+
130
  x_samples = scribble_model.decode_first_stage(samples)
131
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
132
 
 
151
  seed = random.randint(0, 65535)
152
  seed_everything(seed)
153
 
154
+ if save_memory:
155
+ pose_model.low_vram_shift(is_diffusing=False)
156
 
157
+
158
  cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
159
  un_cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([n_prompt] * num_samples)]}
160
  shape = (4, H // 8, W // 8)
161
 
162
+ if save_memory:
163
+ pose_model.low_vram_shift(is_diffusing=False)
164
+
165
  samples, intermediates = ddim_sampler_pose.sample(ddim_steps, num_samples,
166
  shape, cond, verbose=False, eta=eta,
167
  unconditional_guidance_scale=scale,
168
  unconditional_conditioning=un_cond)
169
 
170
+ if save_memory:
171
+ pose_model.low_vram_shift(is_diffusing=False)
172
+
173
  x_samples = pose_model.decode_first_stage(samples)
174
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
175