hysts HF staff commited on
Commit
59bef24
1 Parent(s): 3744a88
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import functools
6
  import os
7
  import random
8
  import shlex
@@ -65,26 +64,25 @@ def load_model(device: torch.device) -> nn.Module:
65
  return model
66
 
67
 
68
- def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
69
- return torch.from_numpy(np.random.RandomState(seed).randn(1, z_dim)).to(device).float()
 
 
 
 
70
 
71
 
72
  @torch.inference_mode()
73
- def generate_image(
74
- seed: int, truncation_psi: float, randomize_noise: bool, model: nn.Module, device: torch.device
75
- ) -> np.ndarray:
76
  seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
77
 
78
- z = generate_z(model.style_dim, seed, device)
 
79
  out, _ = model([z], truncation=truncation_psi, truncation_latent=model.latent_avg, randomize_noise=randomize_noise)
80
  out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
81
  return out[0].cpu().numpy()
82
 
83
 
84
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
85
- model = load_model(device)
86
- fn = functools.partial(generate_image, model=model, device=device)
87
-
88
  with gr.Blocks(css="style.css") as demo:
89
  gr.Markdown(DESCRIPTION)
90
  with gr.Row():
@@ -93,7 +91,7 @@ with gr.Blocks(css="style.css") as demo:
93
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
94
  psi = gr.Slider(label="Truncation psi", minimum=0, maximum=2, step=0.05, value=0.7)
95
  randomize_noise = gr.Checkbox(label="Randomize Noise", value=False)
96
- run_button = gr.Button("Run")
97
  with gr.Column():
98
  result = gr.Image(label="Output")
99
  gr.Markdown(ARTICLE)
@@ -105,9 +103,11 @@ with gr.Blocks(css="style.css") as demo:
105
  queue=False,
106
  api_name=False,
107
  ).then(
108
- fn=fn,
109
  inputs=[seed, psi, randomize_noise],
110
  outputs=result,
111
  api_name="run",
112
  )
113
- demo.queue(max_size=10).launch()
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import os
6
  import random
7
  import shlex
 
64
  return model
65
 
66
 
67
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
68
+ model = load_model(device)
69
+
70
+
71
+ def generate_z(z_dim: int, seed: int) -> torch.Tensor:
72
+ return torch.from_numpy(np.random.RandomState(seed).randn(1, z_dim)).float()
73
 
74
 
75
  @torch.inference_mode()
76
+ def generate_image(seed: int, truncation_psi: float, randomize_noise: bool) -> np.ndarray:
 
 
77
  seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
78
 
79
+ z = generate_z(model.style_dim, seed)
80
+ z = z.to(device)
81
  out, _ = model([z], truncation=truncation_psi, truncation_latent=model.latent_avg, randomize_noise=randomize_noise)
82
  out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
83
  return out[0].cpu().numpy()
84
 
85
 
 
 
 
 
86
  with gr.Blocks(css="style.css") as demo:
87
  gr.Markdown(DESCRIPTION)
88
  with gr.Row():
 
91
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
92
  psi = gr.Slider(label="Truncation psi", minimum=0, maximum=2, step=0.05, value=0.7)
93
  randomize_noise = gr.Checkbox(label="Randomize Noise", value=False)
94
+ run_button = gr.Button()
95
  with gr.Column():
96
  result = gr.Image(label="Output")
97
  gr.Markdown(ARTICLE)
 
103
  queue=False,
104
  api_name=False,
105
  ).then(
106
+ fn=generate_image,
107
  inputs=[seed, psi, randomize_noise],
108
  outputs=result,
109
  api_name="run",
110
  )
111
+
112
+ if __name__ == "__main__":
113
+ demo.queue(max_size=10).launch()