gaparmar commited on
Commit
a5f38fd
1 Parent(s): 13ed5cd
Files changed (3) hide show
  1. app.py +1 -1
  2. src/model.py +46 -1
  3. src/pix2pix_turbo.py +3 -46
app.py CHANGED
@@ -238,7 +238,7 @@ with gr.Blocks(css="style.css") as demo:
238
  prompt_temp = gr.Textbox(label="Prompt Style Template", value=styles[DEFAULT_STYLE_NAME], scale=2, max_lines=1)
239
 
240
  with gr.Row():
241
- val_r = gr.Slider(label="sketch guidance r: ", show_label=True, minimum=0, maximum=1, value=0.4, step=0.01, scale=3)
242
  seed = gr.Textbox(label="Seed", value=42, scale=1, min_width=50)
243
  randomize_seed = gr.Button("Random", scale=1, min_width=50)
244
 
 
238
  prompt_temp = gr.Textbox(label="Prompt Style Template", value=styles[DEFAULT_STYLE_NAME], scale=2, max_lines=1)
239
 
240
  with gr.Row():
241
+ val_r = gr.Slider(label="Sketch guidance gamma: ", show_label=True, minimum=0, maximum=1, value=0.4, step=0.01, scale=3)
242
  seed = gr.Textbox(label="Seed", value=42, scale=1, min_width=50)
243
  randomize_seed = gr.Button("Random", scale=1, min_width=50)
244
 
src/model.py CHANGED
@@ -10,4 +10,49 @@ def make_1step_sched():
10
  noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler")
11
  noise_scheduler_1step.set_timesteps(1, device="cuda")
12
  noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
13
- return noise_scheduler_1step
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler")
11
  noise_scheduler_1step.set_timesteps(1, device="cuda")
12
  noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
13
+ return noise_scheduler_1step
14
+
15
+
16
+ """The forward method of the `Encoder` class."""
17
+ def my_vae_encoder_fwd(self, sample):
18
+ sample = self.conv_in(sample)
19
+ l_blocks = []
20
+ # down
21
+ for down_block in self.down_blocks:
22
+ l_blocks.append(sample)
23
+ sample = down_block(sample)
24
+ # middle
25
+ sample = self.mid_block(sample)
26
+ sample = self.conv_norm_out(sample)
27
+ sample = self.conv_act(sample)
28
+ sample = self.conv_out(sample)
29
+ self.current_down_blocks = l_blocks
30
+ return sample
31
+
32
+
33
+ """The forward method of the `Decoder` class."""
34
+ def my_vae_decoder_fwd(self,sample, latent_embeds = None):
35
+ sample = self.conv_in(sample)
36
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
37
+ # middle
38
+ sample = self.mid_block(sample, latent_embeds)
39
+ sample = sample.to(upscale_dtype)
40
+ if not self.ignore_skip:
41
+ skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
42
+ # up
43
+ for idx, up_block in enumerate(self.up_blocks):
44
+ skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma)
45
+ # add skip
46
+ sample = sample + skip_in
47
+ sample = up_block(sample, latent_embeds)
48
+ else:
49
+ for idx, up_block in enumerate(self.up_blocks):
50
+ sample = up_block(sample, latent_embeds)
51
+ # post-process
52
+ if latent_embeds is None:
53
+ sample = self.conv_norm_out(sample)
54
+ else:
55
+ sample = self.conv_norm_out(sample, latent_embeds)
56
+ sample = self.conv_act(sample)
57
+ sample = self.conv_out(sample)
58
+ return sample
src/pix2pix_turbo.py CHANGED
@@ -11,52 +11,7 @@ from diffusers.utils.peft_utils import set_weights_and_activate_adapters
11
  from peft import LoraConfig
12
  p = "src/"
13
  sys.path.append(p)
14
- from model import make_1step_sched
15
-
16
-
17
- """The forward method of the `Encoder` class."""
18
- def my_vae_encoder_fwd(self, sample):
19
- sample = self.conv_in(sample)
20
- l_blocks = []
21
- # down
22
- for down_block in self.down_blocks:
23
- l_blocks.append(sample)
24
- sample = down_block(sample)
25
- # middle
26
- sample = self.mid_block(sample)
27
- sample = self.conv_norm_out(sample)
28
- sample = self.conv_act(sample)
29
- sample = self.conv_out(sample)
30
- self.current_down_blocks = l_blocks
31
- return sample
32
-
33
-
34
- """The forward method of the `Decoder` class."""
35
- def my_vae_decoder_fwd(self,sample, latent_embeds = None):
36
- sample = self.conv_in(sample)
37
- upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
38
- # middle
39
- sample = self.mid_block(sample, latent_embeds)
40
- sample = sample.to(upscale_dtype)
41
- if not self.ignore_skip:
42
- skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
43
- # up
44
- for idx, up_block in enumerate(self.up_blocks):
45
- skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx])
46
- # add skip
47
- sample = sample + skip_in
48
- sample = up_block(sample, latent_embeds)
49
- else:
50
- for idx, up_block in enumerate(self.up_blocks):
51
- sample = up_block(sample, latent_embeds)
52
- # post-process
53
- if latent_embeds is None:
54
- sample = self.conv_norm_out(sample)
55
- else:
56
- sample = self.conv_norm_out(sample, latent_embeds)
57
- sample = self.conv_act(sample)
58
- sample = self.conv_out(sample)
59
- return sample
60
 
61
 
62
  class TwinConv(torch.nn.Module):
@@ -151,6 +106,7 @@ class Pix2Pix_Turbo(torch.nn.Module):
151
  unet.eval()
152
  vae.eval()
153
  self.unet, self.vae = unet, vae
 
154
  self.timesteps = torch.tensor([999], device="cuda").long()
155
 
156
 
@@ -177,5 +133,6 @@ class Pix2Pix_Turbo(torch.nn.Module):
177
  self.unet.conv_in.r = None
178
  x_denoised = self.sched.step(unet_output, self.timesteps, unet_input, return_dict=True).prev_sample
179
  self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
 
180
  output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor ).sample).clamp(-1,1)
181
  return output_image
 
11
  from peft import LoraConfig
12
  p = "src/"
13
  sys.path.append(p)
14
+ from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  class TwinConv(torch.nn.Module):
 
106
  unet.eval()
107
  vae.eval()
108
  self.unet, self.vae = unet, vae
109
+ self.vae.decoder.gamma = 1
110
  self.timesteps = torch.tensor([999], device="cuda").long()
111
 
112
 
 
133
  self.unet.conv_in.r = None
134
  x_denoised = self.sched.step(unet_output, self.timesteps, unet_input, return_dict=True).prev_sample
135
  self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
136
+ self.vae.decoder.gamma = r
137
  output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor ).sample).clamp(-1,1)
138
  return output_image