nikshep01's picture
Update app.py
547eaa5 verified
raw
history blame
2.76 kB
import cv2
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import streamlit as st
import numpy as np
import mediapipe as mp
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
# Initialize MediaPipe Drawing
mp_drawing = mp.solutions.drawing_utils
# Load the Hugging Face model and tokenizer
model_name = "dima806/yoga_pose_image_classification"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Yoga Pose Classification Function using Hugging Face model
def classify_pose(landmarks):
# Prepare input for the model
landmark_list = [landmark.x for landmark in landmarks] + [landmark.y for landmark in landmarks] + [landmark.z for landmark in landmarks]
inputs = tokenizer(landmark_list, return_tensors="pt")
# Get model predictions
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=1)
# Map predictions to pose names (adjust this mapping according to your model)
pose_names = ["Mountain Pose", "Tree Pose", "Warrior Pose", "Unknown Pose"]
return pose_names[predictions.item()]
def main():
st.title("Live Yoga Pose Detection with Hugging Face")
# Start video capture
cap = cv2.VideoCapture(0)
stframe = st.empty()
while cap.isOpened():
success, image = cap.read()
if not success:
st.error("Ignoring empty camera frame.")
continue
# Convert the BGR image to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Process the image and detect the pose
results = pose.process(image_rgb)
# Draw the pose annotation on the image
if results.pose_landmarks:
mp_drawing.draw_landmarks(
image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
# Classify the detected pose
landmarks = results.pose_landmarks.landmark
pose_name = classify_pose(landmarks)
# Display the classification result on the image
cv2.putText(image, pose_name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
# Convert the image back to BGR for OpenCV
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Display the image in Streamlit
stframe.image(image_bgr, channels='BGR')
# Break the loop if 'q' is pressed
if cv2.waitKey(5) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if _name_ == "_main_":
main()