CogVideoX / app.py
oahzxl's picture
update ui
09f1eaa
raw
history blame contribute delete
No virus
5.97 kB
import os
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), ".tmp_outputs")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import uuid
import gradio as gr
import spaces
from videosys import CogVideoXConfig, CogVideoXPABConfig, VideoSysEngine
def load_model(model_name, enable_video_sys=False, pab_threshold=[100, 850], pab_range=2):
pab_config = CogVideoXPABConfig(spatial_threshold=pab_threshold, spatial_range=pab_range)
config = CogVideoXConfig(model_name, enable_pab=enable_video_sys, pab_config=pab_config)
engine = VideoSysEngine(config)
return engine
def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
unique_filename = f"{uuid.uuid4().hex}.mp4"
output_path = os.path.join("./.tmp_outputs", unique_filename)
engine.save_video(video, output_path)
return output_path
@spaces.GPU(duration=400)
def generate_vs(
model_name,
prompt,
num_inference_steps,
guidance_scale,
threshold_start,
threshold_end,
gap,
progress=gr.Progress(track_tqdm=True),
):
threshold = [int(threshold_end), int(threshold_start)]
gap = int(gap)
engine = load_model(model_name, enable_video_sys=True, pab_threshold=threshold, pab_range=gap)
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
return video_path
css = """
body {
font-family: Arial, sans-serif;
line-height: 1.6;
color: #333;
margin: 0 auto;
padding: 20px;
}
.container {
display: flex;
flex-direction: column;
gap: 10px;
}
.row {
display: flex;
flex-wrap: wrap;
gap: 10px;
}
.column {
flex: 1;
min-width: 0;
}
.video-output {
width: 100%;
max-width: 720px;
height: auto;
margin: 0 auto;
}
.server-status {
margin-top: 5px;
padding: 5px;
font-size: 0.8em;
}
.server-status h4 {
margin: 0 0 3px 0;
font-size: 0.9em;
}
.server-status .row {
margin-bottom: 2px;
}
.server-status .textbox {
min-height: unset !important;
}
.server-status .textbox input {
padding: 1px 5px !important;
height: 20px !important;
font-size: 0.9em !important;
}
.server-status .textbox label {
margin-bottom: 0 !important;
font-size: 0.9em !important;
line-height: 1.2 !important;
}
.server-status .textbox {
gap: 0 !important;
}
.server-status .textbox input {
margin-top: -2px !important;
}
@media (max-width: 768px) {
.row {
flex-direction: column;
}
.column {
width: 100%;
}
}
.video-output {
width: 100%;
height: auto;
}
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
VideoSys for CogVideoX🤗
</div>
<div style="text-align: center; font-size: 15px;">
🌐 Github: <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys">https://github.com/NUS-HPC-AI-Lab/VideoSys</a><br>
⚠️ This demo is for academic research and experiential use only.
Users should strictly adhere to local laws and ethics.<br>
💡 This demo only demonstrates single-device inference. To experience the full power of VideoSys, please deploy it with multiple devices.<br><br>
</div>
</div>
"""
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=2)
with gr.Column():
gr.Markdown("**Generation Parameters**<br>")
with gr.Row():
model_name = gr.Radio(
["THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b"], label="Model Type", value="THUDM/CogVideoX-2b"
)
with gr.Row():
num_inference_steps = gr.Slider(label="Inference Steps", maximum=50, value=50)
guidance_scale = gr.Slider(label="Guidance Scale", value=6.0, maximum=15.0)
gr.Markdown("**Pyramid Attention Broadcast Parameters**<br>")
with gr.Row():
pab_range = gr.Slider(
label="Broadcast Range",
value=2,
step=1,
minimum=1,
maximum=4,
info="Attention broadcast range.",
)
pab_threshold_start = gr.Slider(
label="Start Timestep",
minimum=500,
maximum=1000,
value=850,
step=1,
info="Broadcast start timestep (1000 is the fisrt).",
)
pab_threshold_end = gr.Slider(
label="End Timestep",
minimum=0,
maximum=500,
step=1,
value=100,
info="Broadcast end timestep (0 is the last).",
)
with gr.Row():
generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys")
with gr.Column():
with gr.Row():
video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
generate_button_vs.click(
generate_vs,
inputs=[
model_name,
prompt,
num_inference_steps,
guidance_scale,
pab_threshold_start,
pab_threshold_end,
pab_range,
],
outputs=[video_output_vs],
concurrency_id="gen",
concurrency_limit=1,
)
if __name__ == "__main__":
demo.queue(max_size=10, default_concurrency_limit=1)
demo.launch()