HongFangzhou commited on
Commit
8ee45cc
β€’
1 Parent(s): 6fcfbfd

3DTopia test

Browse files
Files changed (2) hide show
  1. app.py +145 -92
  2. requirements.txt +4 -1
app.py CHANGED
@@ -2,7 +2,10 @@ import os
2
  import sys
3
  import cv2
4
  import time
 
5
  import json
 
 
6
  import torch
7
  import mcubes
8
  import trimesh
@@ -11,7 +14,6 @@ import argparse
11
  import subprocess
12
  import numpy as np
13
  import gradio as gr
14
- from tqdm import tqdm
15
  import imageio.v2 as imageio
16
  import pytorch_lightning as pl
17
  from omegaconf import OmegaConf
@@ -28,10 +30,90 @@ from utility.initialize import instantiate_from_config, get_obj_from_str
28
  from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes
29
  from utility.triplane_renderer.renderer import get_rays, to8b
30
 
 
 
 
31
  import warnings
32
  warnings.filterwarnings("ignore", category=UserWarning)
33
  warnings.filterwarnings("ignore", category=DeprecationWarning)
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def add_text(rgb, caption):
36
  font = cv2.FONT_HERSHEY_SIMPLEX
37
  # org
@@ -51,76 +133,6 @@ def add_text(rgb, caption):
51
  cv2.putText(rgb, bci, (gap, gap*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA)
52
  return rgb
53
 
54
- config = "3DTopia/configs/default.yaml"
55
- # local_ckpt = "3DTopia/checkpoints/3dtopia_diffusion_state_dict.ckpt"
56
- local_ckpt = "/data/3DTopia_all/3DTopia_code/checkpoints/model.safetensors"
57
- if os.path.exists(local_ckpt):
58
- ckpt = local_ckpt
59
- else:
60
- ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors")
61
- configs = OmegaConf.load(config)
62
- os.makedirs("tmp", exist_ok=True)
63
-
64
- import sys
65
- import traceback
66
-
67
- try:
68
- if ckpt.endswith(".ckpt"):
69
- model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params)
70
- elif ckpt.endswith(".safetensors"):
71
- model = get_obj_from_str(configs.model["target"])(**configs.model.params)
72
- print("download finish")
73
- model_ckpt = load_file(ckpt)
74
- print("download finish")
75
- model.load_state_dict(model_ckpt)
76
- print("download finish")
77
- else:
78
- raise NotImplementedError
79
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
80
- model = model.to(device)
81
- print("download finish")
82
- sampler = DDIMSampler(model)
83
-
84
- img_size = configs.model.params.unet_config.params.image_size
85
- channels = configs.model.params.unet_config.params.in_channels
86
- shape = [channels, img_size, img_size * 3]
87
-
88
- pose_folder = '3DTopia/assets/sample_data/pose'
89
- poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)])
90
- batch_rays_list = []
91
- H = 128
92
- ratio = 512 // H
93
- for p in poses_fname:
94
- c2w = np.loadtxt(p).reshape(4, 4)
95
- c2w[:3, 3] *= 2.2
96
- c2w = np.array([
97
- [1, 0, 0, 0],
98
- [0, 0, -1, 0],
99
- [0, 1, 0, 0],
100
- [0, 0, 0, 1]
101
- ]) @ c2w
102
-
103
- k = np.array([
104
- [560 / ratio, 0, H * 0.5],
105
- [0, 560 / ratio, H * 0.5],
106
- [0, 0, 1]
107
- ])
108
-
109
- rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4]))
110
- coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1)
111
- coords = torch.reshape(coords, [-1,2]).long()
112
- rays_o = rays_o[coords[:, 0], coords[:, 1]]
113
- rays_d = rays_d[coords[:, 0], coords[:, 1]]
114
- batch_rays = torch.stack([rays_o, rays_d], 0)
115
- batch_rays_list.append(batch_rays)
116
- batch_rays_list = torch.stack(batch_rays_list, 0)
117
- except Exception as e:
118
- print(e)
119
- print(traceback.format_exc())
120
- print(sys.exc_info()[2])
121
-
122
-
123
- print("download finish")
124
  def marching_cube(b, text, global_info):
125
  # prepare volumn for marching cube
126
  res = 128
@@ -169,7 +181,7 @@ def marching_cube(b, text, global_info):
169
  ]
170
  rgb_final = None
171
  diff_final = None
172
- for rays_o in tqdm(rays_o_list):
173
  rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device)
174
  rays_d = pt_vertices.reshape(-1, 3) - rays_o
175
  rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1)
@@ -246,7 +258,7 @@ def infer(prompt, samples, steps, scale, seed, global_info):
246
 
247
  view_num = len(batch_rays_list)
248
  video_list = []
249
- for v in tqdm(range(view_num//8*3, view_num//8*5, 2)):
250
  rgb_sample = render_img(v)
251
  video_list.append(rgb_sample)
252
  big_video_list.append(video_list)
@@ -287,25 +299,62 @@ def infer(prompt, samples, steps, scale, seed, global_info):
287
 
288
  return global_info, path
289
 
290
- def infer_stage2(prompt, selection, seed, global_info):
291
  prompt = prompt.replace('/', '')
292
  mesh_path = marching_cube(int(selection), prompt, global_info)
293
  mesh_name = mesh_path.split('/')[-1][:-4]
294
-
295
- if2_cmd = f"threefiner if2 --mesh {mesh_path} --prompt \"{prompt}\" --outdir tmp --save {mesh_name}_if2.glb --text_dir --front_dir=-y"
296
- print(if2_cmd)
297
- # os.system(if2_cmd)
298
- subprocess.Popen(if2_cmd, shell=True).wait()
299
- torch.cuda.empty_cache()
300
-
301
  video_path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4"
302
- render_cmd = f"kire {os.path.join('tmp', mesh_name + '_if2.glb')} --save_video {video_path} --wogui --force_cuda_rast --H 256 --W 256"
303
- print(render_cmd)
304
- # os.system(render_cmd)
305
- subprocess.Popen(render_cmd, shell=True).wait()
 
 
306
  torch.cuda.empty_cache()
307
 
308
- return video_path, os.path.join('tmp', mesh_name + '_if2.glb')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  markdown=f'''
311
  # 3DTopia
@@ -315,7 +364,7 @@ markdown=f'''
315
  First enter prompt for a 3D object, hit "Generate 3D". Then choose one candidate from the dropdown options for the second stage refinement and hit "Start Refinement". The final mesh can be downloaded from the bottom right box.
316
 
317
  ### Runtime:
318
- The first stage takes 30s if generating 4 samples. The second stage takes roughly 3 min.
319
 
320
  ### Useful links:
321
  [Github Repo](https://github.com/3DTopia/3DTopia)
@@ -337,7 +386,7 @@ with block:
337
  )
338
  btn = gr.Button("Generate 3D")
339
  gallery = gr.Video(height=512)
340
- # advanced_button = gr.Button("Advanced Options", elem_id="advanced-btn")
341
  with gr.Row(elem_id="advanced-options"):
342
  with gr.Tab("Advanced options"):
343
  samples = gr.Slider(label="Number of Samples", minimum=1, maximum=4, value=4, step=1)
@@ -361,11 +410,15 @@ with block:
361
  with gr.Column():
362
  with gr.Row():
363
  dropdown = gr.Dropdown(
364
- ['0', '1', '2', '3'], label="Choose a candidate for stage2", value='0'
365
  )
366
  btn_stage2 = gr.Button("Start Refinement")
367
  gallery = gr.Video(height=512)
368
- download = gr.File(label="Download mesh", file_count="single", height=100)
369
- gr.on([btn_stage2.click], infer_stage2, inputs=[text, dropdown, seed, global_info], outputs=[gallery, download])
 
 
 
 
370
 
371
- block.launch(share=True, debug=True)
 
2
  import sys
3
  import cv2
4
  import time
5
+ import tyro
6
  import json
7
+ import kiui
8
+ import tqdm
9
  import torch
10
  import mcubes
11
  import trimesh
 
14
  import subprocess
15
  import numpy as np
16
  import gradio as gr
 
17
  import imageio.v2 as imageio
18
  import pytorch_lightning as pl
19
  from omegaconf import OmegaConf
 
30
  from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes
31
  from utility.triplane_renderer.renderer import get_rays, to8b
32
 
33
+ from threefiner.gui import GUI
34
+ from threefiner.opt import config_defaults, config_doc, check_options, Options
35
+
36
  import warnings
37
  warnings.filterwarnings("ignore", category=UserWarning)
38
  warnings.filterwarnings("ignore", category=DeprecationWarning)
39
 
40
+ ###################################### INIT STAGE 1 #########################################
41
+ config = "3DTopia/configs/default.yaml"
42
+ download_ckpt = "3DTopia/checkpoints/3dtopia_diffusion_state_dict.ckpt"
43
+ if not os.path.exists(download_ckpt):
44
+ ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors")
45
+ else:
46
+ ckpt = download_ckpt
47
+ configs = OmegaConf.load(config)
48
+ os.makedirs("tmp", exist_ok=True)
49
+
50
+ if ckpt.endswith(".ckpt"):
51
+ model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params)
52
+ elif ckpt.endswith(".safetensors"):
53
+ model = get_obj_from_str(configs.model["target"])(**configs.model.params)
54
+ model_ckpt = load_file(ckpt)
55
+ model.load_state_dict(model_ckpt)
56
+ else:
57
+ raise NotImplementedError
58
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
59
+ model = model.to(device)
60
+ sampler = DDIMSampler(model)
61
+
62
+ img_size = configs.model.params.unet_config.params.image_size
63
+ channels = configs.model.params.unet_config.params.in_channels
64
+ shape = [channels, img_size, img_size * 3]
65
+
66
+ pose_folder = '3DTopia/assets/sample_data/pose'
67
+ poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)])
68
+ batch_rays_list = []
69
+ H = 128
70
+ ratio = 512 // H
71
+ for p in poses_fname:
72
+ c2w = np.loadtxt(p).reshape(4, 4)
73
+ c2w[:3, 3] *= 2.2
74
+ c2w = np.array([
75
+ [1, 0, 0, 0],
76
+ [0, 0, -1, 0],
77
+ [0, 1, 0, 0],
78
+ [0, 0, 0, 1]
79
+ ]) @ c2w
80
+
81
+ k = np.array([
82
+ [560 / ratio, 0, H * 0.5],
83
+ [0, 560 / ratio, H * 0.5],
84
+ [0, 0, 1]
85
+ ])
86
+
87
+ rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4]))
88
+ coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1)
89
+ coords = torch.reshape(coords, [-1,2]).long()
90
+ rays_o = rays_o[coords[:, 0], coords[:, 1]]
91
+ rays_d = rays_d[coords[:, 0], coords[:, 1]]
92
+ batch_rays = torch.stack([rays_o, rays_d], 0)
93
+ batch_rays_list.append(batch_rays)
94
+ batch_rays_list = torch.stack(batch_rays_list, 0)
95
+ ###################################### INIT STAGE 1 #########################################
96
+
97
+ ###################################### INIT STAGE 2 #########################################
98
+ GRADIO_SAVE_PATH_MESH = 'gradio_output.glb'
99
+ GRADIO_SAVE_PATH_VIDEO = 'gradio_output.mp4'
100
+
101
+ # opt = tyro.cli(tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc))
102
+ opt = Options(
103
+ mode='IF2',
104
+ iters=400,
105
+ )
106
+
107
+ # hacks for not loading mesh at initialization
108
+ # opt.mesh = 'tmp/_2024-01-25_19:33:03.110191_if2.glb'
109
+ opt.save = GRADIO_SAVE_PATH_MESH
110
+ opt.prompt = ''
111
+ opt.text_dir = True
112
+ opt.front_dir = '+z'
113
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
114
+ gui = GUI(opt)
115
+ ###################################### INIT STAGE 2 #########################################
116
+
117
  def add_text(rgb, caption):
118
  font = cv2.FONT_HERSHEY_SIMPLEX
119
  # org
 
133
  cv2.putText(rgb, bci, (gap, gap*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA)
134
  return rgb
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def marching_cube(b, text, global_info):
137
  # prepare volumn for marching cube
138
  res = 128
 
181
  ]
182
  rgb_final = None
183
  diff_final = None
184
+ for rays_o in tqdm.tqdm(rays_o_list):
185
  rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device)
186
  rays_d = pt_vertices.reshape(-1, 3) - rays_o
187
  rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1)
 
258
 
259
  view_num = len(batch_rays_list)
260
  video_list = []
261
+ for v in tqdm.tqdm(range(view_num//8*3, view_num//8*5, 2)):
262
  rgb_sample = render_img(v)
263
  video_list.append(rgb_sample)
264
  big_video_list.append(video_list)
 
299
 
300
  return global_info, path
301
 
302
+ def infer_stage2(prompt, selection, seed, global_info, iters):
303
  prompt = prompt.replace('/', '')
304
  mesh_path = marching_cube(int(selection), prompt, global_info)
305
  mesh_name = mesh_path.split('/')[-1][:-4]
306
+ # if2_cmd = f"threefiner if2 --mesh {mesh_path} --prompt \"{prompt}\" --outdir tmp --save {mesh_name}_if2.glb --text_dir --front_dir=-y"
307
+ # print(if2_cmd)
308
+ # subprocess.Popen(if2_cmd, shell=True).wait()
309
+ # torch.cuda.empty_cache()
 
 
 
310
  video_path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4"
311
+ # render_cmd = f"kire {os.path.join('tmp', mesh_name + '_if2.glb')} --save_video {video_path} --wogui --force_cuda_rast --H 256 --W 256"
312
+ # print(render_cmd)
313
+ # subprocess.Popen(render_cmd, shell=True).wait()
314
+ # torch.cuda.empty_cache()
315
+
316
+ process_stage2(mesh_path, prompt, "down", iters, f'tmp/{mesh_name}_if2.glb', video_path)
317
  torch.cuda.empty_cache()
318
 
319
+ return video_path, f'tmp/{mesh_name}_if2.glb'
320
+
321
+ def process_stage2(input_model, input_text, input_dir, iters, output_model, output_video):
322
+ # set front facing direction (map from gradio model3D's mysterious coordinate system to OpenGL...)
323
+ opt.text_dir = True
324
+ if input_dir == 'front':
325
+ opt.front_dir = '-z'
326
+ elif input_dir == 'back':
327
+ opt.front_dir = '+z'
328
+ elif input_dir == 'left':
329
+ opt.front_dir = '+x'
330
+ elif input_dir == 'right':
331
+ opt.front_dir = '-x'
332
+ elif input_dir == 'up':
333
+ opt.front_dir = '+y'
334
+ elif input_dir == 'down':
335
+ opt.front_dir = '-y'
336
+ else:
337
+ # turn off text_dir
338
+ opt.text_dir = False
339
+ opt.front_dir = '+z'
340
+
341
+ # set mesh path
342
+ opt.mesh = input_model
343
+
344
+ # load mesh!
345
+ gui.renderer = gui.renderer_class(opt, device).to(device)
346
+
347
+ # set prompt
348
+ gui.prompt = opt.positive_prompt + ', ' + input_text
349
+
350
+ # train
351
+ gui.prepare_train() # update optimizer and prompt embeddings
352
+ for i in tqdm.trange(iters):
353
+ gui.train_step()
354
+
355
+ # save mesh & video
356
+ gui.save_model(output_model)
357
+ gui.save_model(output_video)
358
 
359
  markdown=f'''
360
  # 3DTopia
 
364
  First enter prompt for a 3D object, hit "Generate 3D". Then choose one candidate from the dropdown options for the second stage refinement and hit "Start Refinement". The final mesh can be downloaded from the bottom right box.
365
 
366
  ### Runtime:
367
+ The first stage takes 30s if generating 4 samples. The second stage takes roughly 1m30s.
368
 
369
  ### Useful links:
370
  [Github Repo](https://github.com/3DTopia/3DTopia)
 
386
  )
387
  btn = gr.Button("Generate 3D")
388
  gallery = gr.Video(height=512)
389
+ # advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
390
  with gr.Row(elem_id="advanced-options"):
391
  with gr.Tab("Advanced options"):
392
  samples = gr.Slider(label="Number of Samples", minimum=1, maximum=4, value=4, step=1)
 
410
  with gr.Column():
411
  with gr.Row():
412
  dropdown = gr.Dropdown(
413
+ ['0', '1', '2', '3'], label="Choose a Candidate For Stage2", value='0'
414
  )
415
  btn_stage2 = gr.Button("Start Refinement")
416
  gallery = gr.Video(height=512)
417
+ with gr.Row(elem_id="advanced-options"):
418
+ with gr.Tab("Advanced options"):
419
+ # input_dir = gr.Radio(['front', 'back', 'left', 'right', 'up', 'down'], value='down', label="front-facing direction")
420
+ iters = gr.Slider(minimum=100, maximum=1000, step=100, value=400, label="Refine iterations")
421
+ download = gr.File(label="Download Mesh", file_count="single", height=100)
422
+ gr.on([btn_stage2.click], infer_stage2, inputs=[text, dropdown, seed, global_info, iters], outputs=[gallery, download])
423
 
424
+ block.launch(share=True)
requirements.txt CHANGED
@@ -54,4 +54,7 @@ trimesh
54
  vit-pytorch
55
  wandb
56
  wcwidth
57
- zipp
 
 
 
 
54
  vit-pytorch
55
  wandb
56
  wcwidth
57
+ zipp
58
+ git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
59
+ git+https://github.com/NVlabs/nvdiffrast
60
+ git+https://github.com/3DTopia/threefiner