hvaldez's picture
first commit
c18a21e verified
raw
history blame
1.74 kB
import gradio as gr
from demo import VideoCLSModel
sample_videos = [
'data/charades_ego/video/P9SOAEGO.mp4',
'data/charades_ego/video/6D5DHEGO.mp4',
'data/charades_ego/video/15AKPEGO.mp4',
'data/charades_ego/video/X2JTKEGO.mp4',
'data/charades_ego/video/184EHEGO.mp4',
'data/charades_ego/video/S8YZIEGO.mp4',
'data/charades_ego/video/PRODQEGO.mp4',
'data/charades_ego/video/QLXEXEGO.mp4',
'data/charades_ego/video/CC0LBEGO.mp4',
'data/charades_ego/video/FLY2FEGO.mp4'
]
def main():
svitt = VideoCLSModel("configs/charades_ego/svitt.yml")
def predict(video_str):
video_file = video_str.split('/')[-1]
for i, item in enumerate(sample_videos):
if video_file in item:
idx = i
break
ft_action, gt_action = svitt.predict(idx)
return gt_action, ft_action
with gr.Blocks() as demo:
gr.Markdown(
"""
# SViTT-Ego for Action Recognition
Choose a sample video and click predict to view the results.
"""
)
with gr.Row():
idx = gr.Number(label="Idx", visible=False)
video = gr.Video(label='video', format='mp4', autoplay=True, height=256, width=256)
with gr.Row():
label = gr.Text(label="Ground Truth")
ours = gr.Text(label="SViTT-Ego prediction")
with gr.Row():
btn = gr.Button("Predict", variant="primary")
btn.click(predict, inputs=[video], outputs=[label, ours])
with gr.Column():
gr.Examples(examples=[[x] for _, x in enumerate(sample_videos)], inputs=[video])
demo.launch()
if __name__ == "__main__":
main()