zR commited on
Commit
ed512fd
1 Parent(s): 0922892
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. README.md +5 -2
  3. app.py +190 -32
  4. requirements.txt +11 -13
  5. rife_model.py +1 -5
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  models/RealESRGAN_x4.pth filter=lfs diff=lfs merge=lfs -text
37
  models/flownet.pkl filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  models/RealESRGAN_x4.pth filter=lfs diff=lfs merge=lfs -text
37
  models/flownet.pkl filter=lfs diff=lfs merge=lfs -text
38
+ horse.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -5,7 +5,8 @@ colorFrom: yellow
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.42.0
8
- suggested_hardware: l4x1
 
9
  app_port: 7860
10
  app_file: app.py
11
  models:
@@ -41,4 +42,6 @@ pip install -r requirements.txt
41
 
42
  ```bash
43
  python gradio_web_demo.py
44
- ```
 
 
 
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.42.0
8
+ suggested_hardware: a10g-large
9
+ suggested_storage: large
10
  app_port: 7860
11
  app_file: app.py
12
  models:
 
42
 
43
  ```bash
44
  python gradio_web_demo.py
45
+ ```
46
+
47
+
app.py CHANGED
@@ -1,12 +1,31 @@
 
 
 
 
 
 
 
 
1
  import math
2
  import os
3
  import random
4
  import threading
5
  import time
6
 
 
 
 
7
  import gradio as gr
8
  import torch
9
- from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler,CogVideoXDPMScheduler
 
 
 
 
 
 
 
 
10
  from datetime import datetime, timedelta
11
 
12
  from diffusers.image_processor import VaeImageProcessor
@@ -23,9 +42,33 @@ snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
23
 
24
  pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(device)
25
  pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  pipe.transformer.to(memory_format=torch.channels_last)
28
  pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
 
 
29
 
30
  os.makedirs("./output", exist_ok=True)
31
  os.makedirs("./gradio_tmp", exist_ok=True)
@@ -47,6 +90,80 @@ Video descriptions must have the same num of words as examples below. Extra word
47
  """
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def convert_prompt(prompt: str, retry_times: int = 3) -> str:
51
  if not os.environ.get("OPENAI_API_KEY"):
52
  return prompt
@@ -86,7 +203,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
86
  "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
87
  },
88
  ],
89
- model="glm-4-0520",
90
  temperature=0.01,
91
  top_p=0.7,
92
  stream=False,
@@ -98,24 +215,55 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
98
 
99
 
100
  def infer(
101
- prompt: str,
102
- num_inference_steps: int,
103
- guidance_scale: float,
104
- seed: int = -1,
105
- #progress=gr.Progress(track_tqdm=True),
 
 
 
106
  ):
107
  if seed == -1:
108
- seed = random.randint(0, 2 ** 8 - 1)
109
- video_pt = pipe(
110
- prompt=prompt,
111
- num_videos_per_prompt=1,
112
- num_inference_steps=num_inference_steps,
113
- num_frames=49,
114
- use_dynamic_cfg=True,
115
- output_type="pt",
116
- guidance_scale=guidance_scale,
117
- generator=torch.Generator(device="cpu").manual_seed(seed),
118
- ).frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  return (video_pt, seed)
121
 
@@ -146,6 +294,7 @@ def delete_old_files():
146
 
147
 
148
  threading.Thread(target=delete_old_files, daemon=True).start()
 
149
 
150
  with gr.Blocks() as demo:
151
  gr.Markdown("""
@@ -166,10 +315,15 @@ with gr.Blocks() as demo:
166
  <div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
167
  ⚠️ This demo is for academic research and experiential use only.
168
  </div>
169
-
170
  """)
171
  with gr.Row():
172
  with gr.Column():
 
 
 
 
 
 
173
  prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
174
 
175
  with gr.Row():
@@ -184,7 +338,7 @@ with gr.Blocks() as demo:
184
  label="Inference Seed (Enter a positive number, -1 for random)", value=-1
185
  )
186
  with gr.Row():
187
- enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 1440 × 960)", value=False)
188
  enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False)
189
  gr.Markdown(
190
  "✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).<br>&nbsp;&nbsp;&nbsp;&nbsp;The entire process is based on open-source solutions."
@@ -263,20 +417,25 @@ with gr.Blocks() as demo:
263
  </table>
264
  """)
265
 
266
-
267
- def generate(prompt,
268
- seed_value,
269
- scale_status,
270
- rife_status,
271
- progress=gr.Progress(track_tqdm=True)
272
- ):
273
-
 
 
274
  latents, seed = infer(
275
  prompt,
 
 
 
276
  num_inference_steps=50, # NOT Changed
277
  guidance_scale=7.0, # NOT Changed
278
  seed=seed_value,
279
- #progress=progress,
280
  )
281
  if scale_status:
282
  latents = utils.upscale_batch_and_concatenate(upscale_model, latents, device)
@@ -301,18 +460,17 @@ with gr.Blocks() as demo:
301
 
302
  return video_path, video_update, gif_update, seed_update
303
 
304
-
305
  def enhance_prompt_func(prompt):
306
  return convert_prompt(prompt, retry_times=1)
307
 
308
-
309
  generate_button.click(
310
  generate,
311
- inputs=[prompt, seed_param, enable_scale, enable_rife],
312
  outputs=[video_output, download_video_button, download_gif_button, seed_text],
313
  )
314
 
315
  enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
 
316
 
317
  if __name__ == "__main__":
318
  demo.queue(max_size=15)
 
1
+ """
2
+ THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to generate videos gradio web demo.
3
+ set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
4
+
5
+ Usage:
6
+ OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
7
+ """
8
+
9
  import math
10
  import os
11
  import random
12
  import threading
13
  import time
14
 
15
+ import cv2
16
+ import tempfile
17
+ import imageio_ffmpeg
18
  import gradio as gr
19
  import torch
20
+ from PIL import Image
21
+ from diffusers import (
22
+ CogVideoXPipeline,
23
+ CogVideoXDPMScheduler,
24
+ CogVideoXVideoToVideoPipeline,
25
+ CogVideoXImageToVideoPipeline,
26
+ CogVideoXTransformer3DModel,
27
+ )
28
+ from diffusers.utils import load_video, load_image
29
  from datetime import datetime, timedelta
30
 
31
  from diffusers.image_processor import VaeImageProcessor
 
42
 
43
  pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(device)
44
  pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
45
+ pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
46
+ "THUDM/CogVideoX-5b",
47
+ transformer=pipe.transformer,
48
+ vae=pipe.vae,
49
+ scheduler=pipe.scheduler,
50
+ tokenizer=pipe.tokenizer,
51
+ text_encoder=pipe.text_encoder,
52
+ torch_dtype=torch.bfloat16,
53
+ ).to(device)
54
+
55
+ pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
56
+ "THUDM/CogVideoX-5b",
57
+ transformer=CogVideoXTransformer3DModel.from_pretrained(
58
+ "THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=torch.bfloat16
59
+ ),
60
+ vae=pipe.vae,
61
+ scheduler=pipe.scheduler,
62
+ tokenizer=pipe.tokenizer,
63
+ text_encoder=pipe.text_encoder,
64
+ torch_dtype=torch.bfloat16,
65
+ ).to(device)
66
+
67
 
68
  pipe.transformer.to(memory_format=torch.channels_last)
69
  pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
70
+ pipe_image.transformer.to(memory_format=torch.channels_last)
71
+ pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True)
72
 
73
  os.makedirs("./output", exist_ok=True)
74
  os.makedirs("./gradio_tmp", exist_ok=True)
 
90
  """
91
 
92
 
93
+ def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)):
94
+ width, height = get_video_dimensions(input_video)
95
+
96
+ if width == 720 and height == 480:
97
+ processed_video = input_video
98
+ else:
99
+ processed_video = center_crop_resize(input_video)
100
+ return processed_video
101
+
102
+
103
+ def get_video_dimensions(input_video_path):
104
+ reader = imageio_ffmpeg.read_frames(input_video_path)
105
+ metadata = next(reader)
106
+ return metadata["size"]
107
+
108
+
109
+ def center_crop_resize(input_video_path, target_width=720, target_height=480):
110
+ cap = cv2.VideoCapture(input_video_path)
111
+
112
+ orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
113
+ orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
114
+ orig_fps = cap.get(cv2.CAP_PROP_FPS)
115
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
116
+
117
+ width_factor = target_width / orig_width
118
+ height_factor = target_height / orig_height
119
+ resize_factor = max(width_factor, height_factor)
120
+
121
+ inter_width = int(orig_width * resize_factor)
122
+ inter_height = int(orig_height * resize_factor)
123
+
124
+ target_fps = 8
125
+ ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1)
126
+ skip = min(5, ideal_skip) # Cap at 5
127
+
128
+ while (total_frames / (skip + 1)) < 49 and skip > 0:
129
+ skip -= 1
130
+
131
+ processed_frames = []
132
+ frame_count = 0
133
+ total_read = 0
134
+
135
+ while frame_count < 49 and total_read < total_frames:
136
+ ret, frame = cap.read()
137
+ if not ret:
138
+ break
139
+
140
+ if total_read % (skip + 1) == 0:
141
+ resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA)
142
+
143
+ start_x = (inter_width - target_width) // 2
144
+ start_y = (inter_height - target_height) // 2
145
+ cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width]
146
+
147
+ processed_frames.append(cropped)
148
+ frame_count += 1
149
+
150
+ total_read += 1
151
+
152
+ cap.release()
153
+
154
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
155
+ temp_video_path = temp_file.name
156
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
157
+ out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height))
158
+
159
+ for frame in processed_frames:
160
+ out.write(frame)
161
+
162
+ out.release()
163
+
164
+ return temp_video_path
165
+
166
+
167
  def convert_prompt(prompt: str, retry_times: int = 3) -> str:
168
  if not os.environ.get("OPENAI_API_KEY"):
169
  return prompt
 
203
  "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
204
  },
205
  ],
206
+ model="glm-4-plus",
207
  temperature=0.01,
208
  top_p=0.7,
209
  stream=False,
 
215
 
216
 
217
  def infer(
218
+ prompt: str,
219
+ image_input: str,
220
+ video_input: str,
221
+ video_strenght: float,
222
+ num_inference_steps: int,
223
+ guidance_scale: float,
224
+ seed: int = -1,
225
+ progress=gr.Progress(track_tqdm=True),
226
  ):
227
  if seed == -1:
228
+ seed = random.randint(0, 2**8 - 1)
229
+
230
+ if video_input is not None:
231
+ video = load_video(video_input)[:49] # Limit to 49 frames
232
+ video_pt = pipe_video(
233
+ video=video,
234
+ prompt=prompt,
235
+ num_inference_steps=num_inference_steps,
236
+ num_videos_per_prompt=1,
237
+ strength=video_strenght,
238
+ use_dynamic_cfg=True,
239
+ output_type="pt",
240
+ guidance_scale=guidance_scale,
241
+ generator=torch.Generator(device="cpu").manual_seed(seed),
242
+ ).frames
243
+ elif image_input is not None:
244
+ image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
245
+ image = load_image(image_input)
246
+ video_pt = pipe_image(
247
+ image=image,
248
+ prompt=prompt,
249
+ num_inference_steps=num_inference_steps,
250
+ num_videos_per_prompt=1,
251
+ use_dynamic_cfg=True,
252
+ output_type="pt",
253
+ guidance_scale=guidance_scale,
254
+ generator=torch.Generator(device="cpu").manual_seed(seed),
255
+ ).frames
256
+ else:
257
+ video_pt = pipe(
258
+ prompt=prompt,
259
+ num_videos_per_prompt=1,
260
+ num_inference_steps=num_inference_steps,
261
+ num_frames=49,
262
+ use_dynamic_cfg=True,
263
+ output_type="pt",
264
+ guidance_scale=guidance_scale,
265
+ generator=torch.Generator(device="cpu").manual_seed(seed),
266
+ ).frames
267
 
268
  return (video_pt, seed)
269
 
 
294
 
295
 
296
  threading.Thread(target=delete_old_files, daemon=True).start()
297
+ examples = [["horse.mp4"], ["kitten.mp4"], ["train_running.mp4"]]
298
 
299
  with gr.Blocks() as demo:
300
  gr.Markdown("""
 
315
  <div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
316
  ⚠️ This demo is for academic research and experiential use only.
317
  </div>
 
318
  """)
319
  with gr.Row():
320
  with gr.Column():
321
+ with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
322
+ image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
323
+ with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False):
324
+ video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
325
+ strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
326
+ examples_component = gr.Examples(examples, inputs=[video_input], cache_examples=False)
327
  prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
328
 
329
  with gr.Row():
 
338
  label="Inference Seed (Enter a positive number, -1 for random)", value=-1
339
  )
340
  with gr.Row():
341
+ enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False)
342
  enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False)
343
  gr.Markdown(
344
  "✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).<br>&nbsp;&nbsp;&nbsp;&nbsp;The entire process is based on open-source solutions."
 
417
  </table>
418
  """)
419
 
420
+ def generate(
421
+ prompt,
422
+ image_input,
423
+ video_input,
424
+ video_strength,
425
+ seed_value,
426
+ scale_status,
427
+ rife_status,
428
+ progress=gr.Progress(track_tqdm=True)
429
+ ):
430
  latents, seed = infer(
431
  prompt,
432
+ image_input,
433
+ video_input,
434
+ video_strength,
435
  num_inference_steps=50, # NOT Changed
436
  guidance_scale=7.0, # NOT Changed
437
  seed=seed_value,
438
+ progress=progress,
439
  )
440
  if scale_status:
441
  latents = utils.upscale_batch_and_concatenate(upscale_model, latents, device)
 
460
 
461
  return video_path, video_update, gif_update, seed_update
462
 
 
463
  def enhance_prompt_func(prompt):
464
  return convert_prompt(prompt, retry_times=1)
465
 
 
466
  generate_button.click(
467
  generate,
468
+ inputs=[prompt, image_input, video_input, strength, seed_param, enable_scale, enable_rife],
469
  outputs=[video_output, download_video_button, download_gif_button, seed_text],
470
  )
471
 
472
  enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
473
+ video_input.upload(resize_if_unfit, inputs=[video_input], outputs=[video_input])
474
 
475
  if __name__ == "__main__":
476
  demo.queue(max_size=15)
requirements.txt CHANGED
@@ -1,21 +1,19 @@
1
- spaces==0.29.3
2
- safetensors>=0.4.4
3
- spandrel>=0.3.4
4
  tqdm>=4.66.5
5
- opencv-python>=4.10.0.84
6
  scikit-video>=1.1.11
7
- diffusers>=0.30.1
8
  transformers>=4.44.0
9
- accelerate>=0.33.0
 
10
  sentencepiece>=0.2.0
11
- SwissArmyTransformer>=0.4.12
12
  numpy==1.26.0
13
  torch>=2.4.0
14
  torchvision>=0.19.0
15
- gradio>=4.42.0
16
- streamlit>=1.37.1
17
- imageio==2.34.2
18
- imageio-ffmpeg==0.5.1
19
- openai>=1.42.0
20
- moviepy==1.0.3
21
  pillow==9.5.0
 
1
+ spaces>=0.29.3
2
+ safetensors>=0.4.5
3
+ spandrel>=0.4.0
4
  tqdm>=4.66.5
 
5
  scikit-video>=1.1.11
6
+ git+https://github.com/huggingface/diffusers.git@main
7
  transformers>=4.44.0
8
+ accelerate>=0.34.2
9
+ opencv-python>=4.10.0.84
10
  sentencepiece>=0.2.0
 
11
  numpy==1.26.0
12
  torch>=2.4.0
13
  torchvision>=0.19.0
14
+ gradio>=4.44.0
15
+ imageio>=2.34.2
16
+ imageio-ffmpeg>=0.5.1
17
+ openai>=1.45.0
18
+ moviepy>=1.0.3
 
19
  pillow==9.5.0
rife_model.py CHANGED
@@ -10,7 +10,6 @@ import skvideo.io
10
  from rife.RIFE_HDv3 import Model
11
 
12
  logger = logging.getLogger(__name__)
13
-
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
 
@@ -37,8 +36,7 @@ def make_inference(model, I0, I1, upscale_amount, n):
37
 
38
  @torch.inference_mode()
39
  def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_device="cpu"):
40
- print(f"samples dtype:{samples.dtype}")
41
- print(f"samples shape:{samples.shape}")
42
  output = []
43
  # [f, c, h, w]
44
  for b in range(samples.shape[0]):
@@ -119,13 +117,11 @@ def rife_inference_with_path(model, video_path):
119
 
120
 
121
  def rife_inference_with_latents(model, latents):
122
- pbar = utils.ProgressBar(latents.shape[1], desc="RIFE inference")
123
  rife_results = []
124
  latents = latents.to(device)
125
  for i in range(latents.size(0)):
126
  # [f, c, w, h]
127
  latent = latents[i]
128
-
129
  frames = ssim_interpolation_rife(model, latent)
130
  pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h])
131
  rife_results.append(pt_image)
 
10
  from rife.RIFE_HDv3 import Model
11
 
12
  logger = logging.getLogger(__name__)
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
 
 
36
 
37
  @torch.inference_mode()
38
  def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_device="cpu"):
39
+
 
40
  output = []
41
  # [f, c, h, w]
42
  for b in range(samples.shape[0]):
 
117
 
118
 
119
  def rife_inference_with_latents(model, latents):
 
120
  rife_results = []
121
  latents = latents.to(device)
122
  for i in range(latents.size(0)):
123
  # [f, c, w, h]
124
  latent = latents[i]
 
125
  frames = ssim_interpolation_rife(model, latent)
126
  pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h])
127
  rife_results.append(pt_image)