raj999 commited on
Commit
aefe881
1 Parent(s): 51999bc

Create predict_vit

Browse files
Files changed (1) hide show
  1. predict_vit +66 -0
predict_vit ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import clip
3
+ from PIL import Image
4
+ import numpy as np
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+
7
+ # Load the CLIP model
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model, preprocess = clip.load("ViT-B/32", device=device)
10
+
11
+ def extract_features_cp(pil_img: Image.Image) -> np.ndarray:
12
+ # Preprocess the PIL image using CLIP's preprocess function
13
+ img = preprocess(pil_img).unsqueeze(0).to(device)
14
+
15
+ # Extract features using CLIP
16
+ with torch.no_grad():
17
+ features = model.encode_image(img)
18
+
19
+ # Normalize the features
20
+ features = features / features.norm(dim=-1, keepdim=True)
21
+
22
+ # Convert to numpy array and return as a flattened array
23
+ return features.cpu().numpy().flatten()
24
+
25
+ def extract_features(img_path):
26
+ # Load and preprocess the image
27
+ img = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
28
+
29
+ # Extract features using CLIP
30
+ with torch.no_grad():
31
+ features = model.encode_image(img)
32
+
33
+ # Normalize the features
34
+ features = features / features.norm(dim=-1, keepdim=True)
35
+
36
+ # Convert to numpy array
37
+ return features.cpu().numpy().flatten()
38
+
39
+ def compare_features(features1, features2):
40
+ # Cosine similarity
41
+ cos_sim = cosine_similarity([features1], [features2])[0][0]
42
+
43
+ return cos_sim
44
+
45
+ def predict_similarity(features1, features2, threshold=0.5):
46
+ cos_sim = compare_features(features1, features2)
47
+ similarity_score = cos_sim
48
+
49
+ return similarity_score > threshold
50
+
51
+ if __name__ == '__main__':
52
+ # Example usage
53
+ img_path1 = 'result.jpg'
54
+ img_path2 = 'Vochysia.jpg'
55
+
56
+ # Extract features
57
+ features1 = extract_features(img_path1)
58
+ features2 = extract_features(img_path2)
59
+
60
+ # Compare features
61
+ cos_sim = compare_features(features1, features2)
62
+ print(f'Cosine Similarity: {cos_sim}')
63
+
64
+ # Predict similarity
65
+ is_similar = predict_similarity(features1, features2, threshold=0.8)
66
+ print(f'Are the images similar? {"Yes" if is_similar else "No"}')