Fix keypoint visualization code (#5)
Browse files- Fix keypoint visualization code (6ddc5897875a79401ea9fb590c59dec578398192)
Co-authored-by: Merve Noyan <[email protected]>
README.md
CHANGED
@@ -78,23 +78,47 @@ model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/sup
|
|
78 |
|
79 |
inputs = processor(images, return_tensors="pt")
|
80 |
outputs = model(**inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
for i in range(len(images)):
|
|
|
|
|
|
|
83 |
image_mask = outputs.mask[i]
|
84 |
image_indices = torch.nonzero(image_mask).squeeze()
|
85 |
-
image_keypoints = outputs.keypoints[i][image_indices]
|
86 |
-
image_scores = outputs.scores[i][image_indices]
|
87 |
-
image_descriptors = outputs.descriptors[i][image_indices]
|
88 |
-
```
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
```
|
99 |
|
100 |
This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
|
|
|
78 |
|
79 |
inputs = processor(images, return_tensors="pt")
|
80 |
outputs = model(**inputs)
|
81 |
+
```
|
82 |
+
|
83 |
+
We can now visualize the keypoints.
|
84 |
+
|
85 |
+
```
|
86 |
+
import matplotlib.pyplot as plt
|
87 |
+
import torch
|
88 |
|
89 |
for i in range(len(images)):
|
90 |
+
image = images[i]
|
91 |
+
image_width, image_height = image.size
|
92 |
+
|
93 |
image_mask = outputs.mask[i]
|
94 |
image_indices = torch.nonzero(image_mask).squeeze()
|
|
|
|
|
|
|
|
|
95 |
|
96 |
+
image_scores = outputs.scores[i][image_indices]
|
97 |
+
image_keypoints = outputs.keypoints[i][image_indices]
|
98 |
+
|
99 |
+
keypoints = image_keypoints.detach().numpy()
|
100 |
+
scores = image_scores.detach().numpy()
|
101 |
+
|
102 |
+
valid_keypoints = [
|
103 |
+
(kp, score) for kp, score in zip(keypoints, scores)
|
104 |
+
if 0 <= kp[0] < image_width and 0 <= kp[1] < image_height
|
105 |
+
]
|
106 |
+
|
107 |
+
valid_keypoints, valid_scores = zip(*valid_keypoints)
|
108 |
+
valid_keypoints = torch.tensor(valid_keypoints)
|
109 |
+
valid_scores = torch.tensor(valid_scores)
|
110 |
+
|
111 |
+
print(valid_keypoints.shape)
|
112 |
+
|
113 |
+
plt.axis('off')
|
114 |
+
plt.imshow(image)
|
115 |
+
plt.scatter(
|
116 |
+
valid_keypoints[:, 0],
|
117 |
+
valid_keypoints[:, 1],
|
118 |
+
s=valid_scores * 100,
|
119 |
+
c='red'
|
120 |
+
)
|
121 |
+
plt.show()
|
122 |
```
|
123 |
|
124 |
This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
|