ttengwang commited on
Commit
9a84ec8
β€’
1 Parent(s): 863eac9

clean up code, add langchain for chatbox

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitignore +8 -0
  2. DejaVuSansCondensed-Bold.ttf +0 -0
  3. Image/demo1.svg +0 -0
  4. Image/demo2.svg +0 -0
  5. Image/title.svg +0 -1
  6. app.py +441 -311
  7. app_huggingface.py +0 -268
  8. app_old.py +0 -261
  9. app_wo_langchain.py +588 -0
  10. caas.py +0 -114
  11. caption_anything/__init__.py +0 -0
  12. {captioner β†’ caption_anything/captioner}/README.md +0 -0
  13. {captioner β†’ caption_anything/captioner}/__init__.py +0 -0
  14. {captioner β†’ caption_anything/captioner}/base_captioner.py +1 -1
  15. {captioner β†’ caption_anything/captioner}/blip.py +5 -5
  16. {captioner β†’ caption_anything/captioner}/blip2.py +4 -7
  17. {captioner β†’ caption_anything/captioner}/git.py +1 -1
  18. {captioner β†’ caption_anything/captioner}/modeling_blip.py +0 -0
  19. {captioner β†’ caption_anything/captioner}/modeling_git.py +0 -0
  20. {captioner β†’ caption_anything/captioner}/vit_pixel_masks_utils.py +0 -0
  21. caption_anything.py β†’ caption_anything/model.py +78 -63
  22. caption_anything/segmenter/__init__.py +5 -0
  23. {segmenter β†’ caption_anything/segmenter}/base_segmenter.py +66 -31
  24. {segmenter β†’ caption_anything/segmenter}/readme.md +0 -0
  25. {text_refiner β†’ caption_anything/text_refiner}/README.md +0 -0
  26. {text_refiner β†’ caption_anything/text_refiner}/__init__.py +1 -1
  27. {text_refiner β†’ caption_anything/text_refiner}/text_refiner.py +0 -0
  28. caption_anything/utils/chatbot.py +236 -0
  29. image_editing_utils.py β†’ caption_anything/utils/image_editing_utils.py +23 -11
  30. caption_anything/utils/parser.py +29 -0
  31. caption_anything/utils/utils.py +419 -0
  32. env.sh +0 -6
  33. segmenter/__init__.py +0 -5
  34. segmenter/images/truck.jpg +0 -0
  35. segmenter/sam_vit_h_4b8939.pth +0 -3
  36. test_img/img0.png +0 -0
  37. test_img/img1.jpg +0 -0
  38. test_img/img1.jpg.raw_mask.png +0 -0
  39. test_img/img10.jpg +0 -0
  40. test_img/img10.jpg.raw_mask.png +0 -0
  41. test_img/img11.jpg +0 -0
  42. test_img/img12.jpg +0 -0
  43. test_img/img12.jpg.raw_mask.png +0 -0
  44. test_img/img13.jpg +0 -0
  45. test_img/img13.jpg.raw_mask.png +0 -0
  46. test_img/img14.jpg +0 -0
  47. test_img/img14.jpg.raw_mask.png +0 -0
  48. test_img/img15.jpg +0 -0
  49. test_img/img15.jpg.raw_mask.png +0 -0
  50. test_img/img16.jpg +0 -0
.gitignore CHANGED
@@ -2,6 +2,14 @@ result/
2
  model_cache/
3
  *.pth
4
  teng_grad_start.sh
 
 
 
 
 
 
 
 
5
 
6
  # Byte-compiled / optimized / DLL files
7
  __pycache__/
 
2
  model_cache/
3
  *.pth
4
  teng_grad_start.sh
5
+ *.jpg
6
+ *.jpeg
7
+ *.png
8
+ *.svg
9
+ *.gif
10
+ *.tiff
11
+ *.webp
12
+
13
 
14
  # Byte-compiled / optimized / DLL files
15
  __pycache__/
DejaVuSansCondensed-Bold.ttf DELETED
Binary file (632 kB)
 
Image/demo1.svg DELETED
Image/demo2.svg DELETED
Image/title.svg DELETED
app.py CHANGED
@@ -1,85 +1,63 @@
1
- from io import BytesIO
2
- import string
3
- import gradio as gr
4
- import requests
5
- from caption_anything import CaptionAnything
6
- import torch
7
  import json
8
- import sys
9
- import argparse
10
- from caption_anything import parse_augment
11
  import numpy as np
12
- import PIL.ImageDraw as ImageDraw
13
- from image_editing_utils import create_bubble_frame
14
- import copy
15
- from tools import mask_painter
16
- from PIL import Image
17
- import os
18
- from captioner import build_captioner
 
 
 
 
 
 
19
  from segment_anything import sam_model_registry
20
- from text_refiner import build_text_refiner
21
- from segmenter import build_segmenter
22
-
23
-
24
- def download_checkpoint(url, folder, filename):
25
- os.makedirs(folder, exist_ok=True)
26
- filepath = os.path.join(folder, filename)
27
-
28
- if not os.path.exists(filepath):
29
- response = requests.get(url, stream=True)
30
- with open(filepath, "wb") as f:
31
- for chunk in response.iter_content(chunk_size=8192):
32
- if chunk:
33
- f.write(chunk)
34
-
35
- return filepath
36
-
37
-
38
- title = """<p><h1 align="center">Caption-Anything</h1></p>
39
- """
40
- description = """<p>Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything <a href="https://huggingface.co/spaces/TencentARC/Caption-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
41
-
42
- examples = [
43
- ["test_img/img35.webp"],
44
- ["test_img/img2.jpg"],
45
- ["test_img/img5.jpg"],
46
- ["test_img/img12.jpg"],
47
- ["test_img/img14.jpg"],
48
- ["test_img/img0.png"],
49
- ["test_img/img1.jpg"],
50
- ]
51
-
52
- seg_model_map = {
53
- 'base': 'vit_b',
54
- 'large': 'vit_l',
55
- 'huge': 'vit_h'
56
- }
57
- ckpt_url_map = {
58
- 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
59
- 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
60
- 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
61
- }
62
- os.makedirs('result', exist_ok=True)
63
  args = parse_augment()
 
 
 
 
 
 
 
 
 
64
 
65
- checkpoint_url = ckpt_url_map[seg_model_map[args.segmenter]]
66
- folder = "segmenter"
67
- filename = os.path.basename(checkpoint_url)
68
- args.segmenter_checkpoint = os.path.join(folder, filename)
69
 
70
- download_checkpoint(checkpoint_url, folder, filename)
 
 
 
71
 
72
- # args.device = 'cuda:5'
73
- # args.disable_gpt = True
74
- # args.enable_reduce_tokens = False
75
- # args.port=20322
76
- # args.captioner = 'blip'
77
- # args.regular_box = True
78
- shared_captioner = build_captioner(args.captioner, args.device, args)
79
- shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=args.segmenter_checkpoint).to(args.device)
 
 
 
 
 
 
80
 
 
81
 
82
- def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None, session_id=None):
 
 
 
 
83
  segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
84
  captioner = captioner
85
  if session_id is not None:
@@ -89,17 +67,22 @@ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_mod
89
 
90
  def init_openai_api_key(api_key=""):
91
  text_refiner = None
 
92
  if api_key and len(api_key) > 30:
93
  try:
94
  text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
95
- text_refiner.llm('hi') # test
 
96
  except:
97
  text_refiner = None
 
98
  openai_available = text_refiner is not None
99
- return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True), gr.update(visible = True), text_refiner
 
 
100
 
101
 
102
- def get_prompt(chat_input, click_state, click_mode):
103
  inputs = json.loads(chat_input)
104
  if click_mode == 'Continuous':
105
  points = click_state[0]
@@ -119,13 +102,14 @@ def get_prompt(chat_input, click_state, click_mode):
119
  raise NotImplementedError
120
 
121
  prompt = {
122
- "prompt_type":["click"],
123
- "input_point":click_state[0],
124
- "input_label":click_state[1],
125
- "multimask_output":"True",
126
  }
127
  return prompt
128
 
 
129
  def update_click_state(click_state, caption, click_mode):
130
  if click_mode == 'Continuous':
131
  click_state[2].append(caption)
@@ -134,280 +118,426 @@ def update_click_state(click_state, caption, click_mode):
134
  else:
135
  raise NotImplementedError
136
 
137
-
138
- def chat_with_points(chat_input, click_state, chat_state, state, text_refiner, img_caption):
139
- if text_refiner is None:
 
 
140
  response = "Text refiner is not initilzed, please input openai api key."
141
  state = state + [(chat_input, response)]
142
- return state, state, chat_state
143
-
144
- points, labels, captions = click_state
145
- # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
146
- suffix = '\nHuman: {chat_input}\nAI: '
147
- qa_template = '\nHuman: {q}\nAI: {a}'
148
- # # "The image is of width {width} and height {height}."
149
- point_chat_prompt = "I am an AI trained to chat with you about an image. I am greate at what is going on in any image based on the image information your provide. The overall image description is \"{img_caption}\". You will also provide me objects in the image in details, i.e., their location and visual descriptions. Here are the locations and descriptions of events that happen in the image: {points_with_caps} \n Now, let's chat!"
150
- prev_visual_context = ""
151
- pos_points = []
152
- pos_captions = []
153
- for i in range(len(points)):
154
- if labels[i] == 1:
155
- pos_points.append(f"({points[i][0]}, {points[i][0]})")
156
- pos_captions.append(captions[i])
157
- prev_visual_context = prev_visual_context + '\n' + 'There is an event described as \"{}\" locating at {}'.format(pos_captions[-1], ', '.join(pos_points))
158
-
159
- context_length_thres = 500
160
- prev_history = ""
161
- for i in range(len(chat_state)):
162
- q, a = chat_state[i]
163
- if len(prev_history) < context_length_thres:
164
- prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
165
- else:
166
- break
167
- chat_prompt = point_chat_prompt.format(**{"img_caption":img_caption,"points_with_caps": prev_visual_context}) + prev_history + suffix.format(**{"chat_input": chat_input})
168
- print('\nchat_prompt: ', chat_prompt)
169
- response = text_refiner.llm(chat_prompt)
170
- state = state + [(chat_input, response)]
171
- chat_state = chat_state + [(chat_input, response)]
172
- return state, state, chat_state
173
-
174
- def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
175
- length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
176
 
 
 
 
 
 
 
 
 
 
 
 
177
  model = build_caption_anything_with_models(
178
  args,
179
  api_key="",
180
  captioner=shared_captioner,
181
  sam_model=shared_sam_model,
182
- text_refiner=text_refiner,
183
  session_id=iface.app_id
184
  )
185
-
186
- model.segmenter.image_embedding = image_embedding
187
- model.segmenter.predictor.original_size = original_size
188
- model.segmenter.predictor.input_size = input_size
189
- model.segmenter.predictor.is_image_set = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  if point_prompt == 'Positive':
192
- coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
193
  else:
194
- coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
 
 
 
 
195
 
196
  controls = {'length': length,
197
  'sentiment': sentiment,
198
  'factuality': factuality,
199
  'language': language}
200
 
201
- # click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
202
- # chat_input = click_coordinate
203
- prompt = get_prompt(coordinate, click_state, click_mode)
204
- print('prompt: ', prompt, 'controls: ', controls)
205
- input_points = prompt['input_point']
206
- input_labels = prompt['input_label']
 
 
 
 
207
 
208
  enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
209
  out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
 
210
  state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
211
- # for k, v in out['generated_captions'].items():
212
- # state = state + [(f'{k}: {v}', None)]
213
  state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
214
  wiki = out['generated_captions'].get('wiki', "")
215
-
216
  update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
217
  text = out['generated_captions']['raw_caption']
218
- # draw = ImageDraw.Draw(image_input)
219
- # draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
220
  input_mask = np.array(out['mask'].convert('P'))
221
  image_input = mask_painter(np.array(image_input), input_mask)
222
  origin_image_input = image_input
223
- image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]), input_mask, input_points=input_points, input_labels=input_labels)
224
-
225
- yield state, state, click_state, chat_input, image_input, wiki
 
 
 
 
 
 
 
 
226
  if not args.disable_gpt and model.text_refiner:
227
- refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'], enable_wiki=enable_wiki)
 
228
  # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
229
  new_cap = refined_caption['caption']
230
  wiki = refined_caption['wiki']
231
  state = state + [(None, f"caption: {new_cap}")]
232
- refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]), input_mask, input_points=input_points, input_labels=input_labels)
233
- yield state, state, click_state, chat_input, refined_image_input, wiki
 
 
234
 
235
 
236
- def upload_callback(image_input, state):
237
- chat_state = []
238
- click_state = [[], [], []]
239
- res = 1024
240
- width, height = image_input.size
241
- ratio = min(1.0 * res / max(width, height), 1.0)
242
- if ratio < 1.0:
243
- image_input = image_input.resize((int(width * ratio), int(height * ratio)))
244
- print('Scaling input image to {}'.format(image_input.size))
245
- state = [] + [(None, 'Image size: ' + str(image_input.size))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  model = build_caption_anything_with_models(
247
  args,
248
  api_key="",
249
  captioner=shared_captioner,
250
  sam_model=shared_sam_model,
 
251
  session_id=iface.app_id
252
  )
253
- model.segmenter.set_image(image_input)
254
- image_embedding = model.segmenter.image_embedding
255
- original_size = model.segmenter.predictor.original_size
256
- input_size = model.segmenter.predictor.input_size
257
- img_caption, _ = model.captioner.inference_seg(image_input)
258
- return state, state, chat_state, image_input, click_state, image_input, image_embedding, original_size, input_size, img_caption
259
-
260
- with gr.Blocks(
261
- css='''
262
- #image_upload{min-height:400px}
263
- #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
264
- '''
265
- ) as iface:
266
- state = gr.State([])
267
- click_state = gr.State([[],[],[]])
268
- chat_state = gr.State([])
269
- origin_image = gr.State(None)
270
- image_embedding = gr.State(None)
271
- text_refiner = gr.State(None)
272
- original_size = gr.State(None)
273
- input_size = gr.State(None)
274
- img_caption = gr.State(None)
275
-
276
- gr.Markdown(title)
277
- gr.Markdown(description)
278
-
279
- with gr.Row():
280
- with gr.Column(scale=1.0):
281
- with gr.Column(visible=False) as modules_not_need_gpt:
282
- image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
283
- example_image = gr.Image(type="pil", interactive=False, visible=False)
284
- with gr.Row(scale=1.0):
285
- with gr.Row(scale=0.4):
286
- point_prompt = gr.Radio(
287
- choices=["Positive", "Negative"],
288
- value="Positive",
289
- label="Point Prompt",
290
- interactive=True)
291
- click_mode = gr.Radio(
292
- choices=["Continuous", "Single"],
293
- value="Continuous",
294
- label="Clicking Mode",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  interactive=True)
296
- with gr.Row(scale=0.4):
297
- clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
298
- clear_button_image = gr.Button(value="Clear Image", interactive=True)
299
- with gr.Column(visible=False) as modules_need_gpt:
300
- with gr.Row(scale=1.0):
301
- language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
302
- sentiment = gr.Radio(
303
- choices=["Positive", "Natural", "Negative"],
304
- value="Natural",
305
- label="Sentiment",
306
- interactive=True,
307
- )
308
- with gr.Row(scale=1.0):
309
- factuality = gr.Radio(
310
- choices=["Factual", "Imagination"],
311
- value="Factual",
312
- label="Factuality",
313
- interactive=True,
314
  )
315
- length = gr.Slider(
316
- minimum=10,
317
- maximum=80,
318
- value=10,
319
- step=1,
320
- interactive=True,
321
- label="Generated Caption Length",
322
- )
323
- enable_wiki = gr.Radio(
324
- choices=["Yes", "No"],
325
- value="No",
326
- label="Enable Wiki",
327
- interactive=True)
328
- with gr.Column(visible=True) as modules_not_need_gpt3:
329
- gr.Examples(
330
- examples=examples,
331
- inputs=[example_image],
332
- )
333
- with gr.Column(scale=0.5):
334
- openai_api_key = gr.Textbox(
335
- placeholder="Input openAI API key",
336
- show_label=False,
337
- label = "OpenAI API Key",
338
- lines=1,
339
- type="password")
340
- with gr.Row(scale=0.5):
341
- enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
342
- disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True, variant='primary')
343
- with gr.Column(visible=False) as modules_need_gpt2:
344
- wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
345
- with gr.Column(visible=False) as modules_not_need_gpt2:
346
- chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=550,scale=0.5)
347
- with gr.Column(visible=False) as modules_need_gpt3:
348
- chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(container=False)
349
- with gr.Row():
350
- clear_button_text = gr.Button(value="Clear Text", interactive=True)
351
- submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
352
-
353
- openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
354
- enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
355
- disable_chatGPT_button.click(init_openai_api_key, outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
356
-
357
- clear_button_clike.click(
358
- lambda x: ([[], [], []], x, ""),
359
- [origin_image],
360
- [click_state, image_input, wiki_output],
361
- queue=False,
362
- show_progress=False
363
- )
364
- clear_button_image.click(
365
- lambda: (None, [], [], [], [[], [], []], "", "", ""),
366
- [],
367
- [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
368
- queue=False,
369
- show_progress=False
370
- )
371
- clear_button_text.click(
372
- lambda: ([], [], [[], [], [], []], []),
373
- [],
374
- [chatbot, state, click_state, chat_state],
375
- queue=False,
376
- show_progress=False
377
- )
378
- image_input.clear(
379
- lambda: (None, [], [], [], [[], [], []], "", "", ""),
380
- [],
381
- [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
382
- queue=False,
383
- show_progress=False
384
- )
385
-
386
- image_input.upload(upload_callback,[image_input, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size, img_caption])
387
- chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner, img_caption], [chatbot, state, chat_state])
388
- chat_input.submit(lambda: "", None, chat_input)
389
- example_image.change(upload_callback,[example_image, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size, img_caption])
390
-
391
- # select coordinate
392
- image_input.select(inference_seg_cap,
393
- inputs=[
394
- origin_image,
395
- point_prompt,
396
- click_mode,
397
- enable_wiki,
398
- language,
399
- sentiment,
400
- factuality,
401
- length,
402
- image_embedding,
403
- state,
404
- click_state,
405
- original_size,
406
- input_size,
407
- text_refiner
408
- ],
409
- outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
410
- show_progress=False, queue=True)
411
-
412
- iface.queue(concurrency_count=5, api_open=False, max_size=10)
413
- iface.launch(server_name="0.0.0.0", enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
 
 
 
 
2
  import json
3
+ import PIL
4
+ import gradio as gr
 
5
  import numpy as np
6
+ from gradio import processing_utils
7
+
8
+ from packaging import version
9
+ from PIL import Image, ImageDraw
10
+
11
+ from caption_anything.model import CaptionAnything
12
+ from caption_anything.utils.image_editing_utils import create_bubble_frame
13
+ from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter
14
+ from caption_anything.utils.parser import parse_augment
15
+ from caption_anything.captioner import build_captioner
16
+ from caption_anything.text_refiner import build_text_refiner
17
+ from caption_anything.segmenter import build_segmenter
18
+ from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
19
  from segment_anything import sam_model_registry
20
+
21
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  args = parse_augment()
23
+ if args.segmenter_checkpoint is None:
24
+ _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
25
+ else:
26
+ segmenter_checkpoint = args.segmenter_checkpoint
27
+
28
+ shared_captioner = build_captioner(args.captioner, args.device, args)
29
+ shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
30
+ tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
31
+ shared_chatbot_tools = build_chatbot_tools(tools_dict)
32
 
 
 
 
 
33
 
34
+ class ImageSketcher(gr.Image):
35
+ """
36
+ Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
37
+ """
38
 
39
+ is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
40
+
41
+ def __init__(self, **kwargs):
42
+ super().__init__(tool="sketch", **kwargs)
43
+
44
+ def preprocess(self, x):
45
+ if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
46
+ assert isinstance(x, dict)
47
+ if x['mask'] is None:
48
+ decode_image = processing_utils.decode_base64_to_image(x['image'])
49
+ width, height = decode_image.size
50
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
51
+ mask[..., -1] = 255
52
+ mask = self.postprocess(mask)
53
 
54
+ x['mask'] = mask
55
 
56
+ return super().preprocess(x)
57
+
58
+
59
+ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None,
60
+ session_id=None):
61
  segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
62
  captioner = captioner
63
  if session_id is not None:
 
67
 
68
  def init_openai_api_key(api_key=""):
69
  text_refiner = None
70
+ visual_chatgpt = None
71
  if api_key and len(api_key) > 30:
72
  try:
73
  text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
74
+ text_refiner.llm('hi') # test
75
+ visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
76
  except:
77
  text_refiner = None
78
+ visual_chatgpt = None
79
  openai_available = text_refiner is not None
80
+ return gr.update(visible=openai_available), gr.update(visible=openai_available), gr.update(
81
+ visible=openai_available), gr.update(visible=True), gr.update(visible=True), gr.update(
82
+ visible=True), text_refiner, visual_chatgpt
83
 
84
 
85
+ def get_click_prompt(chat_input, click_state, click_mode):
86
  inputs = json.loads(chat_input)
87
  if click_mode == 'Continuous':
88
  points = click_state[0]
 
102
  raise NotImplementedError
103
 
104
  prompt = {
105
+ "prompt_type": ["click"],
106
+ "input_point": click_state[0],
107
+ "input_label": click_state[1],
108
+ "multimask_output": "True",
109
  }
110
  return prompt
111
 
112
+
113
  def update_click_state(click_state, caption, click_mode):
114
  if click_mode == 'Continuous':
115
  click_state[2].append(caption)
 
118
  else:
119
  raise NotImplementedError
120
 
121
+ def chat_input_callback(*args):
122
+ visual_chatgpt, chat_input, click_state, state, aux_state = args
123
+ if visual_chatgpt is not None:
124
+ return visual_chatgpt.run_text(chat_input, state, aux_state)
125
+ else:
126
  response = "Text refiner is not initilzed, please input openai api key."
127
  state = state + [(chat_input, response)]
128
+ return state, state
129
+
130
+ def upload_callback(image_input, state, visual_chatgpt=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
133
+ image_input, mask = image_input['image'], image_input['mask']
134
+
135
+ click_state = [[], [], []]
136
+ res = 1024
137
+ width, height = image_input.size
138
+ ratio = min(1.0 * res / max(width, height), 1.0)
139
+ if ratio < 1.0:
140
+ image_input = image_input.resize((int(width * ratio), int(height * ratio)))
141
+ print('Scaling input image to {}'.format(image_input.size))
142
+
143
  model = build_caption_anything_with_models(
144
  args,
145
  api_key="",
146
  captioner=shared_captioner,
147
  sam_model=shared_sam_model,
 
148
  session_id=iface.app_id
149
  )
150
+ model.segmenter.set_image(image_input)
151
+ image_embedding = model.image_embedding
152
+ original_size = model.original_size
153
+ input_size = model.input_size
154
+
155
+ if visual_chatgpt is not None:
156
+ new_image_path = get_new_image_name('chat_image', func_name='upload')
157
+ image_input.save(new_image_path)
158
+ visual_chatgpt.current_image = new_image_path
159
+ img_caption, _ = model.captioner.inference_seg(image_input)
160
+ Human_prompt = f'\nHuman: provide a new figure with path {new_image_path}. The description is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
161
+ AI_prompt = "Received."
162
+ visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
163
+ state = [(None, 'Received new image, resize it to width {} and height {}: '.format(image_input.size[0], image_input.size[1]))]
164
+
165
+ return state, state, image_input, click_state, image_input, image_input, image_embedding, \
166
+ original_size, input_size
167
+
168
+
169
+ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
170
+ length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
171
+ evt: gr.SelectData):
172
+ click_index = evt.index
173
 
174
  if point_prompt == 'Positive':
175
+ coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
176
  else:
177
+ coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
178
+
179
+ prompt = get_click_prompt(coordinate, click_state, click_mode)
180
+ input_points = prompt['input_point']
181
+ input_labels = prompt['input_label']
182
 
183
  controls = {'length': length,
184
  'sentiment': sentiment,
185
  'factuality': factuality,
186
  'language': language}
187
 
188
+ model = build_caption_anything_with_models(
189
+ args,
190
+ api_key="",
191
+ captioner=shared_captioner,
192
+ sam_model=shared_sam_model,
193
+ text_refiner=text_refiner,
194
+ session_id=iface.app_id
195
+ )
196
+
197
+ model.setup(image_embedding, original_size, input_size, is_image_set=True)
198
 
199
  enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
200
  out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
201
+
202
  state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
 
 
203
  state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
204
  wiki = out['generated_captions'].get('wiki', "")
 
205
  update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
206
  text = out['generated_captions']['raw_caption']
 
 
207
  input_mask = np.array(out['mask'].convert('P'))
208
  image_input = mask_painter(np.array(image_input), input_mask)
209
  origin_image_input = image_input
210
+ image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask,
211
+ input_points=input_points, input_labels=input_labels)
212
+ x, y = input_points[-1]
213
+
214
+ if visual_chatgpt is not None:
215
+ new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
216
+ Image.open(out["crop_save_path"]).save(new_crop_save_path)
217
+ point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
218
+ visual_chatgpt.point_prompt = point_prompt
219
+
220
+ yield state, state, click_state, image_input, wiki
221
  if not args.disable_gpt and model.text_refiner:
222
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
223
+ enable_wiki=enable_wiki)
224
  # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
225
  new_cap = refined_caption['caption']
226
  wiki = refined_caption['wiki']
227
  state = state + [(None, f"caption: {new_cap}")]
228
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
229
+ input_mask,
230
+ input_points=input_points, input_labels=input_labels)
231
+ yield state, state, click_state, refined_image_input, wiki
232
 
233
 
234
+ def get_sketch_prompt(mask: PIL.Image.Image):
235
+ """
236
+ Get the prompt for the sketcher.
237
+ TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
238
+ """
239
+
240
+ mask = np.asarray(mask)[..., 0]
241
+
242
+ # Get the bounding box of the sketch
243
+ y, x = np.where(mask != 0)
244
+ x1, y1 = np.min(x), np.min(y)
245
+ x2, y2 = np.max(x), np.max(y)
246
+
247
+ prompt = {
248
+ 'prompt_type': ['box'],
249
+ 'input_boxes': [
250
+ [x1, y1, x2, y2]
251
+ ]
252
+ }
253
+
254
+ return prompt
255
+
256
+
257
+ def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
258
+ original_size, input_size, text_refiner):
259
+ image_input, mask = sketcher_image['image'], sketcher_image['mask']
260
+
261
+ prompt = get_sketch_prompt(mask)
262
+ boxes = prompt['input_boxes']
263
+
264
+ controls = {'length': length,
265
+ 'sentiment': sentiment,
266
+ 'factuality': factuality,
267
+ 'language': language}
268
+
269
  model = build_caption_anything_with_models(
270
  args,
271
  api_key="",
272
  captioner=shared_captioner,
273
  sam_model=shared_sam_model,
274
+ text_refiner=text_refiner,
275
  session_id=iface.app_id
276
  )
277
+
278
+ model.setup(image_embedding, original_size, input_size, is_image_set=True)
279
+
280
+ enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
281
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
282
+
283
+ # Update components and states
284
+ state.append((f'Box: {boxes}', None))
285
+ state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
286
+ wiki = out['generated_captions'].get('wiki', "")
287
+ text = out['generated_captions']['raw_caption']
288
+ input_mask = np.array(out['mask'].convert('P'))
289
+ image_input = mask_painter(np.array(image_input), input_mask)
290
+
291
+ origin_image_input = image_input
292
+
293
+ fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
294
+ image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
295
+
296
+ yield state, state, image_input, wiki
297
+
298
+ if not args.disable_gpt and model.text_refiner:
299
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
300
+ enable_wiki=enable_wiki)
301
+
302
+ new_cap = refined_caption['caption']
303
+ wiki = refined_caption['wiki']
304
+ state = state + [(None, f"caption: {new_cap}")]
305
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
306
+
307
+ yield state, state, refined_image_input, wiki
308
+
309
+ def clear_chat_memory(visual_chatgpt):
310
+ if visual_chatgpt is not None:
311
+ visual_chatgpt.memory.clear()
312
+ visual_chatgpt.current_image = None
313
+ visual_chatgpt.point_prompt = ""
314
+
315
+ def get_style():
316
+ current_version = version.parse(gr.__version__)
317
+ if current_version <= version.parse('3.24.1'):
318
+ style = '''
319
+ #image_sketcher{min-height:500px}
320
+ #image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
321
+ #image_upload{min-height:500px}
322
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
323
+ '''
324
+ elif current_version <= version.parse('3.27'):
325
+ style = '''
326
+ #image_sketcher{min-height:500px}
327
+ #image_upload{min-height:500px}
328
+ '''
329
+ else:
330
+ style = None
331
+
332
+ return style
333
+
334
+
335
+ def create_ui():
336
+ title = """<p><h1 align="center">Caption-Anything</h1></p>
337
+ """
338
+ description = """<p>Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: <a href="https://github.com/ttengwang/Caption-Anything">https://github.com/ttengwang/Caption-Anything</a> <a href="https://huggingface.co/spaces/TencentARC/Caption-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
339
+
340
+ examples = [
341
+ ["test_images/img35.webp"],
342
+ ["test_images/img2.jpg"],
343
+ ["test_images/img5.jpg"],
344
+ ["test_images/img12.jpg"],
345
+ ["test_images/img14.jpg"],
346
+ ["test_images/qingming3.jpeg"],
347
+ ["test_images/img1.jpg"],
348
+ ]
349
+
350
+ with gr.Blocks(
351
+ css=get_style()
352
+ ) as iface:
353
+ state = gr.State([])
354
+ click_state = gr.State([[], [], []])
355
+ # chat_state = gr.State([])
356
+ origin_image = gr.State(None)
357
+ image_embedding = gr.State(None)
358
+ text_refiner = gr.State(None)
359
+ visual_chatgpt = gr.State(None)
360
+ original_size = gr.State(None)
361
+ input_size = gr.State(None)
362
+ # img_caption = gr.State(None)
363
+ aux_state = gr.State([])
364
+
365
+ gr.Markdown(title)
366
+ gr.Markdown(description)
367
+
368
+ with gr.Row():
369
+ with gr.Column(scale=1.0):
370
+ with gr.Column(visible=False) as modules_not_need_gpt:
371
+ with gr.Tab("Click"):
372
+ image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
373
+ example_image = gr.Image(type="pil", interactive=False, visible=False)
374
+ with gr.Row(scale=1.0):
375
+ with gr.Row(scale=0.4):
376
+ point_prompt = gr.Radio(
377
+ choices=["Positive", "Negative"],
378
+ value="Positive",
379
+ label="Point Prompt",
380
+ interactive=True)
381
+ click_mode = gr.Radio(
382
+ choices=["Continuous", "Single"],
383
+ value="Continuous",
384
+ label="Clicking Mode",
385
+ interactive=True)
386
+ with gr.Row(scale=0.4):
387
+ clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
388
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
389
+ with gr.Tab("Trajectory (beta)"):
390
+ sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=20,
391
+ elem_id="image_sketcher")
392
+ with gr.Row():
393
+ submit_button_sketcher = gr.Button(value="Submit", interactive=True)
394
+
395
+ with gr.Column(visible=False) as modules_need_gpt:
396
+ with gr.Row(scale=1.0):
397
+ language = gr.Dropdown(
398
+ ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
399
+ value="English", label="Language", interactive=True)
400
+ sentiment = gr.Radio(
401
+ choices=["Positive", "Natural", "Negative"],
402
+ value="Natural",
403
+ label="Sentiment",
404
+ interactive=True,
405
+ )
406
+ with gr.Row(scale=1.0):
407
+ factuality = gr.Radio(
408
+ choices=["Factual", "Imagination"],
409
+ value="Factual",
410
+ label="Factuality",
411
+ interactive=True,
412
+ )
413
+ length = gr.Slider(
414
+ minimum=10,
415
+ maximum=80,
416
+ value=10,
417
+ step=1,
418
+ interactive=True,
419
+ label="Generated Caption Length",
420
+ )
421
+ enable_wiki = gr.Radio(
422
+ choices=["Yes", "No"],
423
+ value="No",
424
+ label="Enable Wiki",
425
  interactive=True)
426
+ with gr.Column(visible=True) as modules_not_need_gpt3:
427
+ gr.Examples(
428
+ examples=examples,
429
+ inputs=[example_image],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  )
431
+ with gr.Column(scale=0.5):
432
+ openai_api_key = gr.Textbox(
433
+ placeholder="Input openAI API key",
434
+ show_label=False,
435
+ label="OpenAI API Key",
436
+ lines=1,
437
+ type="password")
438
+ with gr.Row(scale=0.5):
439
+ enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
440
+ disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
441
+ variant='primary')
442
+ with gr.Column(visible=False) as modules_need_gpt2:
443
+ wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
444
+ with gr.Column(visible=False) as modules_not_need_gpt2:
445
+ chatbot = gr.Chatbot(label="Chat about Selected Object", ).style(height=550, scale=0.5)
446
+ with gr.Column(visible=False) as modules_need_gpt3:
447
+ chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
448
+ container=False)
449
+ with gr.Row():
450
+ clear_button_text = gr.Button(value="Clear Text", interactive=True)
451
+ submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
452
+
453
+ openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
454
+ outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
455
+ modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
456
+ enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
457
+ outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
458
+ modules_not_need_gpt,
459
+ modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
460
+ disable_chatGPT_button.click(init_openai_api_key,
461
+ outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
462
+ modules_not_need_gpt,
463
+ modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
464
+
465
+ clear_button_click.click(
466
+ lambda x: ([[], [], []], x, ""),
467
+ [origin_image],
468
+ [click_state, image_input, wiki_output],
469
+ queue=False,
470
+ show_progress=False
471
+ )
472
+ clear_button_image.click(
473
+ lambda: (None, [], [], [[], [], []], "", "", ""),
474
+ [],
475
+ [image_input, chatbot, state, click_state, wiki_output, origin_image],
476
+ queue=False,
477
+ show_progress=False
478
+ )
479
+ clear_button_image.click(clear_chat_memory, inputs=[visual_chatgpt])
480
+ clear_button_text.click(
481
+ lambda: ([], [], [[], [], [], []]),
482
+ [],
483
+ [chatbot, state, click_state],
484
+ queue=False,
485
+ show_progress=False
486
+ )
487
+ clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt])
488
+
489
+ image_input.clear(
490
+ lambda: (None, [], [], [[], [], []], "", "", ""),
491
+ [],
492
+ [image_input, chatbot, state, click_state, wiki_output, origin_image],
493
+ queue=False,
494
+ show_progress=False
495
+ )
496
+
497
+ image_input.clear(clear_chat_memory, inputs=[visual_chatgpt])
498
+
499
+
500
+ image_input.upload(upload_callback, [image_input, state, visual_chatgpt],
501
+ [chatbot, state, origin_image, click_state, image_input, sketcher_input,
502
+ image_embedding, original_size, input_size])
503
+ sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt],
504
+ [chatbot, state, origin_image, click_state, image_input, sketcher_input,
505
+ image_embedding, original_size, input_size])
506
+ chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state],
507
+ [chatbot, state, aux_state])
508
+ chat_input.submit(lambda: "", None, chat_input)
509
+ submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state],
510
+ [chatbot, state, aux_state])
511
+ submit_button_text.click(lambda: "", None, chat_input)
512
+ example_image.change(upload_callback, [example_image, state, visual_chatgpt],
513
+ [chatbot, state, origin_image, click_state, image_input, sketcher_input,
514
+ image_embedding, original_size, input_size])
515
+ example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
516
+ # select coordinate
517
+ image_input.select(
518
+ inference_click,
519
+ inputs=[
520
+ origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
521
+ image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt
522
+ ],
523
+ outputs=[chatbot, state, click_state, image_input, wiki_output],
524
+ show_progress=False, queue=True
525
+ )
526
+
527
+ submit_button_sketcher.click(
528
+ inference_traject,
529
+ inputs=[
530
+ sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
531
+ original_size, input_size, text_refiner
532
+ ],
533
+ outputs=[chatbot, state, sketcher_input, wiki_output],
534
+ show_progress=False, queue=True
535
+ )
536
+
537
+ return iface
538
+
539
+
540
+ if __name__ == '__main__':
541
+ iface = create_ui()
542
+ iface.queue(concurrency_count=5, api_open=False, max_size=10)
543
+ iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
app_huggingface.py DELETED
@@ -1,268 +0,0 @@
1
- from io import BytesIO
2
- import string
3
- import gradio as gr
4
- import requests
5
- from caption_anything import CaptionAnything
6
- import torch
7
- import json
8
- import sys
9
- import argparse
10
- from caption_anything import parse_augment
11
- import numpy as np
12
- import PIL.ImageDraw as ImageDraw
13
- from image_editing_utils import create_bubble_frame
14
- import copy
15
- from tools import mask_painter
16
- from PIL import Image
17
- import os
18
-
19
- def download_checkpoint(url, folder, filename):
20
- os.makedirs(folder, exist_ok=True)
21
- filepath = os.path.join(folder, filename)
22
-
23
- if not os.path.exists(filepath):
24
- response = requests.get(url, stream=True)
25
- with open(filepath, "wb") as f:
26
- for chunk in response.iter_content(chunk_size=8192):
27
- if chunk:
28
- f.write(chunk)
29
-
30
- return filepath
31
- checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
32
- folder = "segmenter"
33
- filename = "sam_vit_h_4b8939.pth"
34
-
35
- download_checkpoint(checkpoint_url, folder, filename)
36
-
37
-
38
- title = """<h1 align="center">Caption-Anything</h1>"""
39
- description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
40
- """
41
-
42
- examples = [
43
- ["test_img/img2.jpg"],
44
- ["test_img/img5.jpg"],
45
- ["test_img/img12.jpg"],
46
- ["test_img/img14.jpg"],
47
- ]
48
-
49
- args = parse_augment()
50
- args.captioner = 'blip2'
51
- args.seg_crop_mode = 'wo_bg'
52
- args.regular_box = True
53
- # args.device = 'cuda:5'
54
- # args.disable_gpt = False
55
- # args.enable_reduce_tokens = True
56
- # args.port=20322
57
- model = CaptionAnything(args)
58
-
59
- def init_openai_api_key(api_key):
60
- os.environ['OPENAI_API_KEY'] = api_key
61
- model.init_refiner()
62
-
63
-
64
- def get_prompt(chat_input, click_state):
65
- points = click_state[0]
66
- labels = click_state[1]
67
- inputs = json.loads(chat_input)
68
- for input in inputs:
69
- points.append(input[:2])
70
- labels.append(input[2])
71
-
72
- prompt = {
73
- "prompt_type":["click"],
74
- "input_point":points,
75
- "input_label":labels,
76
- "multimask_output":"True",
77
- }
78
- return prompt
79
-
80
- def chat_with_points(chat_input, click_state, state):
81
- if not hasattr(model, "text_refiner"):
82
- response = "Text refiner is not initilzed, please input openai api key."
83
- state = state + [(chat_input, response)]
84
- return state, state
85
-
86
- points, labels, captions = click_state
87
- # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
88
- # # "The image is of width {width} and height {height}."
89
- point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
90
- prev_visual_context = ""
91
- pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
92
- if len(captions):
93
- prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
94
- else:
95
- prev_visual_context = 'no point exists.'
96
- chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
97
- response = model.text_refiner.llm(chat_prompt)
98
- state = state + [(chat_input, response)]
99
- return state, state
100
-
101
- def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt:gr.SelectData):
102
-
103
- if point_prompt == 'Positive':
104
- coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
105
- else:
106
- coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
107
-
108
- controls = {'length': length,
109
- 'sentiment': sentiment,
110
- 'factuality': factuality,
111
- 'language': language}
112
-
113
- # click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
114
- # chat_input = click_coordinate
115
- prompt = get_prompt(coordinate, click_state)
116
- print('prompt: ', prompt, 'controls: ', controls)
117
-
118
- out = model.inference(image_input, prompt, controls)
119
- state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
120
- # for k, v in out['generated_captions'].items():
121
- # state = state + [(f'{k}: {v}', None)]
122
- state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
123
- wiki = out['generated_captions'].get('wiki', "")
124
- click_state[2].append(out['generated_captions']['raw_caption'])
125
-
126
- text = out['generated_captions']['raw_caption']
127
- # draw = ImageDraw.Draw(image_input)
128
- # draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
129
- input_mask = np.array(Image.open(out['mask_save_path']).convert('P'))
130
- image_input = mask_painter(np.array(image_input), input_mask)
131
- origin_image_input = image_input
132
- image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
133
-
134
- yield state, state, click_state, chat_input, image_input, wiki
135
- if not args.disable_gpt and hasattr(model, "text_refiner"):
136
- refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
137
- # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
138
- new_cap = refined_caption['caption']
139
- refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
140
- yield state, state, click_state, chat_input, refined_image_input, wiki
141
-
142
-
143
- def upload_callback(image_input, state):
144
- state = [] + [('Image size: ' + str(image_input.size), None)]
145
- click_state = [[], [], []]
146
- model.segmenter.image = None
147
- model.segmenter.image_embedding = None
148
- model.segmenter.set_image(image_input)
149
- return state, image_input, click_state
150
-
151
- with gr.Blocks(
152
- css='''
153
- #image_upload{min-height:400px}
154
- #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
155
- '''
156
- ) as iface:
157
- state = gr.State([])
158
- click_state = gr.State([[],[],[]])
159
- origin_image = gr.State(None)
160
-
161
- gr.Markdown(title)
162
- gr.Markdown(description)
163
-
164
- with gr.Row():
165
- with gr.Column(scale=1.0):
166
- image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
167
- with gr.Row(scale=1.0):
168
- point_prompt = gr.Radio(
169
- choices=["Positive", "Negative"],
170
- value="Positive",
171
- label="Point Prompt",
172
- interactive=True)
173
- clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
174
- clear_button_image = gr.Button(value="Clear Image", interactive=True)
175
- with gr.Row(scale=1.0):
176
- language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
177
-
178
- sentiment = gr.Radio(
179
- choices=["Positive", "Natural", "Negative"],
180
- value="Natural",
181
- label="Sentiment",
182
- interactive=True,
183
- )
184
- with gr.Row(scale=1.0):
185
- factuality = gr.Radio(
186
- choices=["Factual", "Imagination"],
187
- value="Factual",
188
- label="Factuality",
189
- interactive=True,
190
- )
191
- length = gr.Slider(
192
- minimum=10,
193
- maximum=80,
194
- value=10,
195
- step=1,
196
- interactive=True,
197
- label="Length",
198
- )
199
-
200
- with gr.Column(scale=0.5):
201
- openai_api_key = gr.Textbox(
202
- placeholder="Input your openAI API key and press Enter",
203
- show_label=False,
204
- label = "OpenAI API Key",
205
- lines=1,
206
- type="password"
207
- )
208
- openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
209
- wiki_output = gr.Textbox(lines=6, label="Wiki")
210
- chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=450,scale=0.5)
211
- chat_input = gr.Textbox(lines=1, label="Chat Input")
212
- with gr.Row():
213
- clear_button_text = gr.Button(value="Clear Text", interactive=True)
214
- submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
215
- clear_button_clike.click(
216
- lambda x: ([[], [], []], x, ""),
217
- [origin_image],
218
- [click_state, image_input, wiki_output],
219
- queue=False,
220
- show_progress=False
221
- )
222
- clear_button_image.click(
223
- lambda: (None, [], [], [[], [], []], ""),
224
- [],
225
- [image_input, chatbot, state, click_state, wiki_output],
226
- queue=False,
227
- show_progress=False
228
- )
229
- clear_button_text.click(
230
- lambda: ([], [], [[], [], []]),
231
- [],
232
- [chatbot, state, click_state],
233
- queue=False,
234
- show_progress=False
235
- )
236
- image_input.clear(
237
- lambda: (None, [], [], [[], [], []], ""),
238
- [],
239
- [image_input, chatbot, state, click_state, wiki_output],
240
- queue=False,
241
- show_progress=False
242
- )
243
-
244
- examples = gr.Examples(
245
- examples=examples,
246
- inputs=[image_input],
247
- )
248
-
249
- image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state])
250
- chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
251
-
252
- # select coordinate
253
- image_input.select(inference_seg_cap,
254
- inputs=[
255
- origin_image,
256
- point_prompt,
257
- language,
258
- sentiment,
259
- factuality,
260
- length,
261
- state,
262
- click_state
263
- ],
264
- outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
265
- show_progress=False, queue=True)
266
-
267
- iface.queue(concurrency_count=1, api_open=False, max_size=10)
268
- iface.launch(server_name="0.0.0.0", enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_old.py DELETED
@@ -1,261 +0,0 @@
1
- from io import BytesIO
2
- import string
3
- import gradio as gr
4
- import requests
5
- from caption_anything import CaptionAnything
6
- import torch
7
- import json
8
- import sys
9
- import argparse
10
- from caption_anything import parse_augment
11
- import os
12
-
13
- # download sam checkpoint if not downloaded
14
- def download_checkpoint(url, folder, filename):
15
- os.makedirs(folder, exist_ok=True)
16
- filepath = os.path.join(folder, filename)
17
-
18
- if not os.path.exists(filepath):
19
- response = requests.get(url, stream=True)
20
- with open(filepath, "wb") as f:
21
- for chunk in response.iter_content(chunk_size=8192):
22
- if chunk:
23
- f.write(chunk)
24
-
25
- return filepath
26
- checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
27
- folder = "segmenter"
28
- filename = "sam_vit_h_4b8939.pth"
29
-
30
- title = """<h1 align="center">Caption-Anything</h1>"""
31
- description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them.
32
- <br> <strong>Code</strong>: GitHub repo: <a href='https://github.com/ttengwang/Caption-Anything' target='_blank'></a>
33
- """
34
-
35
- examples = [
36
- ["test_img/img2.jpg", "[[1000, 700, 1]]"]
37
- ]
38
-
39
- args = parse_augment()
40
-
41
- def get_prompt(chat_input, click_state):
42
- points = click_state[0]
43
- labels = click_state[1]
44
- inputs = json.loads(chat_input)
45
- for input in inputs:
46
- points.append(input[:2])
47
- labels.append(input[2])
48
-
49
- prompt = {
50
- "prompt_type":["click"],
51
- "input_point":points,
52
- "input_label":labels,
53
- "multimask_output":"True",
54
- }
55
- return prompt
56
-
57
- def inference_seg_cap(image_input, chat_input, language, sentiment, factuality, length, state, click_state):
58
- controls = {'length': length,
59
- 'sentiment': sentiment,
60
- 'factuality': factuality,
61
- 'language': language}
62
- prompt = get_prompt(chat_input, click_state)
63
- print('prompt: ', prompt, 'controls: ', controls)
64
- out = model.inference(image_input, prompt, controls)
65
- state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
66
- for k, v in out['generated_captions'].items():
67
- state = state + [(f'{k}: {v}', None)]
68
- click_state[2].append(out['generated_captions']['raw_caption'])
69
- image_output_mask = out['mask_save_path']
70
- image_output_crop = out['crop_save_path']
71
- return state, state, click_state, image_output_mask, image_output_crop
72
-
73
-
74
- def upload_callback(image_input, state):
75
- state = state + [('Image size: ' + str(image_input.size), None)]
76
- return state
77
-
78
- # get coordinate in format [[x,y,positive/negative]]
79
- def get_select_coords(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt: gr.SelectData):
80
- print("point_prompt: ", point_prompt)
81
- if point_prompt == 'Positive Point':
82
- coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
83
- else:
84
- coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
85
- return (coordinate,) + inference_seg_cap(image_input, coordinate, language, sentiment, factuality, length, state, click_state)
86
-
87
- def chat_with_points(chat_input, click_state, state):
88
- points, labels, captions = click_state
89
- # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\n. Now begin chatting! Human: {chat_input}\nAI: "
90
- # "The image is of width {width} and height {height}."
91
- point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
92
- prev_visual_context = ""
93
- pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
94
- prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
95
- chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
96
- response = model.text_refiner.llm(chat_prompt)
97
- state = state + [(chat_input, response)]
98
- return state, state
99
-
100
- def init_openai_api_key(api_key):
101
- # os.environ['OPENAI_API_KEY'] = api_key
102
- global model
103
- model = CaptionAnything(args, api_key)
104
-
105
- css='''
106
- #image_upload{min-height:200px}
107
- #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 200px}
108
- '''
109
-
110
- with gr.Blocks(css=css) as iface:
111
- state = gr.State([])
112
- click_state = gr.State([[],[],[]])
113
- caption_state = gr.State([[]])
114
- gr.Markdown(title)
115
- gr.Markdown(description)
116
-
117
- with gr.Column():
118
- openai_api_key = gr.Textbox(
119
- placeholder="Input your openAI API key and press Enter",
120
- show_label=False,
121
- lines=1,
122
- type="password",
123
- )
124
- openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
125
-
126
- with gr.Row():
127
- with gr.Column(scale=0.7):
128
- image_input = gr.Image(type="pil", interactive=True, label="Image", elem_id="image_upload").style(height=260,scale=1.0)
129
-
130
- with gr.Row(scale=0.7):
131
- point_prompt = gr.Radio(
132
- choices=["Positive Point", "Negative Point"],
133
- value="Positive Point",
134
- label="Points",
135
- interactive=True,
136
- )
137
-
138
- # with gr.Row():
139
- language = gr.Radio(
140
- choices=["English", "Chinese", "French", "Spanish", "Arabic", "Portuguese","Cantonese"],
141
- value="English",
142
- label="Language",
143
- interactive=True,
144
- )
145
- sentiment = gr.Radio(
146
- choices=["Positive", "Natural", "Negative"],
147
- value="Natural",
148
- label="Sentiment",
149
- interactive=True,
150
- )
151
- factuality = gr.Radio(
152
- choices=["Factual", "Imagination"],
153
- value="Factual",
154
- label="Factuality",
155
- interactive=True,
156
- )
157
- length = gr.Slider(
158
- minimum=5,
159
- maximum=100,
160
- value=10,
161
- step=1,
162
- interactive=True,
163
- label="Length",
164
- )
165
-
166
- with gr.Column(scale=1.5):
167
- with gr.Row():
168
- image_output_mask= gr.Image(type="pil", interactive=False, label="Mask").style(height=260,scale=1.0)
169
- image_output_crop= gr.Image(type="pil", interactive=False, label="Cropped Image by Mask", show_progress=False).style(height=260,scale=1.0)
170
- chatbot = gr.Chatbot(label="Chat Output",).style(height=450,scale=0.5)
171
-
172
- with gr.Row():
173
- with gr.Column(scale=0.7):
174
- prompt_input = gr.Textbox(lines=1, label="Input Prompt (A list of points like : [[100, 200, 1], [200,300,0]])")
175
- prompt_input.submit(
176
- inference_seg_cap,
177
- [
178
- image_input,
179
- prompt_input,
180
- language,
181
- sentiment,
182
- factuality,
183
- length,
184
- state,
185
- click_state
186
- ],
187
- [chatbot, state, click_state, image_output_mask, image_output_crop],
188
- show_progress=False
189
- )
190
-
191
- image_input.upload(
192
- upload_callback,
193
- [image_input, state],
194
- [chatbot]
195
- )
196
-
197
- with gr.Row():
198
- clear_button = gr.Button(value="Clear Click", interactive=True)
199
- clear_button.click(
200
- lambda: ("", [[], [], []], None, None),
201
- [],
202
- [prompt_input, click_state, image_output_mask, image_output_crop],
203
- queue=False,
204
- show_progress=False
205
- )
206
-
207
- clear_button = gr.Button(value="Clear", interactive=True)
208
- clear_button.click(
209
- lambda: ("", [], [], [[], [], []], None, None),
210
- [],
211
- [prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop],
212
- queue=False,
213
- show_progress=False
214
- )
215
-
216
- submit_button = gr.Button(
217
- value="Submit", interactive=True, variant="primary"
218
- )
219
- submit_button.click(
220
- inference_seg_cap,
221
- [
222
- image_input,
223
- prompt_input,
224
- language,
225
- sentiment,
226
- factuality,
227
- length,
228
- state,
229
- click_state
230
- ],
231
- [chatbot, state, click_state, image_output_mask, image_output_crop],
232
- show_progress=False
233
- )
234
-
235
- # select coordinate
236
- image_input.select(
237
- get_select_coords,
238
- inputs=[image_input,point_prompt,language,sentiment,factuality,length,state,click_state],
239
- outputs=[prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop],
240
- show_progress=False
241
- )
242
-
243
- image_input.change(
244
- lambda: ("", [], [[], [], []]),
245
- [],
246
- [chatbot, state, click_state],
247
- queue=False,
248
- )
249
-
250
- with gr.Column(scale=1.5):
251
- chat_input = gr.Textbox(lines=1, label="Chat Input")
252
- chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
253
-
254
-
255
- examples = gr.Examples(
256
- examples=examples,
257
- inputs=[image_input, prompt_input],
258
- )
259
-
260
- iface.queue(concurrency_count=1, api_open=False, max_size=10)
261
- iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_wo_langchain.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import List
4
+
5
+ import PIL
6
+ import gradio as gr
7
+ import numpy as np
8
+ from gradio import processing_utils
9
+
10
+ from packaging import version
11
+ from PIL import Image, ImageDraw
12
+
13
+ from caption_anything.model import CaptionAnything
14
+ from caption_anything.utils.image_editing_utils import create_bubble_frame
15
+ from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter
16
+ from caption_anything.utils.parser import parse_augment
17
+ from caption_anything.captioner import build_captioner
18
+ from caption_anything.text_refiner import build_text_refiner
19
+ from caption_anything.segmenter import build_segmenter
20
+ from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
21
+ from segment_anything import sam_model_registry
22
+
23
+
24
+ args = parse_augment()
25
+
26
+ args = parse_augment()
27
+ if args.segmenter_checkpoint is None:
28
+ _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
29
+ else:
30
+ segmenter_checkpoint = args.segmenter_checkpoint
31
+
32
+ shared_captioner = build_captioner(args.captioner, args.device, args)
33
+ shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
34
+
35
+
36
+ class ImageSketcher(gr.Image):
37
+ """
38
+ Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
39
+ """
40
+
41
+ is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
42
+
43
+ def __init__(self, **kwargs):
44
+ super().__init__(tool="sketch", **kwargs)
45
+
46
+ def preprocess(self, x):
47
+ if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
48
+ assert isinstance(x, dict)
49
+ if x['mask'] is None:
50
+ decode_image = processing_utils.decode_base64_to_image(x['image'])
51
+ width, height = decode_image.size
52
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
53
+ mask[..., -1] = 255
54
+ mask = self.postprocess(mask)
55
+
56
+ x['mask'] = mask
57
+
58
+ return super().preprocess(x)
59
+
60
+
61
+ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None,
62
+ session_id=None):
63
+ segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
64
+ captioner = captioner
65
+ if session_id is not None:
66
+ print('Init caption anything for session {}'.format(session_id))
67
+ return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
68
+
69
+
70
+ def init_openai_api_key(api_key=""):
71
+ text_refiner = None
72
+ if api_key and len(api_key) > 30:
73
+ try:
74
+ text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
75
+ text_refiner.llm('hi') # test
76
+ except:
77
+ text_refiner = None
78
+ openai_available = text_refiner is not None
79
+ return gr.update(visible=openai_available), gr.update(visible=openai_available), gr.update(
80
+ visible=openai_available), gr.update(visible=True), gr.update(visible=True), gr.update(
81
+ visible=True), text_refiner
82
+
83
+
84
+ def get_click_prompt(chat_input, click_state, click_mode):
85
+ inputs = json.loads(chat_input)
86
+ if click_mode == 'Continuous':
87
+ points = click_state[0]
88
+ labels = click_state[1]
89
+ for input in inputs:
90
+ points.append(input[:2])
91
+ labels.append(input[2])
92
+ elif click_mode == 'Single':
93
+ points = []
94
+ labels = []
95
+ for input in inputs:
96
+ points.append(input[:2])
97
+ labels.append(input[2])
98
+ click_state[0] = points
99
+ click_state[1] = labels
100
+ else:
101
+ raise NotImplementedError
102
+
103
+ prompt = {
104
+ "prompt_type": ["click"],
105
+ "input_point": click_state[0],
106
+ "input_label": click_state[1],
107
+ "multimask_output": "True",
108
+ }
109
+ return prompt
110
+
111
+
112
+ def update_click_state(click_state, caption, click_mode):
113
+ if click_mode == 'Continuous':
114
+ click_state[2].append(caption)
115
+ elif click_mode == 'Single':
116
+ click_state[2] = [caption]
117
+ else:
118
+ raise NotImplementedError
119
+
120
+
121
+ def chat_with_points(chat_input, click_state, chat_state, state, text_refiner, img_caption):
122
+ if text_refiner is None:
123
+ response = "Text refiner is not initilzed, please input openai api key."
124
+ state = state + [(chat_input, response)]
125
+ return state, state, chat_state
126
+
127
+ points, labels, captions = click_state
128
+ # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
129
+ suffix = '\nHuman: {chat_input}\nAI: '
130
+ qa_template = '\nHuman: {q}\nAI: {a}'
131
+ # # "The image is of width {width} and height {height}."
132
+ point_chat_prompt = "I am an AI trained to chat with you about an image. I am greate at what is going on in any image based on the image information your provide. The overall image description is \"{img_caption}\". You will also provide me objects in the image in details, i.e., their location and visual descriptions. Here are the locations and descriptions of events that happen in the image: {points_with_caps} \nYou are required to use language instead of number to describe these positions. Now, let's chat!"
133
+ prev_visual_context = ""
134
+ pos_points = []
135
+ pos_captions = []
136
+
137
+ for i in range(len(points)):
138
+ if labels[i] == 1:
139
+ pos_points.append(f"(X:{points[i][0]}, Y:{points[i][1]})")
140
+ pos_captions.append(captions[i])
141
+ prev_visual_context = prev_visual_context + '\n' + 'There is an event described as \"{}\" locating at {}'.format(
142
+ pos_captions[-1], ', '.join(pos_points))
143
+
144
+ context_length_thres = 500
145
+ prev_history = ""
146
+ for i in range(len(chat_state)):
147
+ q, a = chat_state[i]
148
+ if len(prev_history) < context_length_thres:
149
+ prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
150
+ else:
151
+ break
152
+ chat_prompt = point_chat_prompt.format(
153
+ **{"img_caption": img_caption, "points_with_caps": prev_visual_context}) + prev_history + suffix.format(
154
+ **{"chat_input": chat_input})
155
+ print('\nchat_prompt: ', chat_prompt)
156
+ response = text_refiner.llm(chat_prompt)
157
+ state = state + [(chat_input, response)]
158
+ chat_state = chat_state + [(chat_input, response)]
159
+ return state, state, chat_state
160
+
161
+
162
+ def upload_callback(image_input, state):
163
+ if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
164
+ image_input, mask = image_input['image'], image_input['mask']
165
+
166
+ chat_state = []
167
+ click_state = [[], [], []]
168
+ res = 1024
169
+ width, height = image_input.size
170
+ ratio = min(1.0 * res / max(width, height), 1.0)
171
+ if ratio < 1.0:
172
+ image_input = image_input.resize((int(width * ratio), int(height * ratio)))
173
+ print('Scaling input image to {}'.format(image_input.size))
174
+ state = [] + [(None, 'Image size: ' + str(image_input.size))]
175
+ model = build_caption_anything_with_models(
176
+ args,
177
+ api_key="",
178
+ captioner=shared_captioner,
179
+ sam_model=shared_sam_model,
180
+ session_id=iface.app_id
181
+ )
182
+ model.segmenter.set_image(image_input)
183
+ image_embedding = model.image_embedding
184
+ original_size = model.original_size
185
+ input_size = model.input_size
186
+ img_caption, _ = model.captioner.inference_seg(image_input)
187
+
188
+ return state, state, chat_state, image_input, click_state, image_input, image_input, image_embedding, \
189
+ original_size, input_size, img_caption
190
+
191
+
192
+ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
193
+ length, image_embedding, state, click_state, original_size, input_size, text_refiner,
194
+ evt: gr.SelectData):
195
+ click_index = evt.index
196
+
197
+ if point_prompt == 'Positive':
198
+ coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
199
+ else:
200
+ coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
201
+
202
+ prompt = get_click_prompt(coordinate, click_state, click_mode)
203
+ input_points = prompt['input_point']
204
+ input_labels = prompt['input_label']
205
+
206
+ controls = {'length': length,
207
+ 'sentiment': sentiment,
208
+ 'factuality': factuality,
209
+ 'language': language}
210
+
211
+ model = build_caption_anything_with_models(
212
+ args,
213
+ api_key="",
214
+ captioner=shared_captioner,
215
+ sam_model=shared_sam_model,
216
+ text_refiner=text_refiner,
217
+ session_id=iface.app_id
218
+ )
219
+
220
+ model.setup(image_embedding, original_size, input_size, is_image_set=True)
221
+
222
+ enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
223
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
224
+
225
+ state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
226
+ state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
227
+ wiki = out['generated_captions'].get('wiki', "")
228
+ update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
229
+ text = out['generated_captions']['raw_caption']
230
+ input_mask = np.array(out['mask'].convert('P'))
231
+ image_input = mask_painter(np.array(image_input), input_mask)
232
+ origin_image_input = image_input
233
+ image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask,
234
+ input_points=input_points, input_labels=input_labels)
235
+ yield state, state, click_state, image_input, wiki
236
+ if not args.disable_gpt and model.text_refiner:
237
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
238
+ enable_wiki=enable_wiki)
239
+ # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
240
+ new_cap = refined_caption['caption']
241
+ wiki = refined_caption['wiki']
242
+ state = state + [(None, f"caption: {new_cap}")]
243
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
244
+ input_mask,
245
+ input_points=input_points, input_labels=input_labels)
246
+ yield state, state, click_state, refined_image_input, wiki
247
+
248
+
249
+ def get_sketch_prompt(mask: PIL.Image.Image, multi_mask=True):
250
+ """
251
+ Get the prompt for the sketcher.
252
+ TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
253
+ """
254
+
255
+ mask = np.array(np.asarray(mask)[..., 0])
256
+ mask[mask > 0] = 1 # Refine the mask, let all nonzero values be 1
257
+
258
+ if not multi_mask:
259
+ y, x = np.where(mask == 1)
260
+ x1, y1 = np.min(x), np.min(y)
261
+ x2, y2 = np.max(x), np.max(y)
262
+
263
+ prompt = {
264
+ 'prompt_type': ['box'],
265
+ 'input_boxes': [
266
+ [x1, y1, x2, y2]
267
+ ]
268
+ }
269
+
270
+ return prompt
271
+
272
+ traversed = np.zeros_like(mask)
273
+ groups = np.zeros_like(mask)
274
+ max_group_id = 1
275
+
276
+ # Iterate over all pixels
277
+ for x in range(mask.shape[0]):
278
+ for y in range(mask.shape[1]):
279
+ if traversed[x, y] == 1:
280
+ continue
281
+
282
+ if mask[x, y] == 0:
283
+ traversed[x, y] = 1
284
+ else:
285
+ # If pixel is part of mask
286
+ groups[x, y] = max_group_id
287
+ stack = [(x, y)]
288
+ while stack:
289
+ i, j = stack.pop()
290
+ if traversed[i, j] == 1:
291
+ continue
292
+ traversed[i, j] = 1
293
+ if mask[i, j] == 1:
294
+ groups[i, j] = max_group_id
295
+ for di, dj in [(1, 0), (-1, 0), (0, 1), (0, -1), (1, 1), (1, -1), (-1, 1), (-1, -1)]:
296
+ ni, nj = i + di, j + dj
297
+ traversed[i, j] = 1
298
+ if 0 <= nj < mask.shape[1] and mask.shape[0] > ni >= 0 == traversed[ni, nj]:
299
+ stack.append((i + di, j + dj))
300
+ max_group_id += 1
301
+
302
+ # get the bounding box of each group
303
+ boxes = []
304
+ for group in range(1, max_group_id):
305
+ y, x = np.where(groups == group)
306
+ x1, y1 = np.min(x), np.min(y)
307
+ x2, y2 = np.max(x), np.max(y)
308
+ boxes.append([x1, y1, x2, y2])
309
+
310
+ prompt = {
311
+ 'prompt_type': ['box'],
312
+ 'input_boxes': boxes
313
+ }
314
+
315
+ return prompt
316
+
317
+
318
+ def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
319
+ original_size, input_size, text_refiner):
320
+ image_input, mask = sketcher_image['image'], sketcher_image['mask']
321
+
322
+ prompt = get_sketch_prompt(mask, multi_mask=False)
323
+ boxes = prompt['input_boxes']
324
+
325
+ controls = {'length': length,
326
+ 'sentiment': sentiment,
327
+ 'factuality': factuality,
328
+ 'language': language}
329
+
330
+ model = build_caption_anything_with_models(
331
+ args,
332
+ api_key="",
333
+ captioner=shared_captioner,
334
+ sam_model=shared_sam_model,
335
+ text_refiner=text_refiner,
336
+ session_id=iface.app_id
337
+ )
338
+
339
+ model.setup(image_embedding, original_size, input_size, is_image_set=True)
340
+
341
+ enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
342
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
343
+
344
+ # Update components and states
345
+ state.append((f'Box: {boxes}', None))
346
+ state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
347
+ wiki = out['generated_captions'].get('wiki', "")
348
+ text = out['generated_captions']['raw_caption']
349
+ input_mask = np.array(out['mask'].convert('P'))
350
+ image_input = mask_painter(np.array(image_input), input_mask)
351
+
352
+ origin_image_input = image_input
353
+
354
+ fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
355
+ image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
356
+
357
+ yield state, state, image_input, wiki
358
+
359
+ if not args.disable_gpt and model.text_refiner:
360
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
361
+ enable_wiki=enable_wiki)
362
+
363
+ new_cap = refined_caption['caption']
364
+ wiki = refined_caption['wiki']
365
+ state = state + [(None, f"caption: {new_cap}")]
366
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
367
+
368
+ yield state, state, refined_image_input, wiki
369
+
370
+
371
+ def get_style():
372
+ current_version = version.parse(gr.__version__)
373
+ if current_version <= version.parse('3.24.1'):
374
+ style = '''
375
+ #image_sketcher{min-height:500px}
376
+ #image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
377
+ #image_upload{min-height:500px}
378
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
379
+ '''
380
+ elif current_version <= version.parse('3.27'):
381
+ style = '''
382
+ #image_sketcher{min-height:500px}
383
+ #image_upload{min-height:500px}
384
+ '''
385
+ else:
386
+ style = None
387
+
388
+ return style
389
+
390
+
391
+ def create_ui():
392
+ title = """<p><h1 align="center">Caption-Anything</h1></p>
393
+ """
394
+ description = """<p>Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: <a href="https://github.com/ttengwang/Caption-Anything">https://github.com/ttengwang/Caption-Anything</a> <a href="https://huggingface.co/spaces/TencentARC/Caption-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
395
+
396
+ examples = [
397
+ ["test_images/img35.webp"],
398
+ ["test_images/img2.jpg"],
399
+ ["test_images/img5.jpg"],
400
+ ["test_images/img12.jpg"],
401
+ ["test_images/img14.jpg"],
402
+ ["test_images/qingming3.jpeg"],
403
+ ["test_images/img1.jpg"],
404
+ ]
405
+
406
+ with gr.Blocks(
407
+ css=get_style()
408
+ ) as iface:
409
+ state = gr.State([])
410
+ click_state = gr.State([[], [], []])
411
+ chat_state = gr.State([])
412
+ origin_image = gr.State(None)
413
+ image_embedding = gr.State(None)
414
+ text_refiner = gr.State(None)
415
+ original_size = gr.State(None)
416
+ input_size = gr.State(None)
417
+ img_caption = gr.State(None)
418
+
419
+ gr.Markdown(title)
420
+ gr.Markdown(description)
421
+
422
+ with gr.Row():
423
+ with gr.Column(scale=1.0):
424
+ with gr.Column(visible=False) as modules_not_need_gpt:
425
+ with gr.Tab("Click"):
426
+ image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
427
+ example_image = gr.Image(type="pil", interactive=False, visible=False)
428
+ with gr.Row(scale=1.0):
429
+ with gr.Row(scale=0.4):
430
+ point_prompt = gr.Radio(
431
+ choices=["Positive", "Negative"],
432
+ value="Positive",
433
+ label="Point Prompt",
434
+ interactive=True)
435
+ click_mode = gr.Radio(
436
+ choices=["Continuous", "Single"],
437
+ value="Continuous",
438
+ label="Clicking Mode",
439
+ interactive=True)
440
+ with gr.Row(scale=0.4):
441
+ clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
442
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
443
+ with gr.Tab("Trajectory (Beta)"):
444
+ sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=20,
445
+ elem_id="image_sketcher")
446
+ with gr.Row():
447
+ submit_button_sketcher = gr.Button(value="Submit", interactive=True)
448
+
449
+ with gr.Column(visible=False) as modules_need_gpt:
450
+ with gr.Row(scale=1.0):
451
+ language = gr.Dropdown(
452
+ ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
453
+ value="English", label="Language", interactive=True)
454
+ sentiment = gr.Radio(
455
+ choices=["Positive", "Natural", "Negative"],
456
+ value="Natural",
457
+ label="Sentiment",
458
+ interactive=True,
459
+ )
460
+ with gr.Row(scale=1.0):
461
+ factuality = gr.Radio(
462
+ choices=["Factual", "Imagination"],
463
+ value="Factual",
464
+ label="Factuality",
465
+ interactive=True,
466
+ )
467
+ length = gr.Slider(
468
+ minimum=10,
469
+ maximum=80,
470
+ value=10,
471
+ step=1,
472
+ interactive=True,
473
+ label="Generated Caption Length",
474
+ )
475
+ enable_wiki = gr.Radio(
476
+ choices=["Yes", "No"],
477
+ value="No",
478
+ label="Enable Wiki",
479
+ interactive=True)
480
+ with gr.Column(visible=True) as modules_not_need_gpt3:
481
+ gr.Examples(
482
+ examples=examples,
483
+ inputs=[example_image],
484
+ )
485
+ with gr.Column(scale=0.5):
486
+ openai_api_key = gr.Textbox(
487
+ placeholder="Input openAI API key",
488
+ show_label=False,
489
+ label="OpenAI API Key",
490
+ lines=1,
491
+ type="password")
492
+ with gr.Row(scale=0.5):
493
+ enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
494
+ disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
495
+ variant='primary')
496
+ with gr.Column(visible=False) as modules_need_gpt2:
497
+ wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
498
+ with gr.Column(visible=False) as modules_not_need_gpt2:
499
+ chatbot = gr.Chatbot(label="Chat about Selected Object", ).style(height=550, scale=0.5)
500
+ with gr.Column(visible=False) as modules_need_gpt3:
501
+ chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
502
+ container=False)
503
+ with gr.Row():
504
+ clear_button_text = gr.Button(value="Clear Text", interactive=True)
505
+ submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
506
+
507
+ openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
508
+ outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
509
+ modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
510
+ enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
511
+ outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
512
+ modules_not_need_gpt,
513
+ modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
514
+ disable_chatGPT_button.click(init_openai_api_key,
515
+ outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
516
+ modules_not_need_gpt,
517
+ modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
518
+
519
+ clear_button_click.click(
520
+ lambda x: ([[], [], []], x, ""),
521
+ [origin_image],
522
+ [click_state, image_input, wiki_output],
523
+ queue=False,
524
+ show_progress=False
525
+ )
526
+ clear_button_image.click(
527
+ lambda: (None, [], [], [], [[], [], []], "", "", ""),
528
+ [],
529
+ [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
530
+ queue=False,
531
+ show_progress=False
532
+ )
533
+ clear_button_text.click(
534
+ lambda: ([], [], [[], [], [], []], []),
535
+ [],
536
+ [chatbot, state, click_state, chat_state],
537
+ queue=False,
538
+ show_progress=False
539
+ )
540
+ image_input.clear(
541
+ lambda: (None, [], [], [], [[], [], []], "", "", ""),
542
+ [],
543
+ [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
544
+ queue=False,
545
+ show_progress=False
546
+ )
547
+
548
+ image_input.upload(upload_callback, [image_input, state],
549
+ [chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
550
+ image_embedding, original_size, input_size, img_caption])
551
+ sketcher_input.upload(upload_callback, [sketcher_input, state],
552
+ [chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
553
+ image_embedding, original_size, input_size, img_caption])
554
+ chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner, img_caption],
555
+ [chatbot, state, chat_state])
556
+ chat_input.submit(lambda: "", None, chat_input)
557
+ example_image.change(upload_callback, [example_image, state],
558
+ [chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
559
+ image_embedding, original_size, input_size, img_caption])
560
+
561
+ # select coordinate
562
+ image_input.select(
563
+ inference_click,
564
+ inputs=[
565
+ origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
566
+ image_embedding, state, click_state, original_size, input_size, text_refiner
567
+ ],
568
+ outputs=[chatbot, state, click_state, image_input, wiki_output],
569
+ show_progress=False, queue=True
570
+ )
571
+
572
+ submit_button_sketcher.click(
573
+ inference_traject,
574
+ inputs=[
575
+ sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
576
+ original_size, input_size, text_refiner
577
+ ],
578
+ outputs=[chatbot, state, sketcher_input, wiki_output],
579
+ show_progress=False, queue=True
580
+ )
581
+
582
+ return iface
583
+
584
+
585
+ if __name__ == '__main__':
586
+ iface = create_ui()
587
+ iface.queue(concurrency_count=5, api_open=False, max_size=10)
588
+ iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
caas.py DELETED
@@ -1,114 +0,0 @@
1
- from captioner import build_captioner, BaseCaptioner
2
- from segmenter import build_segmenter
3
- from text_refiner import build_text_refiner
4
- import os
5
- import argparse
6
- import pdb
7
- import time
8
- from PIL import Image
9
-
10
- class CaptionAnything():
11
- def __init__(self, args):
12
- self.args = args
13
- self.captioner = build_captioner(args.captioner, args.device, args)
14
- self.segmenter = build_segmenter(args.segmenter, args.device, args)
15
- if not args.disable_gpt:
16
- self.init_refiner()
17
-
18
-
19
- def init_refiner(self):
20
- if os.environ.get('OPENAI_API_KEY', None):
21
- self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args)
22
-
23
- def inference(self, image, prompt, controls, disable_gpt=False):
24
- # segment with prompt
25
- print("CA prompt: ", prompt, "CA controls",controls)
26
- seg_mask = self.segmenter.inference(image, prompt)[0, ...]
27
- mask_save_path = f'result/mask_{time.time()}.png'
28
- if not os.path.exists(os.path.dirname(mask_save_path)):
29
- os.makedirs(os.path.dirname(mask_save_path))
30
- new_p = Image.fromarray(seg_mask.astype('int') * 255.)
31
- if new_p.mode != 'RGB':
32
- new_p = new_p.convert('RGB')
33
- new_p.save(mask_save_path)
34
- print('seg_mask path: ', mask_save_path)
35
- print("seg_mask.shape: ", seg_mask.shape)
36
- # captioning with mask
37
- if self.args.enable_reduce_tokens:
38
- caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
39
- else:
40
- caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
41
- # refining with TextRefiner
42
- context_captions = []
43
- if self.args.context_captions:
44
- context_captions.append(self.captioner.inference(image))
45
- if not disable_gpt and hasattr(self, "text_refiner"):
46
- refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
47
- else:
48
- refined_caption = {'raw_caption': caption}
49
- out = {'generated_captions': refined_caption,
50
- 'crop_save_path': crop_save_path,
51
- 'mask_save_path': mask_save_path,
52
- 'context_captions': context_captions}
53
- return out
54
-
55
- def parse_augment():
56
- parser = argparse.ArgumentParser()
57
- parser.add_argument('--captioner', type=str, default="blip")
58
- parser.add_argument('--segmenter', type=str, default="base")
59
- parser.add_argument('--text_refiner', type=str, default="base")
60
- parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
61
- parser.add_argument('--seg_crop_mode', type=str, default="w_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
62
- parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
63
- parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption")
64
- parser.add_argument('--regular_box', action="store_true", default = False, help="crop image with a regular box")
65
- parser.add_argument('--device', type=str, default="cuda:0")
66
- parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
67
- parser.add_argument('--debug', action="store_true")
68
- parser.add_argument('--gradio_share', action="store_true")
69
- parser.add_argument('--disable_gpt', action="store_true")
70
- parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
71
- parser.add_argument('--disable_reuse_features', action="store_true", default=False)
72
- args = parser.parse_args()
73
-
74
- if args.debug:
75
- print(args)
76
- return args
77
-
78
- if __name__ == "__main__":
79
- args = parse_augment()
80
- # image_path = 'test_img/img3.jpg'
81
- image_path = 'test_img/img13.jpg'
82
- prompts = [
83
- {
84
- "prompt_type":["click"],
85
- "input_point":[[500, 300], [1000, 500]],
86
- "input_label":[1, 0],
87
- "multimask_output":"True",
88
- },
89
- {
90
- "prompt_type":["click"],
91
- "input_point":[[900, 800]],
92
- "input_label":[1],
93
- "multimask_output":"True",
94
- }
95
- ]
96
- controls = {
97
- "length": "30",
98
- "sentiment": "positive",
99
- # "imagination": "True",
100
- "imagination": "False",
101
- "language": "English",
102
- }
103
-
104
- model = CaptionAnything(args)
105
- for prompt in prompts:
106
- print('*'*30)
107
- print('Image path: ', image_path)
108
- image = Image.open(image_path)
109
- print(image)
110
- print('Visual controls (SAM prompt):\n', prompt)
111
- print('Language controls:\n', controls)
112
- out = model.inference(image_path, prompt, controls)
113
-
114
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caption_anything/__init__.py ADDED
File without changes
{captioner β†’ caption_anything/captioner}/README.md RENAMED
File without changes
{captioner β†’ caption_anything/captioner}/__init__.py RENAMED
File without changes
{captioner β†’ caption_anything/captioner}/base_captioner.py RENAMED
@@ -191,7 +191,7 @@ class BaseCaptioner:
191
 
192
  if __name__ == '__main__':
193
  model = BaseCaptioner(device='cuda:0')
194
- image_path = 'test_img/img2.jpg'
195
  seg_mask = np.zeros((15,15))
196
  seg_mask[5:10, 5:10] = 1
197
  seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
 
191
 
192
  if __name__ == '__main__':
193
  model = BaseCaptioner(device='cuda:0')
194
+ image_path = 'test_images/img2.jpg'
195
  seg_mask = np.zeros((15,15))
196
  seg_mask[5:10, 5:10] = 1
197
  seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
{captioner β†’ caption_anything/captioner}/blip.py RENAMED
@@ -54,13 +54,13 @@ class BLIPCaptioner(BaseCaptioner):
54
 
55
  if __name__ == '__main__':
56
  model = BLIPCaptioner(device='cuda:0')
57
- # image_path = 'test_img/img2.jpg'
58
- image_path = '/group/30042/wybertwang/project/woa_visgpt/chatARC/image/SAM/img10.jpg'
59
  seg_mask = np.zeros((15,15))
60
  seg_mask[5:10, 5:10] = 1
61
- seg_mask = 'test_img/img10.jpg.raw_mask.png'
62
- image_path = 'test_img/img2.jpg'
63
- seg_mask = 'test_img/img2.jpg.raw_mask.png'
64
  print(f'process image {image_path}')
65
  print(model.inference_with_reduced_tokens(image_path, seg_mask))
66
 
 
54
 
55
  if __name__ == '__main__':
56
  model = BLIPCaptioner(device='cuda:0')
57
+ # image_path = 'test_images/img2.jpg'
58
+ image_path = 'image/SAM/img10.jpg'
59
  seg_mask = np.zeros((15,15))
60
  seg_mask[5:10, 5:10] = 1
61
+ seg_mask = 'test_images/img10.jpg.raw_mask.png'
62
+ image_path = 'test_images/img2.jpg'
63
+ seg_mask = 'test_images/img2.jpg.raw_mask.png'
64
  print(f'process image {image_path}')
65
  print(model.inference_with_reduced_tokens(image_path, seg_mask))
66
 
{captioner β†’ caption_anything/captioner}/blip2.py RENAMED
@@ -1,13 +1,10 @@
1
  import torch
2
- from PIL import Image, ImageDraw, ImageOps
3
- from transformers import AutoProcessor, Blip2ForConditionalGeneration
4
- import json
5
- import pdb
6
- import cv2
7
  import numpy as np
8
  from typing import Union
 
9
 
10
- from tools import is_platform_win
11
  from .base_captioner import BaseCaptioner
12
 
13
  class BLIP2Captioner(BaseCaptioner):
@@ -55,7 +52,7 @@ if __name__ == '__main__':
55
 
56
  dialogue = False
57
  model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
58
- image_path = 'test_img/img2.jpg'
59
  seg_mask = np.zeros((224,224))
60
  seg_mask[50:200, 50:200] = 1
61
  print(f'process image {image_path}')
 
1
  import torch
2
+ from PIL import Image
 
 
 
 
3
  import numpy as np
4
  from typing import Union
5
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
6
 
7
+ from caption_anything.utils.utils import is_platform_win
8
  from .base_captioner import BaseCaptioner
9
 
10
  class BLIP2Captioner(BaseCaptioner):
 
52
 
53
  dialogue = False
54
  model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
55
+ image_path = 'test_images/img2.jpg'
56
  seg_mask = np.zeros((224,224))
57
  seg_mask[50:200, 50:200] = 1
58
  print(f'process image {image_path}')
{captioner β†’ caption_anything/captioner}/git.py RENAMED
@@ -50,7 +50,7 @@ class GITCaptioner(BaseCaptioner):
50
 
51
  if __name__ == '__main__':
52
  model = GITCaptioner(device='cuda:2', enable_filter=False)
53
- image_path = 'test_img/img2.jpg'
54
  seg_mask = np.zeros((224,224))
55
  seg_mask[50:200, 50:200] = 1
56
  print(f'process image {image_path}')
 
50
 
51
  if __name__ == '__main__':
52
  model = GITCaptioner(device='cuda:2', enable_filter=False)
53
+ image_path = 'test_images/img2.jpg'
54
  seg_mask = np.zeros((224,224))
55
  seg_mask[50:200, 50:200] = 1
56
  print(f'process image {image_path}')
{captioner β†’ caption_anything/captioner}/modeling_blip.py RENAMED
File without changes
{captioner β†’ caption_anything/captioner}/modeling_git.py RENAMED
File without changes
{captioner β†’ caption_anything/captioner}/vit_pixel_masks_utils.py RENAMED
File without changes
caption_anything.py β†’ caption_anything/model.py RENAMED
@@ -1,6 +1,3 @@
1
- from captioner import build_captioner, BaseCaptioner
2
- from segmenter import build_segmenter
3
- from text_refiner import build_text_refiner
4
  import os
5
  import argparse
6
  import pdb
@@ -8,13 +5,17 @@ import time
8
  from PIL import Image
9
  import cv2
10
  import numpy as np
 
 
 
11
 
12
- class CaptionAnything():
 
13
  def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
14
  self.args = args
15
  self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
16
  self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
17
-
18
  self.text_refiner = None
19
  if not args.disable_gpt:
20
  if text_refiner is not None:
@@ -22,24 +23,54 @@ class CaptionAnything():
22
  else:
23
  self.init_refiner(api_key)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def init_refiner(self, api_key):
26
  try:
27
  self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
28
- self.text_refiner.llm('hi') # test
29
  except:
30
  self.text_refiner = None
31
  print('OpenAI GPT is not available')
32
-
33
  def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
34
  # segment with prompt
35
- print("CA prompt: ", prompt, "CA controls",controls)
36
  seg_mask = self.segmenter.inference(image, prompt)[0, ...]
37
  if self.args.enable_morphologyex:
38
  seg_mask = 255 * seg_mask.astype(np.uint8)
39
- seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis = -1)
40
- seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel = np.ones((6, 6), np.uint8))
41
- seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel = np.ones((6, 6), np.uint8))
42
- seg_mask = seg_mask[:,:,0] > 0
43
  mask_save_path = f'result/mask_{time.time()}.png'
44
  if not os.path.exists(os.path.dirname(mask_save_path)):
45
  os.makedirs(os.path.dirname(mask_save_path))
@@ -51,82 +82,66 @@ class CaptionAnything():
51
  print("seg_mask.shape: ", seg_mask.shape)
52
  # captioning with mask
53
  if self.args.enable_reduce_tokens:
54
- caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
 
 
 
 
55
  else:
56
- caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
 
 
 
57
  # refining with TextRefiner
58
  context_captions = []
59
  if self.args.context_captions:
60
  context_captions.append(self.captioner.inference(image))
61
  if not disable_gpt and self.text_refiner is not None:
62
- refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions, enable_wiki=enable_wiki)
 
63
  else:
64
- refined_caption = {'raw_caption': caption}
65
  out = {'generated_captions': refined_caption,
66
- 'crop_save_path': crop_save_path,
67
- 'mask_save_path': mask_save_path,
68
- 'mask': seg_mask_img,
69
- 'context_captions': context_captions}
70
  return out
71
-
72
- def parse_augment():
73
- parser = argparse.ArgumentParser()
74
- parser.add_argument('--captioner', type=str, default="blip2")
75
- parser.add_argument('--segmenter', type=str, default="huge")
76
- parser.add_argument('--text_refiner', type=str, default="base")
77
- parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
78
- parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
79
- parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
80
- parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
81
- parser.add_argument('--disable_regular_box', action="store_true", default = False, help="crop image with a regular box")
82
- parser.add_argument('--device', type=str, default="cuda:0")
83
- parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
84
- parser.add_argument('--debug', action="store_true")
85
- parser.add_argument('--gradio_share', action="store_true")
86
- parser.add_argument('--disable_gpt', action="store_true")
87
- parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
88
- parser.add_argument('--disable_reuse_features', action="store_true", default=False)
89
- parser.add_argument('--enable_morphologyex', action="store_true", default=False)
90
- args = parser.parse_args()
91
 
92
- if args.debug:
93
- print(args)
94
- return args
95
 
96
  if __name__ == "__main__":
 
97
  args = parse_augment()
98
- # image_path = 'test_img/img3.jpg'
99
- image_path = 'test_img/img13.jpg'
100
  prompts = [
101
  {
102
- "prompt_type":["click"],
103
- "input_point":[[500, 300], [1000, 500]],
104
- "input_label":[1, 0],
105
- "multimask_output":"True",
106
  },
107
  {
108
- "prompt_type":["click"],
109
- "input_point":[[900, 800]],
110
- "input_label":[1],
111
- "multimask_output":"True",
112
  }
113
  ]
114
  controls = {
115
- "length": "30",
116
- "sentiment": "positive",
117
- # "imagination": "True",
118
- "imagination": "False",
119
- "language": "English",
120
- }
121
-
122
  model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
123
  for prompt in prompts:
124
- print('*'*30)
125
  print('Image path: ', image_path)
126
  image = Image.open(image_path)
127
  print(image)
128
  print('Visual controls (SAM prompt):\n', prompt)
129
  print('Language controls:\n', controls)
130
  out = model.inference(image_path, prompt, controls)
131
-
132
-
 
 
 
 
1
  import os
2
  import argparse
3
  import pdb
 
5
  from PIL import Image
6
  import cv2
7
  import numpy as np
8
+ from caption_anything.captioner import build_captioner, BaseCaptioner
9
+ from caption_anything.segmenter import build_segmenter
10
+ from caption_anything.text_refiner import build_text_refiner
11
 
12
+
13
+ class CaptionAnything:
14
  def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
15
  self.args = args
16
  self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
17
  self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
18
+
19
  self.text_refiner = None
20
  if not args.disable_gpt:
21
  if text_refiner is not None:
 
23
  else:
24
  self.init_refiner(api_key)
25
 
26
+ @property
27
+ def image_embedding(self):
28
+ return self.segmenter.image_embedding
29
+
30
+ @image_embedding.setter
31
+ def image_embedding(self, image_embedding):
32
+ self.segmenter.image_embedding = image_embedding
33
+
34
+ @property
35
+ def original_size(self):
36
+ return self.segmenter.predictor.original_size
37
+
38
+ @original_size.setter
39
+ def original_size(self, original_size):
40
+ self.segmenter.predictor.original_size = original_size
41
+
42
+ @property
43
+ def input_size(self):
44
+ return self.segmenter.predictor.input_size
45
+
46
+ @input_size.setter
47
+ def input_size(self, input_size):
48
+ self.segmenter.predictor.input_size = input_size
49
+
50
+ def setup(self, image_embedding, original_size, input_size, is_image_set):
51
+ self.image_embedding = image_embedding
52
+ self.original_size = original_size
53
+ self.input_size = input_size
54
+ self.segmenter.predictor.is_image_set = is_image_set
55
+
56
  def init_refiner(self, api_key):
57
  try:
58
  self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
59
+ self.text_refiner.llm('hi') # test
60
  except:
61
  self.text_refiner = None
62
  print('OpenAI GPT is not available')
63
+
64
  def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
65
  # segment with prompt
66
+ print("CA prompt: ", prompt, "CA controls", controls)
67
  seg_mask = self.segmenter.inference(image, prompt)[0, ...]
68
  if self.args.enable_morphologyex:
69
  seg_mask = 255 * seg_mask.astype(np.uint8)
70
+ seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1)
71
+ seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8))
72
+ seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8))
73
+ seg_mask = seg_mask[:, :, 0] > 0
74
  mask_save_path = f'result/mask_{time.time()}.png'
75
  if not os.path.exists(os.path.dirname(mask_save_path)):
76
  os.makedirs(os.path.dirname(mask_save_path))
 
82
  print("seg_mask.shape: ", seg_mask.shape)
83
  # captioning with mask
84
  if self.args.enable_reduce_tokens:
85
+ caption, crop_save_path = self.captioner. \
86
+ inference_with_reduced_tokens(image, seg_mask,
87
+ crop_mode=self.args.seg_crop_mode,
88
+ filter=self.args.clip_filter,
89
+ disable_regular_box=self.args.disable_regular_box)
90
  else:
91
+ caption, crop_save_path = self.captioner. \
92
+ inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode,
93
+ filter=self.args.clip_filter,
94
+ disable_regular_box=self.args.disable_regular_box)
95
  # refining with TextRefiner
96
  context_captions = []
97
  if self.args.context_captions:
98
  context_captions.append(self.captioner.inference(image))
99
  if not disable_gpt and self.text_refiner is not None:
100
+ refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions,
101
+ enable_wiki=enable_wiki)
102
  else:
103
+ refined_caption = {'raw_caption': caption}
104
  out = {'generated_captions': refined_caption,
105
+ 'crop_save_path': crop_save_path,
106
+ 'mask_save_path': mask_save_path,
107
+ 'mask': seg_mask_img,
108
+ 'context_captions': context_captions}
109
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
 
 
 
111
 
112
  if __name__ == "__main__":
113
+ from caption_anything.utils.parser import parse_augment
114
  args = parse_augment()
115
+ # image_path = 'test_images/img3.jpg'
116
+ image_path = 'test_images/img1.jpg'
117
  prompts = [
118
  {
119
+ "prompt_type": ["click"],
120
+ "input_point": [[500, 300], [200, 500]],
121
+ "input_label": [1, 0],
122
+ "multimask_output": "True",
123
  },
124
  {
125
+ "prompt_type": ["click"],
126
+ "input_point": [[300, 800]],
127
+ "input_label": [1],
128
+ "multimask_output": "True",
129
  }
130
  ]
131
  controls = {
132
+ "length": "30",
133
+ "sentiment": "positive",
134
+ # "imagination": "True",
135
+ "imagination": "False",
136
+ "language": "English",
137
+ }
138
+
139
  model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
140
  for prompt in prompts:
141
+ print('*' * 30)
142
  print('Image path: ', image_path)
143
  image = Image.open(image_path)
144
  print(image)
145
  print('Visual controls (SAM prompt):\n', prompt)
146
  print('Language controls:\n', controls)
147
  out = model.inference(image_path, prompt, controls)
 
 
caption_anything/segmenter/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .base_segmenter import BaseSegmenter
2
+ from caption_anything.utils.utils import seg_model_map
3
+
4
+ def build_segmenter(model_name, device, args=None, model=None):
5
+ return BaseSegmenter(device, args.segmenter_checkpoint, model_name, reuse_feature=not args.disable_reuse_features, model=model)
{segmenter β†’ caption_anything/segmenter}/base_segmenter.py RENAMED
@@ -5,19 +5,22 @@ from PIL import Image, ImageDraw, ImageOps
5
  import numpy as np
6
  from typing import Union
7
  from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
 
8
  import matplotlib.pyplot as plt
9
  import PIL
10
 
 
11
  class BaseSegmenter:
12
- def __init__(self, device, checkpoint, model_type='vit_h', reuse_feature = True, model=None):
13
  print(f"Initializing BaseSegmenter to {device}")
14
  self.device = device
15
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
16
  self.processor = None
17
- self.model_type = model_type
18
  if model is None:
 
 
 
19
  self.checkpoint = checkpoint
20
- self.model = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
21
  self.model.to(device=self.device)
22
  else:
23
  self.model = model
@@ -27,26 +30,57 @@ class BaseSegmenter:
27
  self.image_embedding = None
28
  self.image = None
29
 
30
-
31
- @torch.no_grad()
32
- def set_image(self, image: Union[np.ndarray, Image.Image, str]):
33
- if type(image) == str: # input path
34
  image = Image.open(image)
35
  image = np.array(image)
36
  elif type(image) == Image.Image:
37
  image = np.array(image)
 
 
 
 
 
 
 
 
 
38
  self.image = image
39
  if self.reuse_feature:
40
  self.predictor.set_image(image)
41
  self.image_embedding = self.predictor.get_image_embedding()
42
  print(self.image_embedding.shape)
43
 
44
-
45
  @torch.no_grad()
46
- def inference(self, image, control):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  if 'everything' in control['prompt_type']:
48
  masks = self.mask_generator.generate(image)
49
- new_masks = np.concatenate([mask["segmentation"][np.newaxis,:] for mask in masks])
50
  return new_masks
51
  else:
52
  if not self.reuse_feature or self.image_embedding is None:
@@ -55,17 +89,17 @@ class BaseSegmenter:
55
  else:
56
  assert self.image_embedding is not None
57
  self.predictor.features = self.image_embedding
58
-
59
  if 'mutimask_output' in control:
60
  masks, scores, logits = self.predictor.predict(
61
- point_coords = np.array(control['input_point']),
62
- point_labels = np.array(control['input_label']),
63
- multimask_output = True,
64
  )
65
  elif 'input_boxes' in control:
66
  transformed_boxes = self.predictor.transform.apply_boxes_torch(
67
  torch.tensor(control["input_boxes"], device=self.predictor.device),
68
- image.shape[:2]
69
  )
70
  masks, _, _ = self.predictor.predict_torch(
71
  point_coords=None,
@@ -74,31 +108,32 @@ class BaseSegmenter:
74
  multimask_output=False,
75
  )
76
  masks = masks.squeeze(1).cpu().numpy()
77
-
78
  else:
79
  input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None
80
  input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None
81
  input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None
82
-
83
  masks, scores, logits = self.predictor.predict(
84
- point_coords = input_point,
85
- point_labels = input_label,
86
- box = input_box,
87
- multimask_output = False,
88
  )
89
-
90
  if 0 in control['input_label']:
91
  mask_input = logits[np.argmax(scores), :, :]
92
  masks, scores, logits = self.predictor.predict(
93
  point_coords=input_point,
94
  point_labels=input_label,
95
- box = input_box,
96
  mask_input=mask_input[None, :, :],
97
  multimask_output=False,
98
  )
99
-
100
  return masks
101
 
 
102
  if __name__ == "__main__":
103
  image_path = 'segmenter/images/truck.jpg'
104
  prompts = [
@@ -109,9 +144,9 @@ if __name__ == "__main__":
109
  # "multimask_output":"True",
110
  # },
111
  {
112
- "prompt_type":["click"],
113
- "input_point":[[1000, 600], [1325, 625]],
114
- "input_label":[1, 0],
115
  },
116
  # {
117
  # "prompt_type":["click", "box"],
@@ -132,7 +167,7 @@ if __name__ == "__main__":
132
  # "prompt_type":["everything"]
133
  # },
134
  ]
135
-
136
  init_time = time.time()
137
  segmenter = BaseSegmenter(
138
  device='cuda',
@@ -142,8 +177,8 @@ if __name__ == "__main__":
142
  reuse_feature=True
143
  )
144
  print(f'init time: {time.time() - init_time}')
145
-
146
- image_path = 'test_img/img2.jpg'
147
  infer_time = time.time()
148
  for i, prompt in enumerate(prompts):
149
  print(f'{prompt["prompt_type"]} mode')
@@ -152,5 +187,5 @@ if __name__ == "__main__":
152
  masks = segmenter.inference(np.array(image), prompt)
153
  Image.fromarray(masks[0]).save('seg.png')
154
  print(masks.shape)
155
-
156
  print(f'infer time: {time.time() - infer_time}')
 
5
  import numpy as np
6
  from typing import Union
7
  from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ from caption_anything.utils.utils import prepare_segmenter, seg_model_map
9
  import matplotlib.pyplot as plt
10
  import PIL
11
 
12
+
13
  class BaseSegmenter:
14
+ def __init__(self, device, checkpoint, model_name='huge', reuse_feature=True, model=None):
15
  print(f"Initializing BaseSegmenter to {device}")
16
  self.device = device
17
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
18
  self.processor = None
 
19
  if model is None:
20
+ if checkpoint is None:
21
+ _, checkpoint = prepare_segmenter(model_name)
22
+ self.model = sam_model_registry[seg_model_map[model_name]](checkpoint=checkpoint)
23
  self.checkpoint = checkpoint
 
24
  self.model.to(device=self.device)
25
  else:
26
  self.model = model
 
30
  self.image_embedding = None
31
  self.image = None
32
 
33
+ def read_image(self, image: Union[np.ndarray, Image.Image, str]):
34
+ if type(image) == str: # input path
 
 
35
  image = Image.open(image)
36
  image = np.array(image)
37
  elif type(image) == Image.Image:
38
  image = np.array(image)
39
+ elif type(image) == np.ndarray:
40
+ image = image
41
+ else:
42
+ raise TypeError
43
+ return image
44
+
45
+ @torch.no_grad()
46
+ def set_image(self, image: Union[np.ndarray, Image.Image, str]):
47
+ image = self.read_image(image)
48
  self.image = image
49
  if self.reuse_feature:
50
  self.predictor.set_image(image)
51
  self.image_embedding = self.predictor.get_image_embedding()
52
  print(self.image_embedding.shape)
53
 
 
54
  @torch.no_grad()
55
+ def inference(self, image: Union[np.ndarray, Image.Image, str], control: dict):
56
+ """
57
+ SAM inference of image according to control.
58
+ Args:
59
+ image: str or PIL.Image or np.ndarray
60
+ control:
61
+ prompt_type:
62
+ 1. {control['prompt_type'] = ['everything']} to segment everything in the image.
63
+ 2. {control['prompt_type'] = ['click', 'box']} to segment according to click and box.
64
+ 3. {control['prompt_type'] = ['click'] to segment according to click.
65
+ 4. {control['prompt_type'] = ['box'] to segment according to box.
66
+ input_point: list of [x, y] coordinates of click.
67
+ input_label: List of labels for points accordingly, 0 for negative, 1 for positive.
68
+ input_box: List of [x1, y1, x2, y2] coordinates of box.
69
+ multimask_output:
70
+ If true, the model will return three masks.
71
+ For ambiguous input prompts (such as a single click), this will often
72
+ produce better masks than a single prediction. If only a single
73
+ mask is needed, the model's predicted quality score can be used
74
+ to select the best mask. For non-ambiguous prompts, such as multiple
75
+ input prompts, multimask_output=False can give better results.
76
+ Returns:
77
+ masks: np.ndarray of shape [num_masks, height, width]
78
+
79
+ """
80
+ image = self.read_image(image) # Turn image into np.ndarray
81
  if 'everything' in control['prompt_type']:
82
  masks = self.mask_generator.generate(image)
83
+ new_masks = np.concatenate([mask["segmentation"][np.newaxis, :] for mask in masks])
84
  return new_masks
85
  else:
86
  if not self.reuse_feature or self.image_embedding is None:
 
89
  else:
90
  assert self.image_embedding is not None
91
  self.predictor.features = self.image_embedding
92
+
93
  if 'mutimask_output' in control:
94
  masks, scores, logits = self.predictor.predict(
95
+ point_coords=np.array(control['input_point']),
96
+ point_labels=np.array(control['input_label']),
97
+ multimask_output=True,
98
  )
99
  elif 'input_boxes' in control:
100
  transformed_boxes = self.predictor.transform.apply_boxes_torch(
101
  torch.tensor(control["input_boxes"], device=self.predictor.device),
102
+ image.shape[1::-1] # Reverse shape because numpy is (W, H) and function need (H, W)
103
  )
104
  masks, _, _ = self.predictor.predict_torch(
105
  point_coords=None,
 
108
  multimask_output=False,
109
  )
110
  masks = masks.squeeze(1).cpu().numpy()
111
+
112
  else:
113
  input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None
114
  input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None
115
  input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None
116
+
117
  masks, scores, logits = self.predictor.predict(
118
+ point_coords=input_point,
119
+ point_labels=input_label,
120
+ box=input_box,
121
+ multimask_output=False,
122
  )
123
+
124
  if 0 in control['input_label']:
125
  mask_input = logits[np.argmax(scores), :, :]
126
  masks, scores, logits = self.predictor.predict(
127
  point_coords=input_point,
128
  point_labels=input_label,
129
+ box=input_box,
130
  mask_input=mask_input[None, :, :],
131
  multimask_output=False,
132
  )
133
+
134
  return masks
135
 
136
+
137
  if __name__ == "__main__":
138
  image_path = 'segmenter/images/truck.jpg'
139
  prompts = [
 
144
  # "multimask_output":"True",
145
  # },
146
  {
147
+ "prompt_type": ["click"],
148
+ "input_point": [[1000, 600], [1325, 625]],
149
+ "input_label": [1, 0],
150
  },
151
  # {
152
  # "prompt_type":["click", "box"],
 
167
  # "prompt_type":["everything"]
168
  # },
169
  ]
170
+
171
  init_time = time.time()
172
  segmenter = BaseSegmenter(
173
  device='cuda',
 
177
  reuse_feature=True
178
  )
179
  print(f'init time: {time.time() - init_time}')
180
+
181
+ image_path = 'test_images/img2.jpg'
182
  infer_time = time.time()
183
  for i, prompt in enumerate(prompts):
184
  print(f'{prompt["prompt_type"]} mode')
 
187
  masks = segmenter.inference(np.array(image), prompt)
188
  Image.fromarray(masks[0]).save('seg.png')
189
  print(masks.shape)
190
+
191
  print(f'infer time: {time.time() - infer_time}')
{segmenter β†’ caption_anything/segmenter}/readme.md RENAMED
File without changes
{text_refiner β†’ caption_anything/text_refiner}/README.md RENAMED
File without changes
{text_refiner β†’ caption_anything/text_refiner}/__init__.py RENAMED
@@ -1,4 +1,4 @@
1
- from text_refiner.text_refiner import TextRefiner
2
 
3
 
4
  def build_text_refiner(type, device, args=None, api_key=""):
 
1
+ from .text_refiner import TextRefiner
2
 
3
 
4
  def build_text_refiner(type, device, args=None, api_key=""):
{text_refiner β†’ caption_anything/text_refiner}/text_refiner.py RENAMED
File without changes
caption_anything/utils/chatbot.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft
2
+ # Modified from Visual ChatGPT Project https://github.com/microsoft/TaskMatrix/blob/main/visual_chatgpt.py
3
+
4
+ import os
5
+ import gradio as gr
6
+ import re
7
+ import uuid
8
+ from PIL import Image, ImageDraw, ImageOps
9
+ import numpy as np
10
+ import argparse
11
+ import inspect
12
+
13
+ from langchain.agents.initialize import initialize_agent
14
+ from langchain.agents.tools import Tool
15
+ from langchain.chains.conversation.memory import ConversationBufferMemory
16
+ from langchain.llms.openai import OpenAI
17
+ import torch
18
+ from PIL import Image, ImageDraw, ImageOps
19
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
20
+
21
+ VISUAL_CHATGPT_PREFIX = """
22
+ Caption Anything Chatbox (short as CATchat) is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. CATchat is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
23
+
24
+ As a language model, CATchat can not directly read images, but it has a list of tools to finish different visual tasks. CATchat can invoke different tools to indirectly understand pictures.
25
+
26
+ Visual ChatGPT has access to the following tools:"""
27
+
28
+
29
+ # VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
30
+
31
+ # Visual ChatGPT is able to process and understand large amounts of text and images. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "chat_image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name.
32
+
33
+ # Visual ChatGPT is aware of the coordinate of an object in the image, which is represented as a point (X, Y) on the object. Note that (0, 0) represents the bottom-left corner of the image.
34
+
35
+ # Human may provide new figures to Visual ChatGPT with a description. The description helps Visual ChatGPT to understand this image, but Visual ChatGPT should use tools to finish following tasks, rather than directly imagine from the description.
36
+
37
+ # Overall, Visual ChatGPT is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics.
38
+
39
+
40
+ # TOOLS:
41
+ # ------
42
+
43
+ # Visual ChatGPT has access to the following tools:"""
44
+
45
+ VISUAL_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
46
+
47
+ "Thought: Do I need to use a tool? Yes
48
+ Action: the action to take, should be one of [{tool_names}], remember the action must to be one tool
49
+ Action Input: the input to the action
50
+ Observation: the result of the action"
51
+
52
+ When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
53
+
54
+ "Thought: Do I need to use a tool? No
55
+ {ai_prefix}: [your response here]"
56
+
57
+ """
58
+
59
+ VISUAL_CHATGPT_SUFFIX = """
60
+ Begin Chatting!
61
+
62
+ Previous conversation history:
63
+ {chat_history}
64
+
65
+ New input: {input}
66
+ Since CATchat is a text language model, CATchat must use tools iteratively to observe images rather than imagination.
67
+ The thoughts and observations are only visible for CATchat, CATchat should remember to repeat important information in the final response for Human.
68
+
69
+ Thought: Do I need to use a tool? {agent_scratchpad} (You are strictly to use the aforementioned "Thought/Action/Action Input/Observation" format as the answer.)"""
70
+
71
+ os.makedirs('chat_image', exist_ok=True)
72
+
73
+
74
+ def prompts(name, description):
75
+ def decorator(func):
76
+ func.name = name
77
+ func.description = description
78
+ return func
79
+ return decorator
80
+
81
+ def cut_dialogue_history(history_memory, keep_last_n_words=500):
82
+ if history_memory is None or len(history_memory) == 0:
83
+ return history_memory
84
+ tokens = history_memory.split()
85
+ n_tokens = len(tokens)
86
+ print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
87
+ if n_tokens < keep_last_n_words:
88
+ return history_memory
89
+ paragraphs = history_memory.split('\n')
90
+ last_n_tokens = n_tokens
91
+ while last_n_tokens >= keep_last_n_words:
92
+ last_n_tokens -= len(paragraphs[0].split(' '))
93
+ paragraphs = paragraphs[1:]
94
+ return '\n' + '\n'.join(paragraphs)
95
+
96
+ def get_new_image_name(folder='chat_image', func_name="update"):
97
+ this_new_uuid = str(uuid.uuid4())[:8]
98
+ new_file_name = f'{func_name}_{this_new_uuid}.png'
99
+ return os.path.join(folder, new_file_name)
100
+
101
+ class VisualQuestionAnswering:
102
+ def __init__(self, device):
103
+ print(f"Initializing VisualQuestionAnswering to {device}")
104
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
105
+ self.device = device
106
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
107
+ self.model = BlipForQuestionAnswering.from_pretrained(
108
+ "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
109
+ # self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
110
+ # self.model = BlipForQuestionAnswering.from_pretrained(
111
+ # "Salesforce/blip-vqa-capfilt-large", torch_dtype=self.torch_dtype).to(self.device)
112
+
113
+ @prompts(name="Answer Question About The Image",
114
+ description="useful when you need an answer for a question based on an image. "
115
+ "like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
116
+ "The input to this tool should be a comma separated string of two, representing the image_path and the question")
117
+ def inference(self, inputs):
118
+ image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
119
+ raw_image = Image.open(image_path).convert('RGB')
120
+ inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
121
+ out = self.model.generate(**inputs)
122
+ answer = self.processor.decode(out[0], skip_special_tokens=True)
123
+ print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
124
+ f"Output Answer: {answer}")
125
+ return answer
126
+
127
+ def build_chatbot_tools(load_dict):
128
+ print(f"Initializing ChatBot, load_dict={load_dict}")
129
+ models = {}
130
+ # Load Basic Foundation Models
131
+ for class_name, device in load_dict.items():
132
+ models[class_name] = globals()[class_name](device=device)
133
+
134
+ # Load Template Foundation Models
135
+ for class_name, module in globals().items():
136
+ if getattr(module, 'template_model', False):
137
+ template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if k!='self'}
138
+ loaded_names = set([type(e).__name__ for e in models.values()])
139
+ if template_required_names.issubset(loaded_names):
140
+ models[class_name] = globals()[class_name](
141
+ **{name: models[name] for name in template_required_names})
142
+
143
+ tools = []
144
+ for instance in models.values():
145
+ for e in dir(instance):
146
+ if e.startswith('inference'):
147
+ func = getattr(instance, e)
148
+ tools.append(Tool(name=func.name, description=func.description, func=func))
149
+ return tools
150
+
151
+ class ConversationBot:
152
+ def __init__(self, tools, api_key=""):
153
+ # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
154
+ llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
155
+ self.llm = llm
156
+ self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
157
+ self.tools = tools
158
+ self.current_image = None
159
+ self.point_prompt = ""
160
+ self.agent = initialize_agent(
161
+ self.tools,
162
+ self.llm,
163
+ agent="conversational-react-description",
164
+ verbose=True,
165
+ memory=self.memory,
166
+ return_intermediate_steps=True,
167
+ agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
168
+ 'suffix': VISUAL_CHATGPT_SUFFIX}, )
169
+
170
+ def constructe_intermediate_steps(self, agent_res):
171
+ ans = []
172
+ for action, output in agent_res:
173
+ if hasattr(action, "tool_input"):
174
+ use_tool = "Yes"
175
+ act = (f"Thought: Do I need to use a tool? {use_tool}\nAction: {action.tool}\nAction Input: {action.tool_input}", f"Observation: {output}")
176
+ else:
177
+ use_tool = "No"
178
+ act = (f"Thought: Do I need to use a tool? {use_tool}", f"AI: {output}")
179
+ act= list(map(lambda x: x.replace('\n', '<br>'), act))
180
+ ans.append(act)
181
+ return ans
182
+
183
+ def run_text(self, text, state, aux_state):
184
+ self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
185
+ if self.point_prompt != "":
186
+ Human_prompt = f'\nHuman: {self.point_prompt}\n'
187
+ AI_prompt = 'Ok'
188
+ self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
189
+ self.point_prompt = ""
190
+ res = self.agent({"input": text})
191
+ res['output'] = res['output'].replace("\\", "/")
192
+ response = re.sub('(chat_image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
193
+ state = state + [(text, response)]
194
+
195
+ aux_state = aux_state + [(f"User Input: {text}", None)]
196
+ aux_state = aux_state + self.constructe_intermediate_steps(res['intermediate_steps'])
197
+ print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
198
+ f"Current Memory: {self.agent.memory.buffer}\n"
199
+ f"Aux state: {aux_state}\n"
200
+ )
201
+ return state, state, aux_state, aux_state
202
+
203
+
204
+ if __name__ == '__main__':
205
+ parser = argparse.ArgumentParser()
206
+ parser.add_argument('--load', type=str, default="VisualQuestionAnswering_cuda:0")
207
+ parser.add_argument('--port', type=int, default=1015)
208
+
209
+ args = parser.parse_args()
210
+ load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
211
+ tools = build_chatbot_tools(load_dict)
212
+ bot = ConversationBot(tools)
213
+ with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
214
+ with gr.Row():
215
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT").style(height=1000,scale=0.5)
216
+ auxwindow = gr.Chatbot(elem_id="chatbot", label="Aux Window").style(height=1000,scale=0.5)
217
+ state = gr.State([])
218
+ aux_state = gr.State([])
219
+ with gr.Row():
220
+ with gr.Column(scale=0.7):
221
+ txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
222
+ container=False)
223
+ with gr.Column(scale=0.15, min_width=0):
224
+ clear = gr.Button("Clear")
225
+ with gr.Column(scale=0.15, min_width=0):
226
+ btn = gr.UploadButton("Upload", file_types=["image"])
227
+
228
+ txt.submit(bot.run_text, [txt, state, aux_state], [chatbot, state, aux_state, auxwindow])
229
+ txt.submit(lambda: "", None, txt)
230
+ btn.upload(bot.run_image, [btn, state, txt, aux_state], [chatbot, state, txt, aux_state, auxwindow])
231
+ clear.click(bot.memory.clear)
232
+ clear.click(lambda: [], None, chatbot)
233
+ clear.click(lambda: [], None, auxwindow)
234
+ clear.click(lambda: [], None, state)
235
+ clear.click(lambda: [], None, aux_state)
236
+ demo.launch(server_name="0.0.0.0", server_port=args.port, share=True)
image_editing_utils.py β†’ caption_anything/utils/image_editing_utils.py RENAMED
@@ -1,7 +1,8 @@
1
  from PIL import Image, ImageDraw, ImageFont
2
  import copy
3
  import numpy as np
4
- import cv2
 
5
 
6
  def wrap_text(text, font, max_width):
7
  lines = []
@@ -18,11 +19,18 @@ def wrap_text(text, font, max_width):
18
  lines.append(current_line)
19
  return lines
20
 
21
- def create_bubble_frame(image, text, point, segmask, input_points, input_labels, font_path='times_with_simsun.ttf', font_size_ratio=0.033, point_size_ratio=0.01):
 
 
22
  # Load the image
 
 
 
 
 
23
  if type(image) == np.ndarray:
24
  image = Image.fromarray(image)
25
-
26
  image = copy.deepcopy(image)
27
  width, height = image.size
28
 
@@ -47,19 +55,19 @@ def create_bubble_frame(image, text, point, segmask, input_points, input_labels,
47
  bubble_height = text_height + 2 * padding
48
 
49
  # Create a new image for the bubble frame
50
- bubble = Image.new('RGBA', (bubble_width, bubble_height), (255,248, 220, 0))
51
 
52
  # Draw the bubble frame on the new image
53
  draw = ImageDraw.Draw(bubble)
54
  # draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
55
- draw_rounded_rectangle(draw, (0, 0, bubble_width - 1, bubble_height - 1), point_size * 2,
56
- fill=(255,248, 220, 120), outline=None, width=2)
57
  # Draw the wrapped text line by line
58
  y_text = padding
59
  for line in lines:
60
  draw.text((padding, y_text), line, font=font, fill=(0, 0, 0, 255))
61
  y_text += font.getsize(line)[1]
62
-
63
  # Determine the point by the min area rect of mask
64
  try:
65
  ret, thresh = cv2.threshold(segmask, 127, 255, 0)
@@ -109,7 +117,11 @@ def draw_rounded_rectangle(draw, xy, corner_radius, fill=None, outline=None, wid
109
  width=width
110
  )
111
 
112
- draw.pieslice((x1, y1, x1 + corner_radius * 2, y1 + corner_radius * 2), 180, 270, fill=fill, outline=outline, width=width)
113
- draw.pieslice((x2 - corner_radius * 2, y1, x2, y1 + corner_radius * 2), 270, 360, fill=fill, outline=outline, width=width)
114
- draw.pieslice((x2 - corner_radius * 2, y2 - corner_radius * 2, x2, y2), 0, 90, fill=fill, outline=outline, width=width)
115
- draw.pieslice((x1, y2 - corner_radius * 2, x1 + corner_radius * 2, y2), 90, 180, fill=fill, outline=outline, width=width)
 
 
 
 
 
1
  from PIL import Image, ImageDraw, ImageFont
2
  import copy
3
  import numpy as np
4
+ import cv2
5
+
6
 
7
  def wrap_text(text, font, max_width):
8
  lines = []
 
19
  lines.append(current_line)
20
  return lines
21
 
22
+
23
+ def create_bubble_frame(image, text, point, segmask, input_points=(), input_labels=(),
24
+ font_path='assets/times_with_simsun.ttf', font_size_ratio=0.033, point_size_ratio=0.01):
25
  # Load the image
26
+ if input_points is None:
27
+ input_points = []
28
+ if input_labels is None:
29
+ input_labels = []
30
+
31
  if type(image) == np.ndarray:
32
  image = Image.fromarray(image)
33
+
34
  image = copy.deepcopy(image)
35
  width, height = image.size
36
 
 
55
  bubble_height = text_height + 2 * padding
56
 
57
  # Create a new image for the bubble frame
58
+ bubble = Image.new('RGBA', (bubble_width, bubble_height), (255, 248, 220, 0))
59
 
60
  # Draw the bubble frame on the new image
61
  draw = ImageDraw.Draw(bubble)
62
  # draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
63
+ draw_rounded_rectangle(draw, (0, 0, bubble_width - 1, bubble_height - 1), point_size * 2,
64
+ fill=(255, 248, 220, 120), outline=None, width=2)
65
  # Draw the wrapped text line by line
66
  y_text = padding
67
  for line in lines:
68
  draw.text((padding, y_text), line, font=font, fill=(0, 0, 0, 255))
69
  y_text += font.getsize(line)[1]
70
+
71
  # Determine the point by the min area rect of mask
72
  try:
73
  ret, thresh = cv2.threshold(segmask, 127, 255, 0)
 
117
  width=width
118
  )
119
 
120
+ draw.pieslice((x1, y1, x1 + corner_radius * 2, y1 + corner_radius * 2), 180, 270, fill=fill, outline=outline,
121
+ width=width)
122
+ draw.pieslice((x2 - corner_radius * 2, y1, x2, y1 + corner_radius * 2), 270, 360, fill=fill, outline=outline,
123
+ width=width)
124
+ draw.pieslice((x2 - corner_radius * 2, y2 - corner_radius * 2, x2, y2), 0, 90, fill=fill, outline=outline,
125
+ width=width)
126
+ draw.pieslice((x1, y2 - corner_radius * 2, x1 + corner_radius * 2, y2), 90, 180, fill=fill, outline=outline,
127
+ width=width)
caption_anything/utils/parser.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def parse_augment():
4
+ parser = argparse.ArgumentParser()
5
+ parser.add_argument('--captioner', type=str, default="blip2")
6
+ parser.add_argument('--segmenter', type=str, default="huge")
7
+ parser.add_argument('--text_refiner', type=str, default="base")
8
+ parser.add_argument('--segmenter_checkpoint', type=str, default=None, help="SAM checkpoint path")
9
+ parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'],
10
+ help="whether to add or remove background of the image when captioning")
11
+ parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
12
+ parser.add_argument('--context_captions', action="store_true",
13
+ help="use surrounding captions to enhance current caption (TODO)")
14
+ parser.add_argument('--disable_regular_box', action="store_true", default=False,
15
+ help="crop image with a regular box")
16
+ parser.add_argument('--device', type=str, default="cuda:0")
17
+ parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
18
+ parser.add_argument('--debug', action="store_true")
19
+ parser.add_argument('--gradio_share', action="store_true")
20
+ parser.add_argument('--disable_gpt', action="store_true")
21
+ parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
22
+ parser.add_argument('--disable_reuse_features', action="store_true", default=False)
23
+ parser.add_argument('--enable_morphologyex', action="store_true", default=False)
24
+ parser.add_argument('--chat_tools_dict', type=str, default='VisualQuestionAnswering_cuda:0', help='Visual ChatGPT tools, only useful when running gradio applications')
25
+ args = parser.parse_args()
26
+
27
+ if args.debug:
28
+ print(args)
29
+ return args
caption_anything/utils/utils.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import requests
4
+ import numpy as np
5
+ from PIL import Image
6
+ import time
7
+ import sys
8
+ import urllib
9
+ from tqdm import tqdm
10
+ import hashlib
11
+
12
+ def is_platform_win():
13
+ return sys.platform == "win32"
14
+
15
+
16
+ def colormap(rgb=True):
17
+ color_list = np.array(
18
+ [
19
+ 0.000, 0.000, 0.000,
20
+ 1.000, 1.000, 1.000,
21
+ 1.000, 0.498, 0.313,
22
+ 0.392, 0.581, 0.929,
23
+ 0.000, 0.447, 0.741,
24
+ 0.850, 0.325, 0.098,
25
+ 0.929, 0.694, 0.125,
26
+ 0.494, 0.184, 0.556,
27
+ 0.466, 0.674, 0.188,
28
+ 0.301, 0.745, 0.933,
29
+ 0.635, 0.078, 0.184,
30
+ 0.300, 0.300, 0.300,
31
+ 0.600, 0.600, 0.600,
32
+ 1.000, 0.000, 0.000,
33
+ 1.000, 0.500, 0.000,
34
+ 0.749, 0.749, 0.000,
35
+ 0.000, 1.000, 0.000,
36
+ 0.000, 0.000, 1.000,
37
+ 0.667, 0.000, 1.000,
38
+ 0.333, 0.333, 0.000,
39
+ 0.333, 0.667, 0.000,
40
+ 0.333, 1.000, 0.000,
41
+ 0.667, 0.333, 0.000,
42
+ 0.667, 0.667, 0.000,
43
+ 0.667, 1.000, 0.000,
44
+ 1.000, 0.333, 0.000,
45
+ 1.000, 0.667, 0.000,
46
+ 1.000, 1.000, 0.000,
47
+ 0.000, 0.333, 0.500,
48
+ 0.000, 0.667, 0.500,
49
+ 0.000, 1.000, 0.500,
50
+ 0.333, 0.000, 0.500,
51
+ 0.333, 0.333, 0.500,
52
+ 0.333, 0.667, 0.500,
53
+ 0.333, 1.000, 0.500,
54
+ 0.667, 0.000, 0.500,
55
+ 0.667, 0.333, 0.500,
56
+ 0.667, 0.667, 0.500,
57
+ 0.667, 1.000, 0.500,
58
+ 1.000, 0.000, 0.500,
59
+ 1.000, 0.333, 0.500,
60
+ 1.000, 0.667, 0.500,
61
+ 1.000, 1.000, 0.500,
62
+ 0.000, 0.333, 1.000,
63
+ 0.000, 0.667, 1.000,
64
+ 0.000, 1.000, 1.000,
65
+ 0.333, 0.000, 1.000,
66
+ 0.333, 0.333, 1.000,
67
+ 0.333, 0.667, 1.000,
68
+ 0.333, 1.000, 1.000,
69
+ 0.667, 0.000, 1.000,
70
+ 0.667, 0.333, 1.000,
71
+ 0.667, 0.667, 1.000,
72
+ 0.667, 1.000, 1.000,
73
+ 1.000, 0.000, 1.000,
74
+ 1.000, 0.333, 1.000,
75
+ 1.000, 0.667, 1.000,
76
+ 0.167, 0.000, 0.000,
77
+ 0.333, 0.000, 0.000,
78
+ 0.500, 0.000, 0.000,
79
+ 0.667, 0.000, 0.000,
80
+ 0.833, 0.000, 0.000,
81
+ 1.000, 0.000, 0.000,
82
+ 0.000, 0.167, 0.000,
83
+ 0.000, 0.333, 0.000,
84
+ 0.000, 0.500, 0.000,
85
+ 0.000, 0.667, 0.000,
86
+ 0.000, 0.833, 0.000,
87
+ 0.000, 1.000, 0.000,
88
+ 0.000, 0.000, 0.167,
89
+ 0.000, 0.000, 0.333,
90
+ 0.000, 0.000, 0.500,
91
+ 0.000, 0.000, 0.667,
92
+ 0.000, 0.000, 0.833,
93
+ 0.000, 0.000, 1.000,
94
+ 0.143, 0.143, 0.143,
95
+ 0.286, 0.286, 0.286,
96
+ 0.429, 0.429, 0.429,
97
+ 0.571, 0.571, 0.571,
98
+ 0.714, 0.714, 0.714,
99
+ 0.857, 0.857, 0.857
100
+ ]
101
+ ).astype(np.float32)
102
+ color_list = color_list.reshape((-1, 3)) * 255
103
+ if not rgb:
104
+ color_list = color_list[:, ::-1]
105
+ return color_list
106
+
107
+
108
+ color_list = colormap()
109
+ color_list = color_list.astype('uint8').tolist()
110
+
111
+
112
+ def vis_add_mask(image, mask, color, alpha, kernel_size):
113
+ color = np.array(color)
114
+ mask = mask.astype('float').copy()
115
+ mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha)
116
+ for i in range(3):
117
+ image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
118
+ return image
119
+
120
+
121
+ def vis_add_mask_wo_blur(image, mask, color, alpha):
122
+ color = np.array(color)
123
+ mask = mask.astype('float').copy()
124
+ for i in range(3):
125
+ image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
126
+ return image
127
+
128
+
129
+ def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
130
+ background_color = np.array(background_color)
131
+ contour_color = np.array(contour_color)
132
+
133
+ # background_mask = 1 - background_mask
134
+ # contour_mask = 1 - contour_mask
135
+
136
+ for i in range(3):
137
+ image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
138
+ + background_color[i] * (background_alpha-background_mask*background_alpha)
139
+
140
+ image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
141
+ + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
142
+
143
+ return image.astype('uint8')
144
+
145
+
146
+ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, background_color=0, paint_foreground=False):
147
+ """
148
+ add color mask to the background/foreground area
149
+ input_image: numpy array (w, h, C)
150
+ input_mask: numpy array (w, h)
151
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
152
+ background_blur_radius: radius of background blur, must be odd number
153
+ contour_width: width of mask contour, must be odd number
154
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
155
+ background_color: color index of the background (area with input_mask == False)
156
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
157
+ paint_foreground: True for paint on foreground, False for background. Default: Flase
158
+
159
+ Output:
160
+ painted_image: numpy array
161
+ """
162
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
163
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
164
+
165
+ # 0: background, 1: foreground
166
+ input_mask[input_mask>0] = 255
167
+ if paint_foreground:
168
+ painted_image = vis_add_mask(input_image, 255 - input_mask, color_list[background_color], background_alpha, background_blur_radius) # black for background
169
+ else:
170
+ # mask background
171
+ painted_image = vis_add_mask(input_image, input_mask, color_list[background_color], background_alpha, background_blur_radius) # black for background
172
+ # mask contour
173
+ contour_mask = input_mask.copy()
174
+ contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction
175
+ # widden contour
176
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
177
+ contour_mask = cv2.dilate(contour_mask, kernel)
178
+ painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
179
+ return painted_image
180
+
181
+
182
+ def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1):
183
+ """
184
+ paint color mask on the all foreground area
185
+ input_image: numpy array with shape (w, h, C)
186
+ input_mask: list of masks, each mask is a numpy array with shape (w,h)
187
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
188
+ background_blur_radius: radius of background blur, must be odd number
189
+ contour_width: width of mask contour, must be odd number
190
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
191
+ background_color: color index of the background (area with input_mask == False)
192
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
193
+
194
+ Output:
195
+ painted_image: numpy array
196
+ """
197
+
198
+ for i, input_mask in enumerate(input_masks):
199
+ input_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, background_color=i + 2, paint_foreground=True)
200
+ return input_image
201
+
202
+ def mask_generator_00(mask, background_radius, contour_radius):
203
+ # no background width when '00'
204
+ # distance map
205
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
206
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
207
+ dist_map = dist_transform_fore - dist_transform_back
208
+ # ...:::!!!:::...
209
+ contour_radius += 2
210
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
211
+ contour_mask = contour_mask / np.max(contour_mask)
212
+ contour_mask[contour_mask>0.5] = 1.
213
+
214
+ return mask, contour_mask
215
+
216
+
217
+ def mask_generator_01(mask, background_radius, contour_radius):
218
+ # no background width when '00'
219
+ # distance map
220
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
221
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
222
+ dist_map = dist_transform_fore - dist_transform_back
223
+ # ...:::!!!:::...
224
+ contour_radius += 2
225
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
226
+ contour_mask = contour_mask / np.max(contour_mask)
227
+ return mask, contour_mask
228
+
229
+
230
+ def mask_generator_10(mask, background_radius, contour_radius):
231
+ # distance map
232
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
233
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
234
+ dist_map = dist_transform_fore - dist_transform_back
235
+ # .....:::::!!!!!
236
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
237
+ background_mask = (background_mask - np.min(background_mask))
238
+ background_mask = background_mask / np.max(background_mask)
239
+ # ...:::!!!:::...
240
+ contour_radius += 2
241
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
242
+ contour_mask = contour_mask / np.max(contour_mask)
243
+ contour_mask[contour_mask>0.5] = 1.
244
+ return background_mask, contour_mask
245
+
246
+
247
+ def mask_generator_11(mask, background_radius, contour_radius):
248
+ # distance map
249
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
250
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
251
+ dist_map = dist_transform_fore - dist_transform_back
252
+ # .....:::::!!!!!
253
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
254
+ background_mask = (background_mask - np.min(background_mask))
255
+ background_mask = background_mask / np.max(background_mask)
256
+ # ...:::!!!:::...
257
+ contour_radius += 2
258
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
259
+ contour_mask = contour_mask / np.max(contour_mask)
260
+ return background_mask, contour_mask
261
+
262
+
263
+ def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
264
+ """
265
+ Input:
266
+ input_image: numpy array
267
+ input_mask: numpy array
268
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
269
+ background_blur_radius: radius of background blur, must be odd number
270
+ contour_width: width of mask contour, must be odd number
271
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
272
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
273
+ mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
274
+
275
+ Output:
276
+ painted_image: numpy array
277
+ """
278
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
279
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
280
+ assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
281
+
282
+ # downsample input image and mask
283
+ width, height = input_image.shape[0], input_image.shape[1]
284
+ res = 1024
285
+ ratio = min(1.0 * res / max(width, height), 1.0)
286
+ input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
287
+ input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
288
+
289
+ # 0: background, 1: foreground
290
+ msk = np.clip(input_mask, 0, 1)
291
+
292
+ # generate masks for background and contour pixels
293
+ background_radius = (background_blur_radius - 1) // 2
294
+ contour_radius = (contour_width - 1) // 2
295
+ generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
296
+ background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
297
+
298
+ # paint
299
+ painted_image = vis_add_mask_wo_gaussian \
300
+ (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
301
+
302
+ return painted_image
303
+
304
+
305
+ if __name__ == '__main__':
306
+
307
+ background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
308
+ background_blur_radius = 31 # radius of background blur, must be odd number
309
+ contour_width = 11 # contour width, must be odd number
310
+ contour_color = 3 # id in color map, 0: black, 1: white, >1: others
311
+ contour_alpha = 1 # transparency of background, 0: no contour highlighted
312
+
313
+ # load input image and mask
314
+ input_image = np.array(Image.open('./test_images/painter_input_image.jpg').convert('RGB'))
315
+ input_mask = np.array(Image.open('./test_images/painter_input_mask.jpg').convert('P'))
316
+
317
+ # paint
318
+ overall_time_1 = 0
319
+ overall_time_2 = 0
320
+ overall_time_3 = 0
321
+ overall_time_4 = 0
322
+ overall_time_5 = 0
323
+
324
+ for i in range(50):
325
+ t2 = time.time()
326
+ painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
327
+ e2 = time.time()
328
+
329
+ t3 = time.time()
330
+ painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
331
+ e3 = time.time()
332
+
333
+ t1 = time.time()
334
+ painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
335
+ e1 = time.time()
336
+
337
+ t4 = time.time()
338
+ painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
339
+ e4 = time.time()
340
+
341
+ t5 = time.time()
342
+ painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
343
+ e5 = time.time()
344
+
345
+ overall_time_1 += (e1 - t1)
346
+ overall_time_2 += (e2 - t2)
347
+ overall_time_3 += (e3 - t3)
348
+ overall_time_4 += (e4 - t4)
349
+ overall_time_5 += (e5 - t5)
350
+
351
+ print(f'average time w gaussian: {overall_time_1/50}')
352
+ print(f'average time w/o gaussian00: {overall_time_2/50}')
353
+ print(f'average time w/o gaussian10: {overall_time_3/50}')
354
+ print(f'average time w/o gaussian01: {overall_time_4/50}')
355
+ print(f'average time w/o gaussian11: {overall_time_5/50}')
356
+
357
+ # save
358
+ painted_image_00 = Image.fromarray(painted_image_00)
359
+ painted_image_00.save('./test_images/painter_output_image_00.png')
360
+
361
+ painted_image_10 = Image.fromarray(painted_image_10)
362
+ painted_image_10.save('./test_images/painter_output_image_10.png')
363
+
364
+ painted_image_01 = Image.fromarray(painted_image_01)
365
+ painted_image_01.save('./test_images/painter_output_image_01.png')
366
+
367
+ painted_image_11 = Image.fromarray(painted_image_11)
368
+ painted_image_11.save('./test_images/painter_output_image_11.png')
369
+
370
+
371
+ seg_model_map = {
372
+ 'base': 'vit_b',
373
+ 'large': 'vit_l',
374
+ 'huge': 'vit_h'
375
+ }
376
+ ckpt_url_map = {
377
+ 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
378
+ 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
379
+ 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
380
+ }
381
+ expected_sha256_map = {
382
+ 'vit_b': 'ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912',
383
+ 'vit_l': '3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622',
384
+ 'vit_h': 'a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e'
385
+ }
386
+ def prepare_segmenter(segmenter = "huge", download_root: str = None):
387
+ """
388
+ Prepare segmenter model and download checkpoint if necessary.
389
+
390
+ Returns: segmenter model name from 'vit_b', 'vit_l', 'vit_h'.
391
+
392
+ """
393
+
394
+ os.makedirs('result', exist_ok=True)
395
+ seg_model_name = seg_model_map[segmenter]
396
+ checkpoint_url = ckpt_url_map[seg_model_name]
397
+ folder = download_root or os.path.expanduser("~/.cache/SAM")
398
+ filename = os.path.basename(checkpoint_url)
399
+ segmenter_checkpoint = download_checkpoint(checkpoint_url, folder, filename, expected_sha256_map[seg_model_name])
400
+
401
+ return seg_model_name, segmenter_checkpoint
402
+
403
+
404
+ def download_checkpoint(url, folder, filename, expected_sha256):
405
+ os.makedirs(folder, exist_ok=True)
406
+ download_target = os.path.join(folder, filename)
407
+ if os.path.isfile(download_target):
408
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
409
+ return download_target
410
+
411
+ print(f'Download SAM checkpoint {url}, saving to {download_target} ...')
412
+ with requests.get(url, stream=True) as response, open(download_target, "wb") as output:
413
+ progress = tqdm(total=int(response.headers.get('content-length', 0)), unit='B', unit_scale=True)
414
+ for data in response.iter_content(chunk_size=1024):
415
+ size = output.write(data)
416
+ progress.update(size)
417
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
418
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
419
+ return download_target
env.sh DELETED
@@ -1,6 +0,0 @@
1
- conda create -n caption_anything python=3.8 -y
2
- source activate caption_anything
3
- pip install -r requirements.txt
4
5
- # wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
6
-
 
 
 
 
 
 
 
segmenter/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from segmenter.base_segmenter import BaseSegmenter
2
-
3
-
4
- def build_segmenter(type, device, args=None, model=None):
5
- return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features, model=model)
 
 
 
 
 
 
segmenter/images/truck.jpg DELETED
Binary file (271 kB)
 
segmenter/sam_vit_h_4b8939.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
- size 2564550879
 
 
 
 
test_img/img0.png DELETED
Binary file (185 kB)
 
test_img/img1.jpg DELETED
Binary file (501 kB)
 
test_img/img1.jpg.raw_mask.png DELETED
Binary file (114 kB)
 
test_img/img10.jpg DELETED
Binary file (376 kB)
 
test_img/img10.jpg.raw_mask.png DELETED
Binary file (24.3 kB)
 
test_img/img11.jpg DELETED
Binary file (616 kB)
 
test_img/img12.jpg DELETED
Binary file (277 kB)
 
test_img/img12.jpg.raw_mask.png DELETED
Binary file (29.1 kB)
 
test_img/img13.jpg DELETED
Binary file (335 kB)
 
test_img/img13.jpg.raw_mask.png DELETED
Binary file (22.9 kB)
 
test_img/img14.jpg DELETED
Binary file (741 kB)
 
test_img/img14.jpg.raw_mask.png DELETED
Binary file (26.9 kB)
 
test_img/img15.jpg DELETED
Binary file (376 kB)
 
test_img/img15.jpg.raw_mask.png DELETED
Binary file (114 kB)
 
test_img/img16.jpg DELETED
Binary file (337 kB)