Spaces:
Sleeping
Sleeping
import cv2 | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import streamlit as st | |
import numpy as np | |
import mediapipe as mp | |
# Initialize MediaPipe Pose | |
mp_pose = mp.solutions.pose | |
pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) | |
# Initialize MediaPipe Drawing | |
mp_drawing = mp.solutions.drawing_utils | |
# Load the Hugging Face model and tokenizer | |
model_name = "dima806/yoga_pose_image_classification" | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Yoga Pose Classification Function using Hugging Face model | |
def classify_pose(landmarks): | |
# Prepare input for the model | |
landmark_list = [landmark.x for landmark in landmarks] + [landmark.y for landmark in landmarks] + [landmark.z for landmark in landmarks] | |
inputs = tokenizer(landmark_list, return_tensors="pt") | |
# Get model predictions | |
outputs = model(**inputs) | |
predictions = torch.argmax(outputs.logits, dim=1) | |
# Map predictions to pose names (adjust this mapping according to your model) | |
pose_names = ["Mountain Pose", "Tree Pose", "Warrior Pose", "Unknown Pose"] | |
return pose_names[predictions.item()] | |
def main(): | |
st.title("Live Yoga Pose Detection with Hugging Face") | |
# Start video capture | |
cap = cv2.VideoCapture(0) | |
stframe = st.empty() | |
while cap.isOpened(): | |
success, image = cap.read() | |
if not success: | |
st.error("Ignoring empty camera frame.") | |
continue | |
# Convert the BGR image to RGB | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Process the image and detect the pose | |
results = pose.process(image_rgb) | |
# Draw the pose annotation on the image | |
if results.pose_landmarks: | |
mp_drawing.draw_landmarks( | |
image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) | |
# Classify the detected pose | |
landmarks = results.pose_landmarks.landmark | |
pose_name = classify_pose(landmarks) | |
# Display the classification result on the image | |
cv2.putText(image, pose_name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA) | |
# Convert the image back to BGR for OpenCV | |
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
# Display the image in Streamlit | |
stframe.image(image_bgr, channels='BGR') | |
# Break the loop if 'q' is pressed | |
if cv2.waitKey(5) & 0xFF == ord('q'): | |
break | |
cap.release() | |
cv2.destroyAllWindows() | |
if _name_ == "_main_": | |
main() |