nielsr HF staff merve HF staff commited on
Commit
cf91a30
1 Parent(s): c73fb46

Fix keypoint visualization code (#5)

Browse files

- Fix keypoint visualization code (6ddc5897875a79401ea9fb590c59dec578398192)


Co-authored-by: Merve Noyan <[email protected]>

Files changed (1) hide show
  1. README.md +36 -12
README.md CHANGED
@@ -78,23 +78,47 @@ model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/sup
78
 
79
  inputs = processor(images, return_tensors="pt")
80
  outputs = model(**inputs)
 
 
 
 
 
 
 
81
 
82
  for i in range(len(images)):
 
 
 
83
  image_mask = outputs.mask[i]
84
  image_indices = torch.nonzero(image_mask).squeeze()
85
- image_keypoints = outputs.keypoints[i][image_indices]
86
- image_scores = outputs.scores[i][image_indices]
87
- image_descriptors = outputs.descriptors[i][image_indices]
88
- ```
89
 
90
- You can then print the keypoints on the image to visualize the result :
91
- ```python
92
- import cv2
93
- for keypoint, score in zip(image_keypoints, image_scores):
94
- keypoint_x, keypoint_y = int(keypoint[0].item()), int(keypoint[1].item())
95
- color = tuple([score.item() * 255] * 3)
96
- image = cv2.circle(image, (keypoint_x, keypoint_y), 2, color)
97
- cv2.imwrite("output_image.png", image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  ```
99
 
100
  This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
 
78
 
79
  inputs = processor(images, return_tensors="pt")
80
  outputs = model(**inputs)
81
+ ```
82
+
83
+ We can now visualize the keypoints.
84
+
85
+ ```
86
+ import matplotlib.pyplot as plt
87
+ import torch
88
 
89
  for i in range(len(images)):
90
+ image = images[i]
91
+ image_width, image_height = image.size
92
+
93
  image_mask = outputs.mask[i]
94
  image_indices = torch.nonzero(image_mask).squeeze()
 
 
 
 
95
 
96
+ image_scores = outputs.scores[i][image_indices]
97
+ image_keypoints = outputs.keypoints[i][image_indices]
98
+
99
+ keypoints = image_keypoints.detach().numpy()
100
+ scores = image_scores.detach().numpy()
101
+
102
+ valid_keypoints = [
103
+ (kp, score) for kp, score in zip(keypoints, scores)
104
+ if 0 <= kp[0] < image_width and 0 <= kp[1] < image_height
105
+ ]
106
+
107
+ valid_keypoints, valid_scores = zip(*valid_keypoints)
108
+ valid_keypoints = torch.tensor(valid_keypoints)
109
+ valid_scores = torch.tensor(valid_scores)
110
+
111
+ print(valid_keypoints.shape)
112
+
113
+ plt.axis('off')
114
+ plt.imshow(image)
115
+ plt.scatter(
116
+ valid_keypoints[:, 0],
117
+ valid_keypoints[:, 1],
118
+ s=valid_scores * 100,
119
+ c='red'
120
+ )
121
+ plt.show()
122
  ```
123
 
124
  This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).