Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -227,87 +227,89 @@ def marching_cube(b, text, global_info):
|
|
227 |
return path
|
228 |
|
229 |
def infer(prompt, samples, steps, scale, seed, global_info):
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
decode_res = model.decode_first_stage(sample)
|
248 |
-
|
249 |
-
big_video_list = []
|
250 |
-
|
251 |
-
global_info['decode_res'] = decode_res
|
252 |
-
|
253 |
-
for b in range(batch_size):
|
254 |
-
def render_img(v):
|
255 |
-
rgb_sample, _ = model.first_stage_model.render_triplane_eg3d_decoder(
|
256 |
-
decode_res[b:b+1], batch_rays_list[v:v+1].to(device), torch.zeros(1, H, H, 3).to(device),
|
257 |
-
)
|
258 |
-
rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0]
|
259 |
-
rgb_sample = np.stack(
|
260 |
-
[rgb_sample[..., 2], rgb_sample[..., 1], rgb_sample[..., 0]], -1
|
261 |
-
)
|
262 |
-
rgb_sample = add_text(rgb_sample, str(b))
|
263 |
-
return rgb_sample
|
264 |
-
|
265 |
-
view_num = len(batch_rays_list)
|
266 |
-
video_list = []
|
267 |
-
for v in tqdm.tqdm(range(view_num//8*3, view_num//8*5, 2)):
|
268 |
-
rgb_sample = render_img(v)
|
269 |
-
video_list.append(rgb_sample)
|
270 |
-
big_video_list.append(video_list)
|
271 |
-
# if batch_size == 2:
|
272 |
-
# cat_video_list = [
|
273 |
-
# np.concatenate([big_video_list[j][i] for j in range(len(big_video_list))], 1) \
|
274 |
-
# for i in range(len(big_video_list[0]))
|
275 |
-
# ]
|
276 |
-
# elif batch_size > 2:
|
277 |
-
# if batch_size == 3:
|
278 |
-
# big_video_list.append(
|
279 |
-
# [np.zeros_like(f) for f in big_video_list[0]]
|
280 |
-
# )
|
281 |
-
# cat_video_list = [
|
282 |
-
# np.concatenate([
|
283 |
-
# np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1),
|
284 |
-
# np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1),
|
285 |
-
# ], 0) \
|
286 |
-
# for i in range(len(big_video_list[0]))
|
287 |
-
# ]
|
288 |
-
# else:
|
289 |
-
# cat_video_list = big_video_list[0]
|
290 |
-
|
291 |
-
for _ in range(4 - batch_size):
|
292 |
-
big_video_list.append(
|
293 |
-
[np.zeros_like(f) + 255 for f in big_video_list[0]]
|
294 |
)
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
]
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
|
306 |
return global_info, path
|
307 |
|
308 |
def infer_stage2(prompt, selection, seed, global_info, iters):
|
309 |
prompt = prompt.replace('/', '')
|
310 |
-
|
|
|
311 |
mesh_name = mesh_path.split('/')[-1][:-4]
|
312 |
# if2_cmd = f"threefiner if2 --mesh {mesh_path} --prompt \"{prompt}\" --outdir tmp --save {mesh_name}_if2.glb --text_dir --front_dir=-y"
|
313 |
# print(if2_cmd)
|
|
|
227 |
return path
|
228 |
|
229 |
def infer(prompt, samples, steps, scale, seed, global_info):
|
230 |
+
with torch.cuda.device(1):
|
231 |
+
prompt = prompt.replace('/', '')
|
232 |
+
pl.seed_everything(seed)
|
233 |
+
batch_size = samples
|
234 |
+
with torch.no_grad():
|
235 |
+
noise = None
|
236 |
+
c = model.get_learned_conditioning([prompt])
|
237 |
+
unconditional_c = torch.zeros_like(c)
|
238 |
+
sample, _ = sampler.sample(
|
239 |
+
S=steps,
|
240 |
+
batch_size=batch_size,
|
241 |
+
shape=shape,
|
242 |
+
verbose=False,
|
243 |
+
x_T = noise,
|
244 |
+
conditioning = c.repeat(batch_size, 1, 1),
|
245 |
+
unconditional_guidance_scale=scale,
|
246 |
+
unconditional_conditioning=unconditional_c.repeat(batch_size, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
)
|
248 |
+
decode_res = model.decode_first_stage(sample)
|
249 |
+
|
250 |
+
big_video_list = []
|
251 |
+
|
252 |
+
global_info['decode_res'] = decode_res
|
253 |
+
|
254 |
+
for b in range(batch_size):
|
255 |
+
def render_img(v):
|
256 |
+
rgb_sample, _ = model.first_stage_model.render_triplane_eg3d_decoder(
|
257 |
+
decode_res[b:b+1], batch_rays_list[v:v+1].to(device), torch.zeros(1, H, H, 3).to(device),
|
258 |
+
)
|
259 |
+
rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0]
|
260 |
+
rgb_sample = np.stack(
|
261 |
+
[rgb_sample[..., 2], rgb_sample[..., 1], rgb_sample[..., 0]], -1
|
262 |
+
)
|
263 |
+
rgb_sample = add_text(rgb_sample, str(b))
|
264 |
+
return rgb_sample
|
265 |
+
|
266 |
+
view_num = len(batch_rays_list)
|
267 |
+
video_list = []
|
268 |
+
for v in tqdm.tqdm(range(view_num//8*3, view_num//8*5, 2)):
|
269 |
+
rgb_sample = render_img(v)
|
270 |
+
video_list.append(rgb_sample)
|
271 |
+
big_video_list.append(video_list)
|
272 |
+
# if batch_size == 2:
|
273 |
+
# cat_video_list = [
|
274 |
+
# np.concatenate([big_video_list[j][i] for j in range(len(big_video_list))], 1) \
|
275 |
+
# for i in range(len(big_video_list[0]))
|
276 |
+
# ]
|
277 |
+
# elif batch_size > 2:
|
278 |
+
# if batch_size == 3:
|
279 |
+
# big_video_list.append(
|
280 |
+
# [np.zeros_like(f) for f in big_video_list[0]]
|
281 |
+
# )
|
282 |
+
# cat_video_list = [
|
283 |
+
# np.concatenate([
|
284 |
+
# np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1),
|
285 |
+
# np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1),
|
286 |
+
# ], 0) \
|
287 |
+
# for i in range(len(big_video_list[0]))
|
288 |
+
# ]
|
289 |
+
# else:
|
290 |
+
# cat_video_list = big_video_list[0]
|
291 |
+
|
292 |
+
for _ in range(4 - batch_size):
|
293 |
+
big_video_list.append(
|
294 |
+
[np.zeros_like(f) + 255 for f in big_video_list[0]]
|
295 |
+
)
|
296 |
+
cat_video_list = [
|
297 |
+
np.concatenate([
|
298 |
+
np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1),
|
299 |
+
np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1),
|
300 |
+
], 0) \
|
301 |
+
for i in range(len(big_video_list[0]))
|
302 |
+
]
|
303 |
+
|
304 |
+
path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4"
|
305 |
+
imageio.mimwrite(path, np.stack(cat_video_list, 0))
|
306 |
|
307 |
return global_info, path
|
308 |
|
309 |
def infer_stage2(prompt, selection, seed, global_info, iters):
|
310 |
prompt = prompt.replace('/', '')
|
311 |
+
with torch.cuda.device(1):
|
312 |
+
mesh_path = marching_cube(int(selection), prompt, global_info)
|
313 |
mesh_name = mesh_path.split('/')[-1][:-4]
|
314 |
# if2_cmd = f"threefiner if2 --mesh {mesh_path} --prompt \"{prompt}\" --outdir tmp --save {mesh_name}_if2.glb --text_dir --front_dir=-y"
|
315 |
# print(if2_cmd)
|