import multiprocessing import os import time from typing import List, Tuple import cv2 import numpy as np import pandas as pd import streamlit as st import torch from torch import Tensor from transformers import AutoFeatureExtractor, TimesformerForVideoClassification np.random.seed(0) st.set_page_config( page_title="TimeSFormer", page_icon="🧊", layout="wide", initial_sidebar_state="expanded", menu_items={ "Get Help": "https://www.extremelycoolapp.com/help", "Report a bug": "https://www.extremelycoolapp.com/bug", "About": "# This is a header. This is an *extremely* cool app!", }, ) def sample_frame_indices( clip_len: int, frame_sample_rate: float, seg_len: int ) -> np.ndarray: converted_len = int(clip_len * frame_sample_rate) end_idx = np.random.randint(converted_len, seg_len) start_idx = end_idx - converted_len indices = np.linspace(start_idx, end_idx, num=clip_len) indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) return indices # @st.cache_resource @st.experimental_singleton def load_model(model_name: str): if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name: feature_extractor = AutoFeatureExtractor.from_pretrained( "MCG-NJU/videomae-base-finetuned-kinetics" ) else: feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) model = TimesformerForVideoClassification.from_pretrained(model_name) return feature_extractor, model def read_video(file_path: str, frames_per_video: int = 8) -> np.ndarray: cap = cv2.VideoCapture(file_path) length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 1000 frames print("Number of frames", length) indices = sample_frame_indices( clip_len=frames_per_video, frame_sample_rate=4, seg_len=length ) frames: List[np.array] = [] for i in indices: cap.set(1, i) ret, frame = cap.read() if not ret: continue frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) return np.array(frames) def read_video_decord(file_path: str) -> np.ndarray: from decord import VideoReader, cpu videoreader = VideoReader(VIDEO_TMP_PATH, num_threads=1, ctx=cpu(0)) # sample 8 frames videoreader.seek(0) indices = sample_frame_indices( clip_len=8, frame_sample_rate=4, seg_len=len(videoreader) ) video = videoreader.get_batch(indices).asnumpy() # print(video.shape) # (8, 720, 1280, 3) return video def inference(file_path: str, frames_per_video: int = 8): video = read_video(file_path, frames_per_video) inputs = feature_extractor(list(video), return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits: Tensor = outputs.logits # model predicts one of the 400 Kinetics-400 classes predicted_label = logits.argmax(-1).item() print(model.config.id2label[predicted_label]) TOP_K = 12 # logits = np.squeeze(logits) logits = logits.squeeze().numpy() indices = np.argsort(logits)[::-1][:TOP_K] values = logits[indices] results: List[Tuple[str, float]] = [] for index, value in zip(indices, values): predicted_label = model.config.id2label[index] # print(f"Label: {predicted_label} - {value:.2f}%") results.append((predicted_label, value)) return pd.DataFrame(results, columns=("Label", "Confidence")) def get_frames_per_video(model_name: str) -> int: if "base-finetuned" in model_name: return 8 elif "hr-finetuned" in model_name: return 16 else: return 96 st.title("TimeSFormer") with st.expander("INTRODUCTION"): st.text( f"""Streamlit demo for TimeSFormer. Number of CPU(s): {multiprocessing.cpu_count()} """ ) model_name = st.selectbox( "model_name", ( "facebook/timesformer-base-finetuned-k400", "facebook/timesformer-base-finetuned-k600", "facebook/timesformer-base-finetuned-ssv2", "facebook/timesformer-hr-finetuned-k600", "facebook/timesformer-hr-finetuned-k400", "facebook/timesformer-hr-finetuned-ssv2", "fcakyon/timesformer-large-finetuned-k400", "fcakyon/timesformer-large-finetuned-k600", ), ) feature_extractor, model = load_model(model_name) frames_per_video = get_frames_per_video(model_name) st.info(f"Frames per video: {frames_per_video}") VIDEO_TMP_PATH = os.path.join("tmp", "tmp.mp4") uploadedfile = st.file_uploader("Upload file", type=["mp4"]) if uploadedfile is not None: with st.spinner(): with open(VIDEO_TMP_PATH, "wb") as f: f.write(uploadedfile.getbuffer()) start_time = time.time() with st.spinner("Processing..."): df = inference(VIDEO_TMP_PATH, frames_per_video) end_time = time.time() st.info(f"{end_time - start_time} seconds") st.dataframe(df) st.video(VIDEO_TMP_PATH)