H-Liu1997 commited on
Commit
320051c
1 Parent(s): c3bd687

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -38
app.py CHANGED
@@ -467,13 +467,39 @@ def save_first_20_seconds(video_path, output_path="./save_video.mp4"):
467
 
468
 
469
  character_name_to_yaml = {
470
- "speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4": "./configs/gradio_speaker8.yaml",
471
- "speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4": "./configs/gradio_speaker7.yaml",
472
- "speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4": "./configs/gradio_speaker9.yaml",
473
- "1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4": "./configs/gradio_speaker1.yaml",
474
- "101099-00_18_09-00_18_19.mp4": "./configs/gradio.yaml",
475
  }
476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  @spaces.GPU(duration=1000)
478
  def tango(audio_path, character_name, create_graph=False, video_folder_path=None):
479
  saved_audio_path = "./saved_audio.wav"
@@ -488,10 +514,10 @@ def tango(audio_path, character_name, create_graph=False, video_folder_path=None
488
  sf.write(saved_audio_path, resampled_audio, 16000)
489
  audio_path = saved_audio_path
490
 
491
- yaml_name = character_name_to_yaml.get(character_name.split("/")[-1], "./configs/gradio.yaml")
 
492
  print(yaml_name, character_name.split("/")[-1])
493
- cfg = prepare_all(yaml_name)
494
-
495
  if character_name.split("/")[-1] not in character_name_to_yaml.keys():
496
  create_graph=True
497
  # load video, and save it to "./save_video.mp4 for the first 20s of the video."
@@ -510,34 +536,11 @@ def tango(audio_path, character_name, create_graph=False, video_folder_path=None
510
  local_rank = 0
511
  torch.cuda.set_device(local_rank)
512
  device = torch.device("cuda", local_rank)
513
- seed_everything(cfg.seed)
514
-
515
- experiment_ckpt_dir = experiment_log_dir = os.path.join(cfg.output_dir, cfg.exp_name)
516
- smplx_model = smplx.create(
517
- "./emage/smplx_models/",
518
- model_type='smplx',
519
- gender='NEUTRAL_2020',
520
- use_face_contour=False,
521
- num_betas=300,
522
- num_expression_coeffs=100,
523
- ext='npz',
524
- use_pca=False,
525
- ).to(device).eval()
526
- model = init_class(cfg.model.name_pyfile, cfg.model.class_name, cfg).to(device)
527
- for param in model.parameters():
528
- param.requires_grad = True
529
- # freeze wav2vec2
530
- for param in model.audio_encoder.parameters():
531
- param.requires_grad = False
532
- model.smplx_model = smplx_model
533
- model.get_motion_reps = get_motion_reps_tensor
534
-
535
- checkpoint_path = "./datasets/cached_ckpts/ckpt.pth"
536
- checkpoint = torch.load(checkpoint_path)
537
- state_dict = checkpoint['model_state_dict']
538
- new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
539
- model.load_state_dict(new_state_dict, strict=False)
540
-
541
  test_path = os.path.join(experiment_ckpt_dir, f"test_{0}")
542
  os.makedirs(test_path, exist_ok=True)
543
  result = test_fn(model, device, 0, cfg.data.test_meta_paths, test_path, cfg, audio_path)
@@ -632,7 +635,7 @@ def make_demo():
632
  inputs=[audio_input],
633
  outputs=[video_output_1, video_output_2, file_output_1, file_output_2],
634
  label="Select existing Audio examples",
635
- cache_examples=False
636
  )
637
  with gr.Column(scale=1):
638
  video_input = gr.Video(label="Your Character", elem_classes="video")
@@ -642,7 +645,7 @@ def make_demo():
642
  inputs=[video_input], # Correctly refer to video input
643
  outputs=[video_output_1, video_output_2, file_output_1, file_output_2],
644
  label="Character Examples",
645
- cache_examples=False
646
  )
647
 
648
  # Fourth row: Generate video button
 
467
 
468
 
469
  character_name_to_yaml = {
470
+ "speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4": "./datasets/data_json/youtube_test/speaker8.json",
471
+ "speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4": "./datasets/data_json/youtube_test/speaker7.json",
472
+ "speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4": "./datasets/data_json/youtube_test/speaker9.json",
473
+ "1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4": "./datasets/data_json/youtube_test/speaker1.json",
474
+ "101099-00_18_09-00_18_19.mp4": "./datasets/data_json/show_oliver_test/Stupid_Watergate_-_Last_Week_Tonight_with_John_Oliver_HBO-FVFdsl29s_Q.mkv.json",
475
  }
476
 
477
+ cfg = prepare_all("./configs/gradio.yaml")
478
+ seed_everything(cfg.seed)
479
+ experiment_ckpt_dir = experiment_log_dir = os.path.join(cfg.output_dir, cfg.exp_name)
480
+
481
+ smplx_model = smplx.create(
482
+ "./emage/smplx_models/",
483
+ model_type='smplx',
484
+ gender='NEUTRAL_2020',
485
+ use_face_contour=False,
486
+ num_betas=300,
487
+ num_expression_coeffs=100,
488
+ ext='npz',
489
+ use_pca=False,
490
+ )
491
+ model = init_class(cfg.model.name_pyfile, cfg.model.class_name, cfg)
492
+ for param in model.parameters():
493
+ param.requires_grad = False
494
+ model.smplx_model = smplx_model
495
+ model.get_motion_reps = get_motion_reps_tensor
496
+
497
+ checkpoint_path = "./datasets/cached_ckpts/ckpt.pth"
498
+ checkpoint = torch.load(checkpoint_path)
499
+ state_dict = checkpoint['model_state_dict']
500
+ # new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
501
+ model.load_state_dict(state_dict, strict=False)
502
+
503
  @spaces.GPU(duration=1000)
504
  def tango(audio_path, character_name, create_graph=False, video_folder_path=None):
505
  saved_audio_path = "./saved_audio.wav"
 
514
  sf.write(saved_audio_path, resampled_audio, 16000)
515
  audio_path = saved_audio_path
516
 
517
+ yaml_name = character_name_to_yaml.get(character_name.split("/")[-1], "./datasets/data_json/youtube_test/speaker1.json")
518
+ cfg.data.test_meta_paths = yaml_name
519
  print(yaml_name, character_name.split("/")[-1])
520
+
 
521
  if character_name.split("/")[-1] not in character_name_to_yaml.keys():
522
  create_graph=True
523
  # load video, and save it to "./save_video.mp4 for the first 20s of the video."
 
536
  local_rank = 0
537
  torch.cuda.set_device(local_rank)
538
  device = torch.device("cuda", local_rank)
539
+
540
+ smplx_model = smplx_model.to(device).eval()
541
+ model = model.to(device)
542
+ model.smplx_model = model.smplx_model.to(device)
543
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  test_path = os.path.join(experiment_ckpt_dir, f"test_{0}")
545
  os.makedirs(test_path, exist_ok=True)
546
  result = test_fn(model, device, 0, cfg.data.test_meta_paths, test_path, cfg, audio_path)
 
635
  inputs=[audio_input],
636
  outputs=[video_output_1, video_output_2, file_output_1, file_output_2],
637
  label="Select existing Audio examples",
638
+ cache_examples=True
639
  )
640
  with gr.Column(scale=1):
641
  video_input = gr.Video(label="Your Character", elem_classes="video")
 
645
  inputs=[video_input], # Correctly refer to video input
646
  outputs=[video_output_1, video_output_2, file_output_1, file_output_2],
647
  label="Character Examples",
648
+ cache_examples=True
649
  )
650
 
651
  # Fourth row: Generate video button