Ji4chenLi commited on
Commit
e13087b
1 Parent(s): a6dd76e

update app.py to link space and models

Browse files
app.py CHANGED
@@ -10,7 +10,9 @@ import torch
10
  import torchvision
11
  import gradio as gr
12
  import numpy as np
 
13
  from gradio.components import Textbox, Video
 
14
 
15
  from utils.common_utils import load_model_checkpoint
16
  from utils.utils import instantiate_from_config
@@ -144,7 +146,9 @@ if __name__ == "__main__":
144
  config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml")
145
  model_config = config.pop("model", OmegaConf.create())
146
  pretrained_t2v = instantiate_from_config(model_config)
147
- pretrained_t2v = load_model_checkpoint(pretrained_t2v, "checkpoints/VideoCrafter2_model.ckpt")
 
 
148
 
149
  unet_config = model_config["params"]["unet_config"]
150
  unet_config["params"]["use_checkpoint"] = False
@@ -153,7 +157,8 @@ if __name__ == "__main__":
153
 
154
  unet = instantiate_from_config(unet_config)
155
 
156
- unet.load_state_dict(torch.load("checkpoints/unet_mg.pt", map_location=device))
 
157
  unet.eval()
158
 
159
  pretrained_t2v.model.diffusion_model = unet
 
10
  import torchvision
11
  import gradio as gr
12
  import numpy as np
13
+
14
  from gradio.components import Textbox, Video
15
+ from huggingface_hub import hf_hub_download
16
 
17
  from utils.common_utils import load_model_checkpoint
18
  from utils.utils import instantiate_from_config
 
146
  config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml")
147
  model_config = config.pop("model", OmegaConf.create())
148
  pretrained_t2v = instantiate_from_config(model_config)
149
+
150
+ pretrained_path = hf_hub_download("VideoCrafter/VideoCrafter2", filename="model.ckpt")
151
+ pretrained_t2v = load_model_checkpoint(pretrained_t2v, pretrained_path)
152
 
153
  unet_config = model_config["params"]["unet_config"]
154
  unet_config["params"]["use_checkpoint"] = False
 
157
 
158
  unet = instantiate_from_config(unet_config)
159
 
160
+ unet_path = hf_hub_download(repo_id="jiachenli-ucsb/T2V-Turbo-v2", filename="unet_mg.pt")
161
+ unet.load_state_dict(torch.load(unet_path, map_location=device))
162
  unet.eval()
163
 
164
  pretrained_t2v.model.diffusion_model = unet
checkpoints/VideoCrafter2_model.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1edf769ece3308e977228943eeeed39286806aba9da17350449a3fbf4324ccfb
3
- size 7404653244
 
 
 
 
checkpoints/unet_mg.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:92c8767b40a5b2737dd3c69f5f13dae222ead5bd4befbbf894ca870231db13bc
3
- size 5655143958