Yan
commited on
Commit
•
2ca44ee
1
Parent(s):
0227876
added test script and data for local handler testing, fixed syntax error in handler script
Browse files- .gitattributes +1 -0
- handler.py +54 -28
- test.png +3 -0
- test.py +12 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
handler.py
CHANGED
@@ -4,12 +4,10 @@ from PIL import Image
|
|
4 |
from io import BytesIO
|
5 |
import numpy as np
|
6 |
import os
|
7 |
-
import requests
|
8 |
import torch
|
9 |
import torchvision.transforms as T
|
10 |
from transformers import AutoProcessor, AutoModelForVision2Seq
|
11 |
import cv2
|
12 |
-
import ast
|
13 |
|
14 |
# set device
|
15 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -18,15 +16,43 @@ if device.type != 'cuda':
|
|
18 |
# set mixed precision dtype
|
19 |
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
class EndpointHandler():
|
23 |
def __init__(self, path=""):
|
24 |
self.ckpt_id = "ydshieh/kosmos-2-patch14-224"
|
25 |
|
26 |
-
self.model = AutoModelForVision2Seq.from_pretrained(ckpt_id, trust_remote_code=True).to("cuda")
|
27 |
-
self.processor = AutoProcessor.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
def draw_entity_boxes_on_image(image, entities, show=False, save_path=None, entity_index=-1):
|
30 |
"""_summary_
|
31 |
Args:
|
32 |
image (_type_): image or image path
|
@@ -56,17 +82,17 @@ class EndpointHandler():
|
|
56 |
image = np.array(pil_img)[:, :, [2, 1, 0]]
|
57 |
else:
|
58 |
raise ValueError(f"invaild image format, {type(image)} for {image}")
|
59 |
-
|
60 |
if len(entities) == 0:
|
61 |
return image
|
62 |
-
|
63 |
indices = list(range(len(entities)))
|
64 |
if entity_index >= 0:
|
65 |
indices = [entity_index]
|
66 |
-
|
67 |
# Not to show too many bboxes
|
68 |
entities = entities[:len(color_map)]
|
69 |
-
|
70 |
new_image = image.copy()
|
71 |
previous_bboxes = []
|
72 |
# size of text
|
@@ -78,10 +104,10 @@ class EndpointHandler():
|
|
78 |
base_height = int(text_height * 0.675)
|
79 |
text_offset_original = text_height - base_height
|
80 |
text_spaces = 3
|
81 |
-
|
82 |
# num_bboxes = sum(len(x[-1]) for x in entities)
|
83 |
used_colors = colors # random.sample(colors, k=num_bboxes)
|
84 |
-
|
85 |
color_id = -1
|
86 |
for entity_idx, (entity_name, (start, end), bboxes) in enumerate(entities):
|
87 |
color_id += 1
|
@@ -91,37 +117,37 @@ class EndpointHandler():
|
|
91 |
# if start is None and bbox_id > 0:
|
92 |
# color_id += 1
|
93 |
orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)
|
94 |
-
|
95 |
# draw bbox
|
96 |
# random color
|
97 |
color = used_colors[color_id] # tuple(np.random.randint(0, 255, size=3).tolist())
|
98 |
new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
|
99 |
-
|
100 |
l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
|
101 |
-
|
102 |
x1 = orig_x1 - l_o
|
103 |
y1 = orig_y1 - l_o
|
104 |
-
|
105 |
if y1 < text_height + text_offset_original + 2 * text_spaces:
|
106 |
y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
|
107 |
x1 = orig_x1 + r_o
|
108 |
-
|
109 |
# add text background
|
110 |
(text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
|
111 |
text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
|
112 |
-
|
113 |
for prev_bbox in previous_bboxes:
|
114 |
-
while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox):
|
115 |
text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
|
116 |
text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
|
117 |
y1 += (text_height + text_offset_original + 2 * text_spaces)
|
118 |
-
|
119 |
if text_bg_y2 >= image_h:
|
120 |
text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
|
121 |
text_bg_y2 = image_h
|
122 |
y1 = image_h
|
123 |
break
|
124 |
-
|
125 |
alpha = 0.5
|
126 |
for i in range(text_bg_y1, text_bg_y2):
|
127 |
for j in range(text_bg_x1, text_bg_x2):
|
@@ -133,19 +159,19 @@ class EndpointHandler():
|
|
133 |
# white
|
134 |
bg_color = [255, 255, 255]
|
135 |
new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(np.uint8)
|
136 |
-
|
137 |
cv2.putText(
|
138 |
new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
|
139 |
)
|
140 |
# previous_locations.append((x1, y1))
|
141 |
previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))
|
142 |
-
|
143 |
pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]])
|
144 |
if save_path:
|
145 |
pil_image.save(save_path)
|
146 |
if show:
|
147 |
pil_image.show()
|
148 |
-
|
149 |
return pil_image
|
150 |
|
151 |
|
@@ -161,13 +187,13 @@ class EndpointHandler():
|
|
161 |
# (https://github.com/microsoft/unilm/blob/f4695ed0244a275201fff00bee495f76670fbe70/kosmos-2/demo/gradio_app.py#L345-L346)
|
162 |
user_image_path = "/tmp/user_input_test_image.jpg"
|
163 |
image_input.save(user_image_path)
|
164 |
-
|
165 |
# This might give different results from the original argument `image_input`
|
166 |
image_input = Image.open(user_image_path)
|
167 |
text_input = "<grounding>Describe this image in detail:"
|
168 |
#text_input = f"<grounding>{text_input}"
|
169 |
|
170 |
-
inputs = processor(text=text_input, images=image_input, return_tensors="pt")
|
171 |
|
172 |
generated_ids = self.model.generate(
|
173 |
pixel_values=inputs["pixel_values"].to("cuda"),
|
@@ -181,7 +207,7 @@ class EndpointHandler():
|
|
181 |
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
182 |
|
183 |
# By default, the generated text is cleanup and the entities are extracted.
|
184 |
-
processed_text, entities = processor.post_process_generation(generated_text)
|
185 |
|
186 |
annotated_image = self.draw_entity_boxes_on_image(image_input, entities, show=False)
|
187 |
|
@@ -213,10 +239,10 @@ class EndpointHandler():
|
|
213 |
colored_text.append((processed_text[end:len(processed_text)], None))
|
214 |
|
215 |
return annotated_image, colored_text, str(filtered_entities)
|
216 |
-
|
217 |
# helper to decode input image
|
218 |
def decode_base64_image(self, image_string):
|
219 |
base64_image = base64.b64decode(image_string)
|
220 |
buffer = BytesIO(base64_image)
|
221 |
image = Image.open(buffer)
|
222 |
-
return image
|
|
|
4 |
from io import BytesIO
|
5 |
import numpy as np
|
6 |
import os
|
|
|
7 |
import torch
|
8 |
import torchvision.transforms as T
|
9 |
from transformers import AutoProcessor, AutoModelForVision2Seq
|
10 |
import cv2
|
|
|
11 |
|
12 |
# set device
|
13 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
16 |
# set mixed precision dtype
|
17 |
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
|
18 |
|
19 |
+
colors = [
|
20 |
+
(0, 255, 0),
|
21 |
+
(0, 0, 255),
|
22 |
+
(255, 255, 0),
|
23 |
+
(255, 0, 255),
|
24 |
+
(0, 255, 255),
|
25 |
+
(114, 128, 250),
|
26 |
+
(0, 165, 255),
|
27 |
+
(0, 128, 0),
|
28 |
+
(144, 238, 144),
|
29 |
+
(238, 238, 175),
|
30 |
+
(255, 191, 0),
|
31 |
+
(0, 128, 0),
|
32 |
+
(226, 43, 138),
|
33 |
+
(255, 0, 255),
|
34 |
+
(0, 215, 255),
|
35 |
+
(255, 0, 0),
|
36 |
+
]
|
37 |
+
|
38 |
+
color_map = {
|
39 |
+
f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for color_id, color in enumerate(colors)
|
40 |
+
}
|
41 |
+
|
42 |
|
43 |
class EndpointHandler():
|
44 |
def __init__(self, path=""):
|
45 |
self.ckpt_id = "ydshieh/kosmos-2-patch14-224"
|
46 |
|
47 |
+
self.model = AutoModelForVision2Seq.from_pretrained(self.ckpt_id, trust_remote_code=True).to("cuda")
|
48 |
+
self.processor = AutoProcessor.from_pretrained(self.ckpt_id, trust_remote_code=True)
|
49 |
+
|
50 |
+
def is_overlapping(self, rect1, rect2):
|
51 |
+
x1, y1, x2, y2 = rect1
|
52 |
+
x3, y3, x4, y4 = rect2
|
53 |
+
return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
|
54 |
|
55 |
+
def draw_entity_boxes_on_image(self, image, entities, show=False, save_path=None, entity_index=-1):
|
56 |
"""_summary_
|
57 |
Args:
|
58 |
image (_type_): image or image path
|
|
|
82 |
image = np.array(pil_img)[:, :, [2, 1, 0]]
|
83 |
else:
|
84 |
raise ValueError(f"invaild image format, {type(image)} for {image}")
|
85 |
+
|
86 |
if len(entities) == 0:
|
87 |
return image
|
88 |
+
|
89 |
indices = list(range(len(entities)))
|
90 |
if entity_index >= 0:
|
91 |
indices = [entity_index]
|
92 |
+
|
93 |
# Not to show too many bboxes
|
94 |
entities = entities[:len(color_map)]
|
95 |
+
|
96 |
new_image = image.copy()
|
97 |
previous_bboxes = []
|
98 |
# size of text
|
|
|
104 |
base_height = int(text_height * 0.675)
|
105 |
text_offset_original = text_height - base_height
|
106 |
text_spaces = 3
|
107 |
+
|
108 |
# num_bboxes = sum(len(x[-1]) for x in entities)
|
109 |
used_colors = colors # random.sample(colors, k=num_bboxes)
|
110 |
+
|
111 |
color_id = -1
|
112 |
for entity_idx, (entity_name, (start, end), bboxes) in enumerate(entities):
|
113 |
color_id += 1
|
|
|
117 |
# if start is None and bbox_id > 0:
|
118 |
# color_id += 1
|
119 |
orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)
|
120 |
+
|
121 |
# draw bbox
|
122 |
# random color
|
123 |
color = used_colors[color_id] # tuple(np.random.randint(0, 255, size=3).tolist())
|
124 |
new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
|
125 |
+
|
126 |
l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
|
127 |
+
|
128 |
x1 = orig_x1 - l_o
|
129 |
y1 = orig_y1 - l_o
|
130 |
+
|
131 |
if y1 < text_height + text_offset_original + 2 * text_spaces:
|
132 |
y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
|
133 |
x1 = orig_x1 + r_o
|
134 |
+
|
135 |
# add text background
|
136 |
(text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
|
137 |
text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
|
138 |
+
|
139 |
for prev_bbox in previous_bboxes:
|
140 |
+
while self.is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox):
|
141 |
text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
|
142 |
text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
|
143 |
y1 += (text_height + text_offset_original + 2 * text_spaces)
|
144 |
+
|
145 |
if text_bg_y2 >= image_h:
|
146 |
text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
|
147 |
text_bg_y2 = image_h
|
148 |
y1 = image_h
|
149 |
break
|
150 |
+
|
151 |
alpha = 0.5
|
152 |
for i in range(text_bg_y1, text_bg_y2):
|
153 |
for j in range(text_bg_x1, text_bg_x2):
|
|
|
159 |
# white
|
160 |
bg_color = [255, 255, 255]
|
161 |
new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(np.uint8)
|
162 |
+
|
163 |
cv2.putText(
|
164 |
new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
|
165 |
)
|
166 |
# previous_locations.append((x1, y1))
|
167 |
previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))
|
168 |
+
|
169 |
pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]])
|
170 |
if save_path:
|
171 |
pil_image.save(save_path)
|
172 |
if show:
|
173 |
pil_image.show()
|
174 |
+
|
175 |
return pil_image
|
176 |
|
177 |
|
|
|
187 |
# (https://github.com/microsoft/unilm/blob/f4695ed0244a275201fff00bee495f76670fbe70/kosmos-2/demo/gradio_app.py#L345-L346)
|
188 |
user_image_path = "/tmp/user_input_test_image.jpg"
|
189 |
image_input.save(user_image_path)
|
190 |
+
|
191 |
# This might give different results from the original argument `image_input`
|
192 |
image_input = Image.open(user_image_path)
|
193 |
text_input = "<grounding>Describe this image in detail:"
|
194 |
#text_input = f"<grounding>{text_input}"
|
195 |
|
196 |
+
inputs = self.processor(text=text_input, images=image_input, return_tensors="pt")
|
197 |
|
198 |
generated_ids = self.model.generate(
|
199 |
pixel_values=inputs["pixel_values"].to("cuda"),
|
|
|
207 |
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
208 |
|
209 |
# By default, the generated text is cleanup and the entities are extracted.
|
210 |
+
processed_text, entities = self.processor.post_process_generation(generated_text)
|
211 |
|
212 |
annotated_image = self.draw_entity_boxes_on_image(image_input, entities, show=False)
|
213 |
|
|
|
239 |
colored_text.append((processed_text[end:len(processed_text)], None))
|
240 |
|
241 |
return annotated_image, colored_text, str(filtered_entities)
|
242 |
+
|
243 |
# helper to decode input image
|
244 |
def decode_base64_image(self, image_string):
|
245 |
base64_image = base64.b64decode(image_string)
|
246 |
buffer = BytesIO(base64_image)
|
247 |
image = Image.open(buffer)
|
248 |
+
return image
|
test.png
ADDED
Git LFS Details
|
test.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from handler import EndpointHandler
|
2 |
+
from PIL import Image
|
3 |
+
import base64
|
4 |
+
|
5 |
+
# init handler
|
6 |
+
my_handler = EndpointHandler(path=".")
|
7 |
+
|
8 |
+
# prepare sample payload
|
9 |
+
image = Image.open("test.png")
|
10 |
+
payload = {"image": base64.b64encode(image)}
|
11 |
+
|
12 |
+
pred=my_handler(payload)
|