Spaces:
Build error
Build error
Jiatao Gu
commited on
Commit
•
368dc9b
1
Parent(s):
22790a0
fix some errors. update code
Browse files- .gitignore +3 -0
- app.py +31 -35
- gradio_queue.db +0 -0
.gitignore
CHANGED
@@ -23,3 +23,6 @@ scripts/research/
|
|
23 |
.ipynb_checkpoints/
|
24 |
_screenshots/
|
25 |
flagged
|
|
|
|
|
|
|
|
23 |
.ipynb_checkpoints/
|
24 |
_screenshots/
|
25 |
flagged
|
26 |
+
|
27 |
+
*.db
|
28 |
+
gradio_queue.db
|
app.py
CHANGED
@@ -20,13 +20,20 @@ from huggingface_hub import hf_hub_download
|
|
20 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
def set_random_seed(seed):
|
25 |
torch.manual_seed(seed)
|
26 |
np.random.seed(seed)
|
27 |
|
28 |
|
29 |
-
def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name=
|
30 |
gen = model.synthesis
|
31 |
range_u, range_v = gen.C.range_u, gen.C.range_v
|
32 |
if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option?
|
@@ -41,22 +48,10 @@ def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name='FFHQ512
|
|
41 |
return cam
|
42 |
|
43 |
|
44 |
-
def check_name(model_name
|
45 |
"""Gets model by name."""
|
46 |
-
if model_name
|
47 |
-
network_pkl = hf_hub_download(
|
48 |
-
|
49 |
-
# TODO: checkpoint to be updated!
|
50 |
-
# elif model_name == 'FFHQ512v2':
|
51 |
-
# network_pkl = "./pretrained/ffhq_512_eg3d.pkl"
|
52 |
-
# elif model_name == 'AFHQ512':
|
53 |
-
# network_pkl = "./pretrained/afhq_512.pkl"
|
54 |
-
# elif model_name == 'MetFaces512':
|
55 |
-
# network_pkl = "./pretrained/metfaces_512.pkl"
|
56 |
-
# elif model_name == 'CompCars256':
|
57 |
-
# network_pkl = "./pretrained/cars_256.pkl"
|
58 |
-
# elif model_name == 'FFHQ1024':
|
59 |
-
# network_pkl = "./pretrained/ffhq_1024.pkl"
|
60 |
else:
|
61 |
if os.path.isdir(model_name):
|
62 |
network_pkl = sorted(glob.glob(model_name + '/*.pkl'))[-1]
|
@@ -85,7 +80,7 @@ def get_model(network_pkl, render_option=None):
|
|
85 |
return G2, res, imgs
|
86 |
|
87 |
|
88 |
-
global_states = list(get_model(check_name()))
|
89 |
wss = [None, None]
|
90 |
|
91 |
def proc_seed(history, seed):
|
@@ -98,7 +93,8 @@ def proc_seed(history, seed):
|
|
98 |
def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
|
99 |
history = history or {}
|
100 |
seeds = []
|
101 |
-
|
|
|
102 |
if model_find != "":
|
103 |
model_name = model_find
|
104 |
|
@@ -124,7 +120,7 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
|
|
124 |
set_random_seed(seed)
|
125 |
z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.z_dim).astype('float32')).to(device)
|
126 |
ws = model.mapping(z=z, c=None, truncation_psi=trunc)
|
127 |
-
img = model.get_final_output(styles=ws, camera_matrices=get_camera_traj(model, 0, 0), render_option=render_option)
|
128 |
ws = ws.detach().cpu().numpy()
|
129 |
img = img[0].permute(1,2,0).detach().cpu().numpy()
|
130 |
|
@@ -178,26 +174,26 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
|
|
178 |
image = (image * 255).astype('uint8')
|
179 |
return image, history
|
180 |
|
181 |
-
model_name = gr.inputs.Dropdown(
|
182 |
-
model_find = gr.inputs.Textbox(label="
|
183 |
-
render_option = gr.inputs.Textbox(label="rendering options", default='steps:
|
184 |
-
trunc = gr.inputs.Slider(default=
|
185 |
-
seed1 = gr.inputs.Number(default=1, label="seed1")
|
186 |
-
seed2 = gr.inputs.Number(default=9, label="seed2")
|
187 |
-
mix1 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="
|
188 |
-
mix2 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="
|
189 |
-
early = gr.inputs.Radio(['None', 'Normal Map', 'Gradient Map'], default='None', label='
|
190 |
-
yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="
|
191 |
-
pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="
|
192 |
-
roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="
|
193 |
-
fov = gr.inputs.Slider(minimum=
|
194 |
css = ".output-image, .input-image, .image-preview {height: 600px !important} "
|
195 |
|
196 |
gr.Interface(fn=f_synthesis,
|
197 |
inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
|
198 |
-
title="
|
199 |
-
description="
|
200 |
outputs=["image", "state"],
|
201 |
layout='unaligned',
|
202 |
-
css=css, theme='dark-
|
203 |
live=True).launch(enable_queue=True)
|
|
|
20 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
|
22 |
|
23 |
+
model_lists = {
|
24 |
+
'ffhq-512x512-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl'),
|
25 |
+
'ffhq-256x256-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_256.pkl'),
|
26 |
+
'ffhq-1024x1024-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_1024.pkl'),
|
27 |
+
}
|
28 |
+
model_names = [name for name in model_lists]
|
29 |
+
|
30 |
|
31 |
def set_random_seed(seed):
|
32 |
torch.manual_seed(seed)
|
33 |
np.random.seed(seed)
|
34 |
|
35 |
|
36 |
+
def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name=None):
|
37 |
gen = model.synthesis
|
38 |
range_u, range_v = gen.C.range_u, gen.C.range_v
|
39 |
if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option?
|
|
|
48 |
return cam
|
49 |
|
50 |
|
51 |
+
def check_name(model_name):
|
52 |
"""Gets model by name."""
|
53 |
+
if model_name in model_lists:
|
54 |
+
network_pkl = hf_hub_download(**model_lists[model_name])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
else:
|
56 |
if os.path.isdir(model_name):
|
57 |
network_pkl = sorted(glob.glob(model_name + '/*.pkl'))[-1]
|
|
|
80 |
return G2, res, imgs
|
81 |
|
82 |
|
83 |
+
global_states = list(get_model(check_name(model_names[0])))
|
84 |
wss = [None, None]
|
85 |
|
86 |
def proc_seed(history, seed):
|
|
|
93 |
def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
|
94 |
history = history or {}
|
95 |
seeds = []
|
96 |
+
trunc = trunc / 100
|
97 |
+
|
98 |
if model_find != "":
|
99 |
model_name = model_find
|
100 |
|
|
|
120 |
set_random_seed(seed)
|
121 |
z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.z_dim).astype('float32')).to(device)
|
122 |
ws = model.mapping(z=z, c=None, truncation_psi=trunc)
|
123 |
+
img = model.get_final_output(styles=ws, camera_matrices=get_camera_traj(model, 0, 0, model_name=model_name), render_option=render_option)
|
124 |
ws = ws.detach().cpu().numpy()
|
125 |
img = img[0].permute(1,2,0).detach().cpu().numpy()
|
126 |
|
|
|
174 |
image = (image * 255).astype('uint8')
|
175 |
return image, history
|
176 |
|
177 |
+
model_name = gr.inputs.Dropdown(model_names)
|
178 |
+
model_find = gr.inputs.Textbox(label="Checkpoint path (folder or .pkl file)", default="")
|
179 |
+
render_option = gr.inputs.Textbox(label="Additional rendering options", default='freeze_bg,steps:50')
|
180 |
+
trunc = gr.inputs.Slider(default=70, maximum=100, minimum=0, label='Truncation trick (%)')
|
181 |
+
seed1 = gr.inputs.Number(default=1, label="Random seed1")
|
182 |
+
seed2 = gr.inputs.Number(default=9, label="Random seed2")
|
183 |
+
mix1 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (geometry)")
|
184 |
+
mix2 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (apparence)")
|
185 |
+
early = gr.inputs.Radio(['None', 'Normal Map', 'Gradient Map'], default='None', label='Intermedia output')
|
186 |
+
yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Yaw")
|
187 |
+
pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Pitch")
|
188 |
+
roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Roll (optional, not suggested for basic config)")
|
189 |
+
fov = gr.inputs.Slider(minimum=10, maximum=14, default=12, label="Fov")
|
190 |
css = ".output-image, .input-image, .image-preview {height: 600px !important} "
|
191 |
|
192 |
gr.Interface(fn=f_synthesis,
|
193 |
inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
|
194 |
+
title="Interactive Web Demo for StyleNeRF (ICLR 2022)",
|
195 |
+
description="StyleNeRF: A Style-based 3D-Aware Generator for High-resolution Image Synthesis. Currently the demo runs on CPU only.",
|
196 |
outputs=["image", "state"],
|
197 |
layout='unaligned',
|
198 |
+
css=css, theme='dark-seafoam',
|
199 |
live=True).launch(enable_queue=True)
|
gradio_queue.db
CHANGED
Binary files a/gradio_queue.db and b/gradio_queue.db differ
|
|