Spaces:
Runtime error
Runtime error
File size: 6,122 Bytes
70b3e07 dc332e5 e00e6d6 1def0a4 1e87f84 9e3c23c 1e87f84 1def0a4 9e3c23c 254ea49 1e87f84 9e3c23c 1def0a4 9e3c23c ce95383 9e3c23c 254ea49 70b3e07 a3844a2 e5b9cea 9e3c23c 1def0a4 254ea49 9e3c23c 1def0a4 9e3c23c 70b3e07 e00e6d6 dc332e5 e00e6d6 dc332e5 e00e6d6 dc332e5 70b3e07 dd7849d 70b3e07 254ea49 1def0a4 1e87f84 1def0a4 9e3c23c 1def0a4 9e3c23c 1def0a4 9e3c23c 1def0a4 6dc33a4 1def0a4 e00e6d6 e5b9cea e00e6d6 9e3c23c 1def0a4 6dc33a4 1e87f84 6dc33a4 9e3c23c 6dc33a4 9e3c23c 1def0a4 9e3c23c 1def0a4 e5b9cea 9e3c23c 1def0a4 9e3c23c 1def0a4 d8653f1 1def0a4 d8653f1 1def0a4 9e3c23c 1def0a4 6b89aad 1def0a4 6b89aad 1def0a4 6b89aad 1def0a4 9e3c23c 1def0a4 d8653f1 1def0a4 1e87f84 1def0a4 9e3c23c 1def0a4 9e3c23c 1def0a4 6b89aad dc332e5 1def0a4 6b89aad 9e3c23c 1def0a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import json
from datetime import datetime
from time import time
from typing import List, Optional, Tuple
import cv2
import pandas as pd
import torch
from tap import Tap
from torch import Tensor
from transformers import (
AutoFeatureExtractor,
TimesformerForVideoClassification,
VideoMAEFeatureExtractor,
)
from utils.img_container import ImgContainer
class ArgParser(Tap):
is_recording: Optional[bool] = False
# "facebook/timesformer-base-finetuned-k400"
# "facebook/timesformer-base-finetuned-k600",
# "facebook/timesformer-base-finetuned-ssv2",
# "facebook/timesformer-hr-finetuned-k600",
# "facebook/timesformer-hr-finetuned-k400",
# "facebook/timesformer-hr-finetuned-ssv2",
# "fcakyon/timesformer-large-finetuned-k400",
# "fcakyon/timesformer-large-finetuned-k600",
model_name: Optional[str] = "facebook/timesformer-base-finetuned-k400"
num_skip_frames: Optional[int] = 2
top_k: Optional[int] = 5
id2label: Optional[str] = "labels/kinetics_400.json"
threshold: Optional[float] = 10.0 # 10.0
max_confidence: Optional[float] = 20.0 # Set None if not scale
class ActivityModel:
def __init__(self, args: ArgParser):
self.feature_extractor, self.model = self.load_model(args.model_name)
self.args = args
self.frames_per_video = self.get_frames_per_video(args.model_name)
print(f"Frames per video: {self.frames_per_video}")
self.load_json()
self.diary: List[
Tuple[str, int, str, float]
] = [] # [time, activity, confidence]
def save_diary(self):
df = pd.DataFrame(
self.diary, columns=["time", "timestamp", "activity", "confidence"]
)
df.to_csv("diary.csv")
df.to_excel("diary.xlsx")
def load_json(self):
if args.id2label is not None:
with open(args.id2label, encoding="utf-8") as f:
tmp = json.load(f)
d = dict()
for key, item in tmp.items():
d[int(key)] = item
self.model.config.id2label = d
def load_model(
self, model_name: str
) -> Tuple[VideoMAEFeatureExtractor, TimesformerForVideoClassification]:
if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name:
feature_extractor = AutoFeatureExtractor.from_pretrained(
"MCG-NJU/videomae-base-finetuned-kinetics"
)
else:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = TimesformerForVideoClassification.from_pretrained(model_name)
return feature_extractor, model
def inference(self, img_container: ImgContainer):
if not img_container.ready:
return
inputs = self.feature_extractor(list(img_container.imgs), return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits: Tensor = outputs.logits
# model predicts one of the 400 Kinetics-400 classes
max_index = logits.argmax(-1).item()
if max_index not in self.model.config.id2label:
return
predicted_label = self.model.config.id2label[max_index]
confidence = logits[0][max_index].item()
if (self.args.threshold is None) or (
self.args.threshold is not None and confidence >= self.args.threshold
):
img_container.frame_rate.label = f"{predicted_label}_{confidence:.2f}%"
self.diary.append(
(str(datetime.now()), int(time()), predicted_label, confidence)
)
# logits = np.squeeze(logits)
# logits = logits.squeeze().numpy()
# indices = np.argsort(logits)[::-1][: self.args.top_k]
# values = logits[indices]
# results: List[Tuple[str, float]] = []
# for index, value in zip(indices, values):
# predicted_label = self.model.config.id2label[index]
# # print(f"Label: {predicted_label} - {value:.2f}%")
# results.append((predicted_label, value))
# img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence"))
def get_frames_per_video(self, model_name: str) -> int:
if "base-finetuned" in model_name:
return 8
elif "hr-finetuned" in model_name:
return 16
else:
return 96
def main(args: ArgParser):
activity_model = ActivityModel(args)
img_container = ImgContainer(activity_model.frames_per_video, args.is_recording)
num_skips = 0
# define a video capture object
camera = cv2.VideoCapture(0)
frame_width = int(camera.get(3))
frame_height = int(camera.get(4))
size = (frame_width, frame_height)
video_output = cv2.VideoWriter(
"activities.mp4", cv2.VideoWriter_fourcc(*"MP4V"), 10, size
)
if camera.isOpened() == False:
print("Error reading video file")
while camera.isOpened():
# Capture the video frame
# by frame
ret, frame = camera.read()
num_skips = (num_skips + 1) % args.num_skip_frames
img_container.img = frame
img_container.frame_rate.count()
if num_skips == 0:
img_container.add_frame(frame)
activity_model.inference(img_container)
rs = img_container.frame_rate.show_fps(frame, img_container.is_recording)
# Display the resulting frame
cv2.imshow("ActivityTracking", rs)
if img_container.is_recording:
video_output.write(rs)
# the 'q' button is set as the
# quitting button you may use any
# desired button of your choice
k = cv2.waitKey(1)
if k == ord("q"):
break
elif k == ord("r"):
img_container.toggle_recording()
activity_model.save_diary()
# After the loop release the cap object
camera.release()
video_output.release()
# Destroy all the windows
cv2.destroyAllWindows()
if __name__ == "__main__":
args = ArgParser().parse_args()
main(args)
|