merve HF staff commited on
Commit
7aef3af
1 Parent(s): 16a6079

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ import gradio as gr
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ import sys
7
+ from tinysam import sam_model_registry, SamPredictor
8
+ import cv2
9
+
10
+ snapshot_download("merve/tinysam", local_dir="tinysam")
11
+
12
+ model_type = "vit_t"
13
+ sam = sam_model_registry[model_type](checkpoint="./tinysam/tinysam.pth")
14
+
15
+ predictor = SamPredictor(sam)
16
+
17
+ def infer(img):
18
+ # background (original image) layers[0] ( point prompt) composite (total image)
19
+ image = img["background"].convert("RGB")
20
+ point_prompt = img["layers"][0]
21
+ total_image = img["composite"]
22
+ #torch_img = torch.from_numpy(np.array(image))
23
+ #torch_img = torch_img.permute(2, 0, 1)
24
+ predictor.set_image(np.array(image))
25
+
26
+ # get point prompt
27
+ img_arr = np.array(point_prompt)
28
+ nonzero_indices = np.nonzero(img_arr)
29
+ center_x = int(np.mean(nonzero_indices[1]))
30
+ center_y = int(np.mean(nonzero_indices[0]))
31
+ input_point = np.array([[center_x, center_y]])
32
+
33
+ input_label = np.array([1])
34
+ masks, scores, logits = predictor.predict(
35
+ point_coords=input_point,
36
+ point_labels=input_label,
37
+ )
38
+
39
+
40
+ result_label = [(masks[0, :, :], "mask")]
41
+ return image, result_label
42
+
43
+
44
+ with gr.Blocks() as demo:
45
+ im = gr.ImageEditor(
46
+ type="pil"
47
+ )
48
+ submit_btn = gr.Button()
49
+ submit_btn.click(infer, inputs=im, outputs=gr.AnnotatedImage())
50
+
51
+ demo.launch(debug=True)