{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import torch\n",
"import numpy as np\n",
"import PIL\n",
"from PIL import Image\n",
"from IPython.display import HTML\n",
"from pyramid_dit import PyramidDiTForVideoGeneration\n",
"from IPython.display import Image as ipython_image\n",
"from diffusers.utils import load_image, export_to_video, export_to_gif"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"variant='diffusion_transformer_768p' # For high resolution\n",
"# variant='diffusion_transformer_384p' # For low resolution\n",
"\n",
"model_path = \"/home/jinyang06/models/pyramid-flow\" # The downloaded checkpoint dir\n",
"model_dtype = 'bf16'\n",
"\n",
"device_id = 0\n",
"torch.cuda.set_device(device_id)\n",
"\n",
"model = PyramidDiTForVideoGeneration(\n",
" model_path,\n",
" model_dtype,\n",
" model_variant=variant,\n",
")\n",
"\n",
"model.vae.to(\"cuda\")\n",
"model.dit.to(\"cuda\")\n",
"model.text_encoder.to(\"cuda\")\\\n",
"\n",
"if model_dtype == \"bf16\":\n",
" torch_dtype = torch.bfloat16 \n",
"elif model_dtype == \"fp16\":\n",
" torch_dtype = torch.float16\n",
"else:\n",
" torch_dtype = torch.float32\n",
"\n",
"\n",
"def show_video(ori_path, rec_path, width=\"100%\"):\n",
" html = ''\n",
" if ori_path is not None:\n",
" html += f\"\"\"\n",
" \"\"\"\n",
" \n",
" html += f\"\"\"\n",
" \"\"\"\n",
" return HTML(html)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Text-to-Video"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors\"\n",
"\n",
"# used for 384p model variant\n",
"# width = 640\n",
"# height = 384\n",
"\n",
"# used for 768p model variant\n",
"width = 1280\n",
"height = 768\n",
"\n",
"temp = 16 # temp in [1, 31] <=> frame in [1, 241] <=> duration in [0, 10s]\n",
"\n",
"model.vae.enable_tiling()\n",
"\n",
"with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
" frames = model.generate(\n",
" prompt=prompt,\n",
" num_inference_steps=[20, 20, 20],\n",
" video_num_inference_steps=[10, 10, 10],\n",
" height=height,\n",
" width=width,\n",
" temp=temp,\n",
" guidance_scale=9.0, # The guidance for the first frame\n",
" video_guidance_scale=5.0, # The guidance for the other video latent\n",
" output_type=\"pil\",\n",
" )\n",
"\n",
"export_to_video(frames, \"./text_to_video_sample.mp4\", fps=24)\n",
"show_video(None, \"./text_to_video_sample.mp4\", \"70%\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Image-to-Video"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image_path = 'assets/the_great_wall.jpg'\n",
"image = Image.open(image_path).convert(\"RGB\")\n",
"\n",
"width = 1280\n",
"height = 768\n",
"temp = 16\n",
"\n",
"image = image.resize((width, height))\n",
"\n",
"display(image)\n",
"\n",
"prompt = \"FPV flying over the Great Wall\"\n",
"\n",
"with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
" frames = model.generate_i2v(\n",
" prompt=prompt,\n",
" input_image=image,\n",
" num_inference_steps=[10, 10, 10],\n",
" temp=temp,\n",
" guidance_scale=7.0,\n",
" video_guidance_scale=4.0,\n",
" output_type=\"pil\",\n",
" )\n",
"\n",
"export_to_video(frames, \"./image_to_video_sample.mp4\", fps=24)\n",
"show_video(None, \"./image_to_video_sample.mp4\", \"70%\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}