Spaces:
Sleeping
Sleeping
File size: 2,756 Bytes
f497e19 547eaa5 f497e19 56facd8 f497e19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import cv2
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import streamlit as st
import numpy as np
import mediapipe as mp
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
# Initialize MediaPipe Drawing
mp_drawing = mp.solutions.drawing_utils
# Load the Hugging Face model and tokenizer
model_name = "dima806/yoga_pose_image_classification"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Yoga Pose Classification Function using Hugging Face model
def classify_pose(landmarks):
# Prepare input for the model
landmark_list = [landmark.x for landmark in landmarks] + [landmark.y for landmark in landmarks] + [landmark.z for landmark in landmarks]
inputs = tokenizer(landmark_list, return_tensors="pt")
# Get model predictions
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=1)
# Map predictions to pose names (adjust this mapping according to your model)
pose_names = ["Mountain Pose", "Tree Pose", "Warrior Pose", "Unknown Pose"]
return pose_names[predictions.item()]
def main():
st.title("Live Yoga Pose Detection with Hugging Face")
# Start video capture
cap = cv2.VideoCapture(0)
stframe = st.empty()
while cap.isOpened():
success, image = cap.read()
if not success:
st.error("Ignoring empty camera frame.")
continue
# Convert the BGR image to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Process the image and detect the pose
results = pose.process(image_rgb)
# Draw the pose annotation on the image
if results.pose_landmarks:
mp_drawing.draw_landmarks(
image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
# Classify the detected pose
landmarks = results.pose_landmarks.landmark
pose_name = classify_pose(landmarks)
# Display the classification result on the image
cv2.putText(image, pose_name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
# Convert the image back to BGR for OpenCV
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Display the image in Streamlit
stframe.image(image_bgr, channels='BGR')
# Break the loop if 'q' is pressed
if cv2.waitKey(5) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if _name_ == "_main_":
main() |