diff --git a/.gitignore b/.gitignore
index 2e47b28d7377d2fce93c928410efc4250b591beb..f14f513523270e0912c3518d307bd414af2f3cd6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,6 +2,14 @@ result/
model_cache/
*.pth
teng_grad_start.sh
+*.jpg
+*.jpeg
+*.png
+*.svg
+*.gif
+*.tiff
+*.webp
+
# Byte-compiled / optimized / DLL files
__pycache__/
diff --git a/DejaVuSansCondensed-Bold.ttf b/DejaVuSansCondensed-Bold.ttf
deleted file mode 100644
index 437f2f5c0e5b44efcb98f31c1a279458b354fb83..0000000000000000000000000000000000000000
Binary files a/DejaVuSansCondensed-Bold.ttf and /dev/null differ
diff --git a/Image/demo1.svg b/Image/demo1.svg
deleted file mode 100644
index b8435a71ad9f9d5b66d7f7e991ef490271fb6c62..0000000000000000000000000000000000000000
--- a/Image/demo1.svg
+++ /dev/null
@@ -1 +0,0 @@
-
\ No newline at end of file
diff --git a/Image/demo2.svg b/Image/demo2.svg
deleted file mode 100644
index bd7f5f1c2400efbc6eb0122c46c109c1514bdfab..0000000000000000000000000000000000000000
--- a/Image/demo2.svg
+++ /dev/null
@@ -1 +0,0 @@
-
\ No newline at end of file
diff --git a/Image/title.svg b/Image/title.svg
deleted file mode 100644
index 87fcc5fe890c431b9ea6488172b5539b6a959695..0000000000000000000000000000000000000000
--- a/Image/title.svg
+++ /dev/null
@@ -1 +0,0 @@
-
\ No newline at end of file
diff --git a/app.py b/app.py
index 06747f1cb11d2ac87a49d82a468ff83e2d3b36c1..a513db11be1b8f0217019979d2898bd458ac3cef 100644
--- a/app.py
+++ b/app.py
@@ -1,85 +1,63 @@
-from io import BytesIO
-import string
-import gradio as gr
-import requests
-from caption_anything import CaptionAnything
-import torch
+import os
import json
-import sys
-import argparse
-from caption_anything import parse_augment
+import PIL
+import gradio as gr
import numpy as np
-import PIL.ImageDraw as ImageDraw
-from image_editing_utils import create_bubble_frame
-import copy
-from tools import mask_painter
-from PIL import Image
-import os
-from captioner import build_captioner
+from gradio import processing_utils
+
+from packaging import version
+from PIL import Image, ImageDraw
+
+from caption_anything.model import CaptionAnything
+from caption_anything.utils.image_editing_utils import create_bubble_frame
+from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter
+from caption_anything.utils.parser import parse_augment
+from caption_anything.captioner import build_captioner
+from caption_anything.text_refiner import build_text_refiner
+from caption_anything.segmenter import build_segmenter
+from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
from segment_anything import sam_model_registry
-from text_refiner import build_text_refiner
-from segmenter import build_segmenter
-
-
-def download_checkpoint(url, folder, filename):
- os.makedirs(folder, exist_ok=True)
- filepath = os.path.join(folder, filename)
-
- if not os.path.exists(filepath):
- response = requests.get(url, stream=True)
- with open(filepath, "wb") as f:
- for chunk in response.iter_content(chunk_size=8192):
- if chunk:
- f.write(chunk)
-
- return filepath
-
-
-title = """
Caption-Anything
-"""
-description = """
Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
"""
-
-examples = [
- ["test_img/img35.webp"],
- ["test_img/img2.jpg"],
- ["test_img/img5.jpg"],
- ["test_img/img12.jpg"],
- ["test_img/img14.jpg"],
- ["test_img/img0.png"],
- ["test_img/img1.jpg"],
-]
-
-seg_model_map = {
- 'base': 'vit_b',
- 'large': 'vit_l',
- 'huge': 'vit_h'
-}
-ckpt_url_map = {
- 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
- 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
- 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
-}
-os.makedirs('result', exist_ok=True)
+
+
args = parse_augment()
+if args.segmenter_checkpoint is None:
+ _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
+else:
+ segmenter_checkpoint = args.segmenter_checkpoint
+
+shared_captioner = build_captioner(args.captioner, args.device, args)
+shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
+tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
+shared_chatbot_tools = build_chatbot_tools(tools_dict)
-checkpoint_url = ckpt_url_map[seg_model_map[args.segmenter]]
-folder = "segmenter"
-filename = os.path.basename(checkpoint_url)
-args.segmenter_checkpoint = os.path.join(folder, filename)
-download_checkpoint(checkpoint_url, folder, filename)
+class ImageSketcher(gr.Image):
+ """
+ Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
+ """
-# args.device = 'cuda:5'
-# args.disable_gpt = True
-# args.enable_reduce_tokens = False
-# args.port=20322
-# args.captioner = 'blip'
-# args.regular_box = True
-shared_captioner = build_captioner(args.captioner, args.device, args)
-shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=args.segmenter_checkpoint).to(args.device)
+ is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
+
+ def __init__(self, **kwargs):
+ super().__init__(tool="sketch", **kwargs)
+
+ def preprocess(self, x):
+ if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
+ assert isinstance(x, dict)
+ if x['mask'] is None:
+ decode_image = processing_utils.decode_base64_to_image(x['image'])
+ width, height = decode_image.size
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
+ mask[..., -1] = 255
+ mask = self.postprocess(mask)
+ x['mask'] = mask
-def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None, session_id=None):
+ return super().preprocess(x)
+
+
+def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None,
+ session_id=None):
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
captioner = captioner
if session_id is not None:
@@ -89,17 +67,22 @@ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_mod
def init_openai_api_key(api_key=""):
text_refiner = None
+ visual_chatgpt = None
if api_key and len(api_key) > 30:
try:
text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
- text_refiner.llm('hi') # test
+ text_refiner.llm('hi') # test
+ visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
except:
text_refiner = None
+ visual_chatgpt = None
openai_available = text_refiner is not None
- return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True), gr.update(visible = True), text_refiner
+ return gr.update(visible=openai_available), gr.update(visible=openai_available), gr.update(
+ visible=openai_available), gr.update(visible=True), gr.update(visible=True), gr.update(
+ visible=True), text_refiner, visual_chatgpt
-def get_prompt(chat_input, click_state, click_mode):
+def get_click_prompt(chat_input, click_state, click_mode):
inputs = json.loads(chat_input)
if click_mode == 'Continuous':
points = click_state[0]
@@ -119,13 +102,14 @@ def get_prompt(chat_input, click_state, click_mode):
raise NotImplementedError
prompt = {
- "prompt_type":["click"],
- "input_point":click_state[0],
- "input_label":click_state[1],
- "multimask_output":"True",
+ "prompt_type": ["click"],
+ "input_point": click_state[0],
+ "input_label": click_state[1],
+ "multimask_output": "True",
}
return prompt
+
def update_click_state(click_state, caption, click_mode):
if click_mode == 'Continuous':
click_state[2].append(caption)
@@ -134,280 +118,426 @@ def update_click_state(click_state, caption, click_mode):
else:
raise NotImplementedError
-
-def chat_with_points(chat_input, click_state, chat_state, state, text_refiner, img_caption):
- if text_refiner is None:
+def chat_input_callback(*args):
+ visual_chatgpt, chat_input, click_state, state, aux_state = args
+ if visual_chatgpt is not None:
+ return visual_chatgpt.run_text(chat_input, state, aux_state)
+ else:
response = "Text refiner is not initilzed, please input openai api key."
state = state + [(chat_input, response)]
- return state, state, chat_state
-
- points, labels, captions = click_state
- # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
- suffix = '\nHuman: {chat_input}\nAI: '
- qa_template = '\nHuman: {q}\nAI: {a}'
- # # "The image is of width {width} and height {height}."
- point_chat_prompt = "I am an AI trained to chat with you about an image. I am greate at what is going on in any image based on the image information your provide. The overall image description is \"{img_caption}\". You will also provide me objects in the image in details, i.e., their location and visual descriptions. Here are the locations and descriptions of events that happen in the image: {points_with_caps} \n Now, let's chat!"
- prev_visual_context = ""
- pos_points = []
- pos_captions = []
- for i in range(len(points)):
- if labels[i] == 1:
- pos_points.append(f"({points[i][0]}, {points[i][0]})")
- pos_captions.append(captions[i])
- prev_visual_context = prev_visual_context + '\n' + 'There is an event described as \"{}\" locating at {}'.format(pos_captions[-1], ', '.join(pos_points))
-
- context_length_thres = 500
- prev_history = ""
- for i in range(len(chat_state)):
- q, a = chat_state[i]
- if len(prev_history) < context_length_thres:
- prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
- else:
- break
- chat_prompt = point_chat_prompt.format(**{"img_caption":img_caption,"points_with_caps": prev_visual_context}) + prev_history + suffix.format(**{"chat_input": chat_input})
- print('\nchat_prompt: ', chat_prompt)
- response = text_refiner.llm(chat_prompt)
- state = state + [(chat_input, response)]
- chat_state = chat_state + [(chat_input, response)]
- return state, state, chat_state
-
-def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
- length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
+ return state, state
+
+def upload_callback(image_input, state, visual_chatgpt=None):
+ if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
+ image_input, mask = image_input['image'], image_input['mask']
+
+ click_state = [[], [], []]
+ res = 1024
+ width, height = image_input.size
+ ratio = min(1.0 * res / max(width, height), 1.0)
+ if ratio < 1.0:
+ image_input = image_input.resize((int(width * ratio), int(height * ratio)))
+ print('Scaling input image to {}'.format(image_input.size))
+
model = build_caption_anything_with_models(
args,
api_key="",
captioner=shared_captioner,
sam_model=shared_sam_model,
- text_refiner=text_refiner,
session_id=iface.app_id
)
-
- model.segmenter.image_embedding = image_embedding
- model.segmenter.predictor.original_size = original_size
- model.segmenter.predictor.input_size = input_size
- model.segmenter.predictor.is_image_set = True
+ model.segmenter.set_image(image_input)
+ image_embedding = model.image_embedding
+ original_size = model.original_size
+ input_size = model.input_size
+
+ if visual_chatgpt is not None:
+ new_image_path = get_new_image_name('chat_image', func_name='upload')
+ image_input.save(new_image_path)
+ visual_chatgpt.current_image = new_image_path
+ img_caption, _ = model.captioner.inference_seg(image_input)
+ Human_prompt = f'\nHuman: provide a new figure with path {new_image_path}. The description is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
+ AI_prompt = "Received."
+ visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
+ state = [(None, 'Received new image, resize it to width {} and height {}: '.format(image_input.size[0], image_input.size[1]))]
+
+ return state, state, image_input, click_state, image_input, image_input, image_embedding, \
+ original_size, input_size
+
+
+def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
+ length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
+ evt: gr.SelectData):
+ click_index = evt.index
if point_prompt == 'Positive':
- coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
+ coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
else:
- coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
+ coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
+
+ prompt = get_click_prompt(coordinate, click_state, click_mode)
+ input_points = prompt['input_point']
+ input_labels = prompt['input_label']
controls = {'length': length,
'sentiment': sentiment,
'factuality': factuality,
'language': language}
- # click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
- # chat_input = click_coordinate
- prompt = get_prompt(coordinate, click_state, click_mode)
- print('prompt: ', prompt, 'controls: ', controls)
- input_points = prompt['input_point']
- input_labels = prompt['input_label']
+ model = build_caption_anything_with_models(
+ args,
+ api_key="",
+ captioner=shared_captioner,
+ sam_model=shared_sam_model,
+ text_refiner=text_refiner,
+ session_id=iface.app_id
+ )
+
+ model.setup(image_embedding, original_size, input_size, is_image_set=True)
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
+
state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
- # for k, v in out['generated_captions'].items():
- # state = state + [(f'{k}: {v}', None)]
state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
wiki = out['generated_captions'].get('wiki', "")
-
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
text = out['generated_captions']['raw_caption']
- # draw = ImageDraw.Draw(image_input)
- # draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
input_mask = np.array(out['mask'].convert('P'))
image_input = mask_painter(np.array(image_input), input_mask)
origin_image_input = image_input
- image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]), input_mask, input_points=input_points, input_labels=input_labels)
-
- yield state, state, click_state, chat_input, image_input, wiki
+ image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask,
+ input_points=input_points, input_labels=input_labels)
+ x, y = input_points[-1]
+
+ if visual_chatgpt is not None:
+ new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
+ Image.open(out["crop_save_path"]).save(new_crop_save_path)
+ point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
+ visual_chatgpt.point_prompt = point_prompt
+
+ yield state, state, click_state, image_input, wiki
if not args.disable_gpt and model.text_refiner:
- refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'], enable_wiki=enable_wiki)
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
+ enable_wiki=enable_wiki)
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
new_cap = refined_caption['caption']
wiki = refined_caption['wiki']
state = state + [(None, f"caption: {new_cap}")]
- refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]), input_mask, input_points=input_points, input_labels=input_labels)
- yield state, state, click_state, chat_input, refined_image_input, wiki
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
+ input_mask,
+ input_points=input_points, input_labels=input_labels)
+ yield state, state, click_state, refined_image_input, wiki
-def upload_callback(image_input, state):
- chat_state = []
- click_state = [[], [], []]
- res = 1024
- width, height = image_input.size
- ratio = min(1.0 * res / max(width, height), 1.0)
- if ratio < 1.0:
- image_input = image_input.resize((int(width * ratio), int(height * ratio)))
- print('Scaling input image to {}'.format(image_input.size))
- state = [] + [(None, 'Image size: ' + str(image_input.size))]
+def get_sketch_prompt(mask: PIL.Image.Image):
+ """
+ Get the prompt for the sketcher.
+ TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
+ """
+
+ mask = np.asarray(mask)[..., 0]
+
+ # Get the bounding box of the sketch
+ y, x = np.where(mask != 0)
+ x1, y1 = np.min(x), np.min(y)
+ x2, y2 = np.max(x), np.max(y)
+
+ prompt = {
+ 'prompt_type': ['box'],
+ 'input_boxes': [
+ [x1, y1, x2, y2]
+ ]
+ }
+
+ return prompt
+
+
+def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
+ original_size, input_size, text_refiner):
+ image_input, mask = sketcher_image['image'], sketcher_image['mask']
+
+ prompt = get_sketch_prompt(mask)
+ boxes = prompt['input_boxes']
+
+ controls = {'length': length,
+ 'sentiment': sentiment,
+ 'factuality': factuality,
+ 'language': language}
+
model = build_caption_anything_with_models(
args,
api_key="",
captioner=shared_captioner,
sam_model=shared_sam_model,
+ text_refiner=text_refiner,
session_id=iface.app_id
)
- model.segmenter.set_image(image_input)
- image_embedding = model.segmenter.image_embedding
- original_size = model.segmenter.predictor.original_size
- input_size = model.segmenter.predictor.input_size
- img_caption, _ = model.captioner.inference_seg(image_input)
- return state, state, chat_state, image_input, click_state, image_input, image_embedding, original_size, input_size, img_caption
-
-with gr.Blocks(
- css='''
- #image_upload{min-height:400px}
- #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
- '''
-) as iface:
- state = gr.State([])
- click_state = gr.State([[],[],[]])
- chat_state = gr.State([])
- origin_image = gr.State(None)
- image_embedding = gr.State(None)
- text_refiner = gr.State(None)
- original_size = gr.State(None)
- input_size = gr.State(None)
- img_caption = gr.State(None)
-
- gr.Markdown(title)
- gr.Markdown(description)
-
- with gr.Row():
- with gr.Column(scale=1.0):
- with gr.Column(visible=False) as modules_not_need_gpt:
- image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
- example_image = gr.Image(type="pil", interactive=False, visible=False)
- with gr.Row(scale=1.0):
- with gr.Row(scale=0.4):
- point_prompt = gr.Radio(
- choices=["Positive", "Negative"],
- value="Positive",
- label="Point Prompt",
- interactive=True)
- click_mode = gr.Radio(
- choices=["Continuous", "Single"],
- value="Continuous",
- label="Clicking Mode",
+
+ model.setup(image_embedding, original_size, input_size, is_image_set=True)
+
+ enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
+
+ # Update components and states
+ state.append((f'Box: {boxes}', None))
+ state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
+ wiki = out['generated_captions'].get('wiki', "")
+ text = out['generated_captions']['raw_caption']
+ input_mask = np.array(out['mask'].convert('P'))
+ image_input = mask_painter(np.array(image_input), input_mask)
+
+ origin_image_input = image_input
+
+ fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
+ image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
+
+ yield state, state, image_input, wiki
+
+ if not args.disable_gpt and model.text_refiner:
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
+ enable_wiki=enable_wiki)
+
+ new_cap = refined_caption['caption']
+ wiki = refined_caption['wiki']
+ state = state + [(None, f"caption: {new_cap}")]
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
+
+ yield state, state, refined_image_input, wiki
+
+def clear_chat_memory(visual_chatgpt):
+ if visual_chatgpt is not None:
+ visual_chatgpt.memory.clear()
+ visual_chatgpt.current_image = None
+ visual_chatgpt.point_prompt = ""
+
+def get_style():
+ current_version = version.parse(gr.__version__)
+ if current_version <= version.parse('3.24.1'):
+ style = '''
+ #image_sketcher{min-height:500px}
+ #image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
+ #image_upload{min-height:500px}
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
+ '''
+ elif current_version <= version.parse('3.27'):
+ style = '''
+ #image_sketcher{min-height:500px}
+ #image_upload{min-height:500px}
+ '''
+ else:
+ style = None
+
+ return style
+
+
+def create_ui():
+ title = """
Caption-Anything
+ """
+ description = """
Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
"""
-description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
-"""
-
-examples = [
- ["test_img/img2.jpg"],
- ["test_img/img5.jpg"],
- ["test_img/img12.jpg"],
- ["test_img/img14.jpg"],
-]
-
-args = parse_augment()
-args.captioner = 'blip2'
-args.seg_crop_mode = 'wo_bg'
-args.regular_box = True
-# args.device = 'cuda:5'
-# args.disable_gpt = False
-# args.enable_reduce_tokens = True
-# args.port=20322
-model = CaptionAnything(args)
-
-def init_openai_api_key(api_key):
- os.environ['OPENAI_API_KEY'] = api_key
- model.init_refiner()
-
-
-def get_prompt(chat_input, click_state):
- points = click_state[0]
- labels = click_state[1]
- inputs = json.loads(chat_input)
- for input in inputs:
- points.append(input[:2])
- labels.append(input[2])
-
- prompt = {
- "prompt_type":["click"],
- "input_point":points,
- "input_label":labels,
- "multimask_output":"True",
- }
- return prompt
-
-def chat_with_points(chat_input, click_state, state):
- if not hasattr(model, "text_refiner"):
- response = "Text refiner is not initilzed, please input openai api key."
- state = state + [(chat_input, response)]
- return state, state
-
- points, labels, captions = click_state
- # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
- # # "The image is of width {width} and height {height}."
- point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
- prev_visual_context = ""
- pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
- if len(captions):
- prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
- else:
- prev_visual_context = 'no point exists.'
- chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
- response = model.text_refiner.llm(chat_prompt)
- state = state + [(chat_input, response)]
- return state, state
-
-def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt:gr.SelectData):
-
- if point_prompt == 'Positive':
- coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
- else:
- coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
-
- controls = {'length': length,
- 'sentiment': sentiment,
- 'factuality': factuality,
- 'language': language}
-
- # click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
- # chat_input = click_coordinate
- prompt = get_prompt(coordinate, click_state)
- print('prompt: ', prompt, 'controls: ', controls)
-
- out = model.inference(image_input, prompt, controls)
- state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
- # for k, v in out['generated_captions'].items():
- # state = state + [(f'{k}: {v}', None)]
- state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
- wiki = out['generated_captions'].get('wiki', "")
- click_state[2].append(out['generated_captions']['raw_caption'])
-
- text = out['generated_captions']['raw_caption']
- # draw = ImageDraw.Draw(image_input)
- # draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
- input_mask = np.array(Image.open(out['mask_save_path']).convert('P'))
- image_input = mask_painter(np.array(image_input), input_mask)
- origin_image_input = image_input
- image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
-
- yield state, state, click_state, chat_input, image_input, wiki
- if not args.disable_gpt and hasattr(model, "text_refiner"):
- refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
- # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
- new_cap = refined_caption['caption']
- refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
- yield state, state, click_state, chat_input, refined_image_input, wiki
-
-
-def upload_callback(image_input, state):
- state = [] + [('Image size: ' + str(image_input.size), None)]
- click_state = [[], [], []]
- model.segmenter.image = None
- model.segmenter.image_embedding = None
- model.segmenter.set_image(image_input)
- return state, image_input, click_state
-
-with gr.Blocks(
- css='''
- #image_upload{min-height:400px}
- #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
- '''
-) as iface:
- state = gr.State([])
- click_state = gr.State([[],[],[]])
- origin_image = gr.State(None)
-
- gr.Markdown(title)
- gr.Markdown(description)
-
- with gr.Row():
- with gr.Column(scale=1.0):
- image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
- with gr.Row(scale=1.0):
- point_prompt = gr.Radio(
- choices=["Positive", "Negative"],
- value="Positive",
- label="Point Prompt",
- interactive=True)
- clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
- clear_button_image = gr.Button(value="Clear Image", interactive=True)
- with gr.Row(scale=1.0):
- language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
-
- sentiment = gr.Radio(
- choices=["Positive", "Natural", "Negative"],
- value="Natural",
- label="Sentiment",
- interactive=True,
- )
- with gr.Row(scale=1.0):
- factuality = gr.Radio(
- choices=["Factual", "Imagination"],
- value="Factual",
- label="Factuality",
- interactive=True,
- )
- length = gr.Slider(
- minimum=10,
- maximum=80,
- value=10,
- step=1,
- interactive=True,
- label="Length",
- )
-
- with gr.Column(scale=0.5):
- openai_api_key = gr.Textbox(
- placeholder="Input your openAI API key and press Enter",
- show_label=False,
- label = "OpenAI API Key",
- lines=1,
- type="password"
- )
- openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
- wiki_output = gr.Textbox(lines=6, label="Wiki")
- chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=450,scale=0.5)
- chat_input = gr.Textbox(lines=1, label="Chat Input")
- with gr.Row():
- clear_button_text = gr.Button(value="Clear Text", interactive=True)
- submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
- clear_button_clike.click(
- lambda x: ([[], [], []], x, ""),
- [origin_image],
- [click_state, image_input, wiki_output],
- queue=False,
- show_progress=False
- )
- clear_button_image.click(
- lambda: (None, [], [], [[], [], []], ""),
- [],
- [image_input, chatbot, state, click_state, wiki_output],
- queue=False,
- show_progress=False
- )
- clear_button_text.click(
- lambda: ([], [], [[], [], []]),
- [],
- [chatbot, state, click_state],
- queue=False,
- show_progress=False
- )
- image_input.clear(
- lambda: (None, [], [], [[], [], []], ""),
- [],
- [image_input, chatbot, state, click_state, wiki_output],
- queue=False,
- show_progress=False
- )
-
- examples = gr.Examples(
- examples=examples,
- inputs=[image_input],
- )
-
- image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state])
- chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
-
- # select coordinate
- image_input.select(inference_seg_cap,
- inputs=[
- origin_image,
- point_prompt,
- language,
- sentiment,
- factuality,
- length,
- state,
- click_state
- ],
- outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
- show_progress=False, queue=True)
-
-iface.queue(concurrency_count=1, api_open=False, max_size=10)
-iface.launch(server_name="0.0.0.0", enable_queue=True)
\ No newline at end of file
diff --git a/app_old.py b/app_old.py
deleted file mode 100644
index 531f681a77e177eaffa5026a7851cfa241fb65b2..0000000000000000000000000000000000000000
--- a/app_old.py
+++ /dev/null
@@ -1,261 +0,0 @@
-from io import BytesIO
-import string
-import gradio as gr
-import requests
-from caption_anything import CaptionAnything
-import torch
-import json
-import sys
-import argparse
-from caption_anything import parse_augment
-import os
-
-# download sam checkpoint if not downloaded
-def download_checkpoint(url, folder, filename):
- os.makedirs(folder, exist_ok=True)
- filepath = os.path.join(folder, filename)
-
- if not os.path.exists(filepath):
- response = requests.get(url, stream=True)
- with open(filepath, "wb") as f:
- for chunk in response.iter_content(chunk_size=8192):
- if chunk:
- f.write(chunk)
-
- return filepath
-checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
-folder = "segmenter"
-filename = "sam_vit_h_4b8939.pth"
-
-title = """
Caption-Anything
"""
-description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them.
- Code: GitHub repo:
-"""
-
-examples = [
- ["test_img/img2.jpg", "[[1000, 700, 1]]"]
-]
-
-args = parse_augment()
-
-def get_prompt(chat_input, click_state):
- points = click_state[0]
- labels = click_state[1]
- inputs = json.loads(chat_input)
- for input in inputs:
- points.append(input[:2])
- labels.append(input[2])
-
- prompt = {
- "prompt_type":["click"],
- "input_point":points,
- "input_label":labels,
- "multimask_output":"True",
- }
- return prompt
-
-def inference_seg_cap(image_input, chat_input, language, sentiment, factuality, length, state, click_state):
- controls = {'length': length,
- 'sentiment': sentiment,
- 'factuality': factuality,
- 'language': language}
- prompt = get_prompt(chat_input, click_state)
- print('prompt: ', prompt, 'controls: ', controls)
- out = model.inference(image_input, prompt, controls)
- state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
- for k, v in out['generated_captions'].items():
- state = state + [(f'{k}: {v}', None)]
- click_state[2].append(out['generated_captions']['raw_caption'])
- image_output_mask = out['mask_save_path']
- image_output_crop = out['crop_save_path']
- return state, state, click_state, image_output_mask, image_output_crop
-
-
-def upload_callback(image_input, state):
- state = state + [('Image size: ' + str(image_input.size), None)]
- return state
-
-# get coordinate in format [[x,y,positive/negative]]
-def get_select_coords(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt: gr.SelectData):
- print("point_prompt: ", point_prompt)
- if point_prompt == 'Positive Point':
- coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
- else:
- coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
- return (coordinate,) + inference_seg_cap(image_input, coordinate, language, sentiment, factuality, length, state, click_state)
-
-def chat_with_points(chat_input, click_state, state):
- points, labels, captions = click_state
- # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\n. Now begin chatting! Human: {chat_input}\nAI: "
- # "The image is of width {width} and height {height}."
- point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
- prev_visual_context = ""
- pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
- prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
- chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
- response = model.text_refiner.llm(chat_prompt)
- state = state + [(chat_input, response)]
- return state, state
-
-def init_openai_api_key(api_key):
- # os.environ['OPENAI_API_KEY'] = api_key
- global model
- model = CaptionAnything(args, api_key)
-
-css='''
-#image_upload{min-height:200px}
-#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 200px}
-'''
-
-with gr.Blocks(css=css) as iface:
- state = gr.State([])
- click_state = gr.State([[],[],[]])
- caption_state = gr.State([[]])
- gr.Markdown(title)
- gr.Markdown(description)
-
- with gr.Column():
- openai_api_key = gr.Textbox(
- placeholder="Input your openAI API key and press Enter",
- show_label=False,
- lines=1,
- type="password",
- )
- openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
-
- with gr.Row():
- with gr.Column(scale=0.7):
- image_input = gr.Image(type="pil", interactive=True, label="Image", elem_id="image_upload").style(height=260,scale=1.0)
-
- with gr.Row(scale=0.7):
- point_prompt = gr.Radio(
- choices=["Positive Point", "Negative Point"],
- value="Positive Point",
- label="Points",
- interactive=True,
- )
-
- # with gr.Row():
- language = gr.Radio(
- choices=["English", "Chinese", "French", "Spanish", "Arabic", "Portuguese","Cantonese"],
- value="English",
- label="Language",
- interactive=True,
- )
- sentiment = gr.Radio(
- choices=["Positive", "Natural", "Negative"],
- value="Natural",
- label="Sentiment",
- interactive=True,
- )
- factuality = gr.Radio(
- choices=["Factual", "Imagination"],
- value="Factual",
- label="Factuality",
- interactive=True,
- )
- length = gr.Slider(
- minimum=5,
- maximum=100,
- value=10,
- step=1,
- interactive=True,
- label="Length",
- )
-
- with gr.Column(scale=1.5):
- with gr.Row():
- image_output_mask= gr.Image(type="pil", interactive=False, label="Mask").style(height=260,scale=1.0)
- image_output_crop= gr.Image(type="pil", interactive=False, label="Cropped Image by Mask", show_progress=False).style(height=260,scale=1.0)
- chatbot = gr.Chatbot(label="Chat Output",).style(height=450,scale=0.5)
-
- with gr.Row():
- with gr.Column(scale=0.7):
- prompt_input = gr.Textbox(lines=1, label="Input Prompt (A list of points like : [[100, 200, 1], [200,300,0]])")
- prompt_input.submit(
- inference_seg_cap,
- [
- image_input,
- prompt_input,
- language,
- sentiment,
- factuality,
- length,
- state,
- click_state
- ],
- [chatbot, state, click_state, image_output_mask, image_output_crop],
- show_progress=False
- )
-
- image_input.upload(
- upload_callback,
- [image_input, state],
- [chatbot]
- )
-
- with gr.Row():
- clear_button = gr.Button(value="Clear Click", interactive=True)
- clear_button.click(
- lambda: ("", [[], [], []], None, None),
- [],
- [prompt_input, click_state, image_output_mask, image_output_crop],
- queue=False,
- show_progress=False
- )
-
- clear_button = gr.Button(value="Clear", interactive=True)
- clear_button.click(
- lambda: ("", [], [], [[], [], []], None, None),
- [],
- [prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop],
- queue=False,
- show_progress=False
- )
-
- submit_button = gr.Button(
- value="Submit", interactive=True, variant="primary"
- )
- submit_button.click(
- inference_seg_cap,
- [
- image_input,
- prompt_input,
- language,
- sentiment,
- factuality,
- length,
- state,
- click_state
- ],
- [chatbot, state, click_state, image_output_mask, image_output_crop],
- show_progress=False
- )
-
- # select coordinate
- image_input.select(
- get_select_coords,
- inputs=[image_input,point_prompt,language,sentiment,factuality,length,state,click_state],
- outputs=[prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop],
- show_progress=False
- )
-
- image_input.change(
- lambda: ("", [], [[], [], []]),
- [],
- [chatbot, state, click_state],
- queue=False,
- )
-
- with gr.Column(scale=1.5):
- chat_input = gr.Textbox(lines=1, label="Chat Input")
- chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
-
-
- examples = gr.Examples(
- examples=examples,
- inputs=[image_input, prompt_input],
- )
-
-iface.queue(concurrency_count=1, api_open=False, max_size=10)
-iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
diff --git a/app_wo_langchain.py b/app_wo_langchain.py
new file mode 100644
index 0000000000000000000000000000000000000000..e511153274c35e9bf545fb3f5e7244c327de618c
--- /dev/null
+++ b/app_wo_langchain.py
@@ -0,0 +1,588 @@
+import os
+import json
+from typing import List
+
+import PIL
+import gradio as gr
+import numpy as np
+from gradio import processing_utils
+
+from packaging import version
+from PIL import Image, ImageDraw
+
+from caption_anything.model import CaptionAnything
+from caption_anything.utils.image_editing_utils import create_bubble_frame
+from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter
+from caption_anything.utils.parser import parse_augment
+from caption_anything.captioner import build_captioner
+from caption_anything.text_refiner import build_text_refiner
+from caption_anything.segmenter import build_segmenter
+from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
+from segment_anything import sam_model_registry
+
+
+args = parse_augment()
+
+args = parse_augment()
+if args.segmenter_checkpoint is None:
+ _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
+else:
+ segmenter_checkpoint = args.segmenter_checkpoint
+
+shared_captioner = build_captioner(args.captioner, args.device, args)
+shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
+
+
+class ImageSketcher(gr.Image):
+ """
+ Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
+ """
+
+ is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
+
+ def __init__(self, **kwargs):
+ super().__init__(tool="sketch", **kwargs)
+
+ def preprocess(self, x):
+ if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
+ assert isinstance(x, dict)
+ if x['mask'] is None:
+ decode_image = processing_utils.decode_base64_to_image(x['image'])
+ width, height = decode_image.size
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
+ mask[..., -1] = 255
+ mask = self.postprocess(mask)
+
+ x['mask'] = mask
+
+ return super().preprocess(x)
+
+
+def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None,
+ session_id=None):
+ segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
+ captioner = captioner
+ if session_id is not None:
+ print('Init caption anything for session {}'.format(session_id))
+ return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
+
+
+def init_openai_api_key(api_key=""):
+ text_refiner = None
+ if api_key and len(api_key) > 30:
+ try:
+ text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
+ text_refiner.llm('hi') # test
+ except:
+ text_refiner = None
+ openai_available = text_refiner is not None
+ return gr.update(visible=openai_available), gr.update(visible=openai_available), gr.update(
+ visible=openai_available), gr.update(visible=True), gr.update(visible=True), gr.update(
+ visible=True), text_refiner
+
+
+def get_click_prompt(chat_input, click_state, click_mode):
+ inputs = json.loads(chat_input)
+ if click_mode == 'Continuous':
+ points = click_state[0]
+ labels = click_state[1]
+ for input in inputs:
+ points.append(input[:2])
+ labels.append(input[2])
+ elif click_mode == 'Single':
+ points = []
+ labels = []
+ for input in inputs:
+ points.append(input[:2])
+ labels.append(input[2])
+ click_state[0] = points
+ click_state[1] = labels
+ else:
+ raise NotImplementedError
+
+ prompt = {
+ "prompt_type": ["click"],
+ "input_point": click_state[0],
+ "input_label": click_state[1],
+ "multimask_output": "True",
+ }
+ return prompt
+
+
+def update_click_state(click_state, caption, click_mode):
+ if click_mode == 'Continuous':
+ click_state[2].append(caption)
+ elif click_mode == 'Single':
+ click_state[2] = [caption]
+ else:
+ raise NotImplementedError
+
+
+def chat_with_points(chat_input, click_state, chat_state, state, text_refiner, img_caption):
+ if text_refiner is None:
+ response = "Text refiner is not initilzed, please input openai api key."
+ state = state + [(chat_input, response)]
+ return state, state, chat_state
+
+ points, labels, captions = click_state
+ # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
+ suffix = '\nHuman: {chat_input}\nAI: '
+ qa_template = '\nHuman: {q}\nAI: {a}'
+ # # "The image is of width {width} and height {height}."
+ point_chat_prompt = "I am an AI trained to chat with you about an image. I am greate at what is going on in any image based on the image information your provide. The overall image description is \"{img_caption}\". You will also provide me objects in the image in details, i.e., their location and visual descriptions. Here are the locations and descriptions of events that happen in the image: {points_with_caps} \nYou are required to use language instead of number to describe these positions. Now, let's chat!"
+ prev_visual_context = ""
+ pos_points = []
+ pos_captions = []
+
+ for i in range(len(points)):
+ if labels[i] == 1:
+ pos_points.append(f"(X:{points[i][0]}, Y:{points[i][1]})")
+ pos_captions.append(captions[i])
+ prev_visual_context = prev_visual_context + '\n' + 'There is an event described as \"{}\" locating at {}'.format(
+ pos_captions[-1], ', '.join(pos_points))
+
+ context_length_thres = 500
+ prev_history = ""
+ for i in range(len(chat_state)):
+ q, a = chat_state[i]
+ if len(prev_history) < context_length_thres:
+ prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
+ else:
+ break
+ chat_prompt = point_chat_prompt.format(
+ **{"img_caption": img_caption, "points_with_caps": prev_visual_context}) + prev_history + suffix.format(
+ **{"chat_input": chat_input})
+ print('\nchat_prompt: ', chat_prompt)
+ response = text_refiner.llm(chat_prompt)
+ state = state + [(chat_input, response)]
+ chat_state = chat_state + [(chat_input, response)]
+ return state, state, chat_state
+
+
+def upload_callback(image_input, state):
+ if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
+ image_input, mask = image_input['image'], image_input['mask']
+
+ chat_state = []
+ click_state = [[], [], []]
+ res = 1024
+ width, height = image_input.size
+ ratio = min(1.0 * res / max(width, height), 1.0)
+ if ratio < 1.0:
+ image_input = image_input.resize((int(width * ratio), int(height * ratio)))
+ print('Scaling input image to {}'.format(image_input.size))
+ state = [] + [(None, 'Image size: ' + str(image_input.size))]
+ model = build_caption_anything_with_models(
+ args,
+ api_key="",
+ captioner=shared_captioner,
+ sam_model=shared_sam_model,
+ session_id=iface.app_id
+ )
+ model.segmenter.set_image(image_input)
+ image_embedding = model.image_embedding
+ original_size = model.original_size
+ input_size = model.input_size
+ img_caption, _ = model.captioner.inference_seg(image_input)
+
+ return state, state, chat_state, image_input, click_state, image_input, image_input, image_embedding, \
+ original_size, input_size, img_caption
+
+
+def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
+ length, image_embedding, state, click_state, original_size, input_size, text_refiner,
+ evt: gr.SelectData):
+ click_index = evt.index
+
+ if point_prompt == 'Positive':
+ coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
+ else:
+ coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
+
+ prompt = get_click_prompt(coordinate, click_state, click_mode)
+ input_points = prompt['input_point']
+ input_labels = prompt['input_label']
+
+ controls = {'length': length,
+ 'sentiment': sentiment,
+ 'factuality': factuality,
+ 'language': language}
+
+ model = build_caption_anything_with_models(
+ args,
+ api_key="",
+ captioner=shared_captioner,
+ sam_model=shared_sam_model,
+ text_refiner=text_refiner,
+ session_id=iface.app_id
+ )
+
+ model.setup(image_embedding, original_size, input_size, is_image_set=True)
+
+ enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
+
+ state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
+ state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
+ wiki = out['generated_captions'].get('wiki', "")
+ update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
+ text = out['generated_captions']['raw_caption']
+ input_mask = np.array(out['mask'].convert('P'))
+ image_input = mask_painter(np.array(image_input), input_mask)
+ origin_image_input = image_input
+ image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask,
+ input_points=input_points, input_labels=input_labels)
+ yield state, state, click_state, image_input, wiki
+ if not args.disable_gpt and model.text_refiner:
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
+ enable_wiki=enable_wiki)
+ # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
+ new_cap = refined_caption['caption']
+ wiki = refined_caption['wiki']
+ state = state + [(None, f"caption: {new_cap}")]
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
+ input_mask,
+ input_points=input_points, input_labels=input_labels)
+ yield state, state, click_state, refined_image_input, wiki
+
+
+def get_sketch_prompt(mask: PIL.Image.Image, multi_mask=True):
+ """
+ Get the prompt for the sketcher.
+ TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
+ """
+
+ mask = np.array(np.asarray(mask)[..., 0])
+ mask[mask > 0] = 1 # Refine the mask, let all nonzero values be 1
+
+ if not multi_mask:
+ y, x = np.where(mask == 1)
+ x1, y1 = np.min(x), np.min(y)
+ x2, y2 = np.max(x), np.max(y)
+
+ prompt = {
+ 'prompt_type': ['box'],
+ 'input_boxes': [
+ [x1, y1, x2, y2]
+ ]
+ }
+
+ return prompt
+
+ traversed = np.zeros_like(mask)
+ groups = np.zeros_like(mask)
+ max_group_id = 1
+
+ # Iterate over all pixels
+ for x in range(mask.shape[0]):
+ for y in range(mask.shape[1]):
+ if traversed[x, y] == 1:
+ continue
+
+ if mask[x, y] == 0:
+ traversed[x, y] = 1
+ else:
+ # If pixel is part of mask
+ groups[x, y] = max_group_id
+ stack = [(x, y)]
+ while stack:
+ i, j = stack.pop()
+ if traversed[i, j] == 1:
+ continue
+ traversed[i, j] = 1
+ if mask[i, j] == 1:
+ groups[i, j] = max_group_id
+ for di, dj in [(1, 0), (-1, 0), (0, 1), (0, -1), (1, 1), (1, -1), (-1, 1), (-1, -1)]:
+ ni, nj = i + di, j + dj
+ traversed[i, j] = 1
+ if 0 <= nj < mask.shape[1] and mask.shape[0] > ni >= 0 == traversed[ni, nj]:
+ stack.append((i + di, j + dj))
+ max_group_id += 1
+
+ # get the bounding box of each group
+ boxes = []
+ for group in range(1, max_group_id):
+ y, x = np.where(groups == group)
+ x1, y1 = np.min(x), np.min(y)
+ x2, y2 = np.max(x), np.max(y)
+ boxes.append([x1, y1, x2, y2])
+
+ prompt = {
+ 'prompt_type': ['box'],
+ 'input_boxes': boxes
+ }
+
+ return prompt
+
+
+def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
+ original_size, input_size, text_refiner):
+ image_input, mask = sketcher_image['image'], sketcher_image['mask']
+
+ prompt = get_sketch_prompt(mask, multi_mask=False)
+ boxes = prompt['input_boxes']
+
+ controls = {'length': length,
+ 'sentiment': sentiment,
+ 'factuality': factuality,
+ 'language': language}
+
+ model = build_caption_anything_with_models(
+ args,
+ api_key="",
+ captioner=shared_captioner,
+ sam_model=shared_sam_model,
+ text_refiner=text_refiner,
+ session_id=iface.app_id
+ )
+
+ model.setup(image_embedding, original_size, input_size, is_image_set=True)
+
+ enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
+
+ # Update components and states
+ state.append((f'Box: {boxes}', None))
+ state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
+ wiki = out['generated_captions'].get('wiki', "")
+ text = out['generated_captions']['raw_caption']
+ input_mask = np.array(out['mask'].convert('P'))
+ image_input = mask_painter(np.array(image_input), input_mask)
+
+ origin_image_input = image_input
+
+ fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
+ image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
+
+ yield state, state, image_input, wiki
+
+ if not args.disable_gpt and model.text_refiner:
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
+ enable_wiki=enable_wiki)
+
+ new_cap = refined_caption['caption']
+ wiki = refined_caption['wiki']
+ state = state + [(None, f"caption: {new_cap}")]
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
+
+ yield state, state, refined_image_input, wiki
+
+
+def get_style():
+ current_version = version.parse(gr.__version__)
+ if current_version <= version.parse('3.24.1'):
+ style = '''
+ #image_sketcher{min-height:500px}
+ #image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
+ #image_upload{min-height:500px}
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
+ '''
+ elif current_version <= version.parse('3.27'):
+ style = '''
+ #image_sketcher{min-height:500px}
+ #image_upload{min-height:500px}
+ '''
+ else:
+ style = None
+
+ return style
+
+
+def create_ui():
+ title = """
Caption-Anything
+ """
+ description = """
Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
"""
+
+ examples = [
+ ["test_images/img35.webp"],
+ ["test_images/img2.jpg"],
+ ["test_images/img5.jpg"],
+ ["test_images/img12.jpg"],
+ ["test_images/img14.jpg"],
+ ["test_images/qingming3.jpeg"],
+ ["test_images/img1.jpg"],
+ ]
+
+ with gr.Blocks(
+ css=get_style()
+ ) as iface:
+ state = gr.State([])
+ click_state = gr.State([[], [], []])
+ chat_state = gr.State([])
+ origin_image = gr.State(None)
+ image_embedding = gr.State(None)
+ text_refiner = gr.State(None)
+ original_size = gr.State(None)
+ input_size = gr.State(None)
+ img_caption = gr.State(None)
+
+ gr.Markdown(title)
+ gr.Markdown(description)
+
+ with gr.Row():
+ with gr.Column(scale=1.0):
+ with gr.Column(visible=False) as modules_not_need_gpt:
+ with gr.Tab("Click"):
+ image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
+ example_image = gr.Image(type="pil", interactive=False, visible=False)
+ with gr.Row(scale=1.0):
+ with gr.Row(scale=0.4):
+ point_prompt = gr.Radio(
+ choices=["Positive", "Negative"],
+ value="Positive",
+ label="Point Prompt",
+ interactive=True)
+ click_mode = gr.Radio(
+ choices=["Continuous", "Single"],
+ value="Continuous",
+ label="Clicking Mode",
+ interactive=True)
+ with gr.Row(scale=0.4):
+ clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
+ with gr.Tab("Trajectory (Beta)"):
+ sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=20,
+ elem_id="image_sketcher")
+ with gr.Row():
+ submit_button_sketcher = gr.Button(value="Submit", interactive=True)
+
+ with gr.Column(visible=False) as modules_need_gpt:
+ with gr.Row(scale=1.0):
+ language = gr.Dropdown(
+ ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
+ value="English", label="Language", interactive=True)
+ sentiment = gr.Radio(
+ choices=["Positive", "Natural", "Negative"],
+ value="Natural",
+ label="Sentiment",
+ interactive=True,
+ )
+ with gr.Row(scale=1.0):
+ factuality = gr.Radio(
+ choices=["Factual", "Imagination"],
+ value="Factual",
+ label="Factuality",
+ interactive=True,
+ )
+ length = gr.Slider(
+ minimum=10,
+ maximum=80,
+ value=10,
+ step=1,
+ interactive=True,
+ label="Generated Caption Length",
+ )
+ enable_wiki = gr.Radio(
+ choices=["Yes", "No"],
+ value="No",
+ label="Enable Wiki",
+ interactive=True)
+ with gr.Column(visible=True) as modules_not_need_gpt3:
+ gr.Examples(
+ examples=examples,
+ inputs=[example_image],
+ )
+ with gr.Column(scale=0.5):
+ openai_api_key = gr.Textbox(
+ placeholder="Input openAI API key",
+ show_label=False,
+ label="OpenAI API Key",
+ lines=1,
+ type="password")
+ with gr.Row(scale=0.5):
+ enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
+ disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
+ variant='primary')
+ with gr.Column(visible=False) as modules_need_gpt2:
+ wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
+ with gr.Column(visible=False) as modules_not_need_gpt2:
+ chatbot = gr.Chatbot(label="Chat about Selected Object", ).style(height=550, scale=0.5)
+ with gr.Column(visible=False) as modules_need_gpt3:
+ chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
+ container=False)
+ with gr.Row():
+ clear_button_text = gr.Button(value="Clear Text", interactive=True)
+ submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
+
+ openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
+ outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
+ modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
+ enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
+ outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
+ modules_not_need_gpt,
+ modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
+ disable_chatGPT_button.click(init_openai_api_key,
+ outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
+ modules_not_need_gpt,
+ modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
+
+ clear_button_click.click(
+ lambda x: ([[], [], []], x, ""),
+ [origin_image],
+ [click_state, image_input, wiki_output],
+ queue=False,
+ show_progress=False
+ )
+ clear_button_image.click(
+ lambda: (None, [], [], [], [[], [], []], "", "", ""),
+ [],
+ [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
+ queue=False,
+ show_progress=False
+ )
+ clear_button_text.click(
+ lambda: ([], [], [[], [], [], []], []),
+ [],
+ [chatbot, state, click_state, chat_state],
+ queue=False,
+ show_progress=False
+ )
+ image_input.clear(
+ lambda: (None, [], [], [], [[], [], []], "", "", ""),
+ [],
+ [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
+ queue=False,
+ show_progress=False
+ )
+
+ image_input.upload(upload_callback, [image_input, state],
+ [chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
+ image_embedding, original_size, input_size, img_caption])
+ sketcher_input.upload(upload_callback, [sketcher_input, state],
+ [chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
+ image_embedding, original_size, input_size, img_caption])
+ chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner, img_caption],
+ [chatbot, state, chat_state])
+ chat_input.submit(lambda: "", None, chat_input)
+ example_image.change(upload_callback, [example_image, state],
+ [chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
+ image_embedding, original_size, input_size, img_caption])
+
+ # select coordinate
+ image_input.select(
+ inference_click,
+ inputs=[
+ origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
+ image_embedding, state, click_state, original_size, input_size, text_refiner
+ ],
+ outputs=[chatbot, state, click_state, image_input, wiki_output],
+ show_progress=False, queue=True
+ )
+
+ submit_button_sketcher.click(
+ inference_traject,
+ inputs=[
+ sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
+ original_size, input_size, text_refiner
+ ],
+ outputs=[chatbot, state, sketcher_input, wiki_output],
+ show_progress=False, queue=True
+ )
+
+ return iface
+
+
+if __name__ == '__main__':
+ iface = create_ui()
+ iface.queue(concurrency_count=5, api_open=False, max_size=10)
+ iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
diff --git a/caas.py b/caas.py
deleted file mode 100644
index a2ac7c91b24682a7cd255b99eba2729e29d3aa0f..0000000000000000000000000000000000000000
--- a/caas.py
+++ /dev/null
@@ -1,114 +0,0 @@
-from captioner import build_captioner, BaseCaptioner
-from segmenter import build_segmenter
-from text_refiner import build_text_refiner
-import os
-import argparse
-import pdb
-import time
-from PIL import Image
-
-class CaptionAnything():
- def __init__(self, args):
- self.args = args
- self.captioner = build_captioner(args.captioner, args.device, args)
- self.segmenter = build_segmenter(args.segmenter, args.device, args)
- if not args.disable_gpt:
- self.init_refiner()
-
-
- def init_refiner(self):
- if os.environ.get('OPENAI_API_KEY', None):
- self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args)
-
- def inference(self, image, prompt, controls, disable_gpt=False):
- # segment with prompt
- print("CA prompt: ", prompt, "CA controls",controls)
- seg_mask = self.segmenter.inference(image, prompt)[0, ...]
- mask_save_path = f'result/mask_{time.time()}.png'
- if not os.path.exists(os.path.dirname(mask_save_path)):
- os.makedirs(os.path.dirname(mask_save_path))
- new_p = Image.fromarray(seg_mask.astype('int') * 255.)
- if new_p.mode != 'RGB':
- new_p = new_p.convert('RGB')
- new_p.save(mask_save_path)
- print('seg_mask path: ', mask_save_path)
- print("seg_mask.shape: ", seg_mask.shape)
- # captioning with mask
- if self.args.enable_reduce_tokens:
- caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
- else:
- caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
- # refining with TextRefiner
- context_captions = []
- if self.args.context_captions:
- context_captions.append(self.captioner.inference(image))
- if not disable_gpt and hasattr(self, "text_refiner"):
- refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
- else:
- refined_caption = {'raw_caption': caption}
- out = {'generated_captions': refined_caption,
- 'crop_save_path': crop_save_path,
- 'mask_save_path': mask_save_path,
- 'context_captions': context_captions}
- return out
-
-def parse_augment():
- parser = argparse.ArgumentParser()
- parser.add_argument('--captioner', type=str, default="blip")
- parser.add_argument('--segmenter', type=str, default="base")
- parser.add_argument('--text_refiner', type=str, default="base")
- parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
- parser.add_argument('--seg_crop_mode', type=str, default="w_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
- parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
- parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption")
- parser.add_argument('--regular_box', action="store_true", default = False, help="crop image with a regular box")
- parser.add_argument('--device', type=str, default="cuda:0")
- parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
- parser.add_argument('--debug', action="store_true")
- parser.add_argument('--gradio_share', action="store_true")
- parser.add_argument('--disable_gpt', action="store_true")
- parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
- parser.add_argument('--disable_reuse_features', action="store_true", default=False)
- args = parser.parse_args()
-
- if args.debug:
- print(args)
- return args
-
-if __name__ == "__main__":
- args = parse_augment()
- # image_path = 'test_img/img3.jpg'
- image_path = 'test_img/img13.jpg'
- prompts = [
- {
- "prompt_type":["click"],
- "input_point":[[500, 300], [1000, 500]],
- "input_label":[1, 0],
- "multimask_output":"True",
- },
- {
- "prompt_type":["click"],
- "input_point":[[900, 800]],
- "input_label":[1],
- "multimask_output":"True",
- }
- ]
- controls = {
- "length": "30",
- "sentiment": "positive",
- # "imagination": "True",
- "imagination": "False",
- "language": "English",
- }
-
- model = CaptionAnything(args)
- for prompt in prompts:
- print('*'*30)
- print('Image path: ', image_path)
- image = Image.open(image_path)
- print(image)
- print('Visual controls (SAM prompt):\n', prompt)
- print('Language controls:\n', controls)
- out = model.inference(image_path, prompt, controls)
-
-
\ No newline at end of file
diff --git a/caption_anything.py b/caption_anything.py
deleted file mode 100644
index fe0cb9869e191c2ffda865d6410117911b8ebef8..0000000000000000000000000000000000000000
--- a/caption_anything.py
+++ /dev/null
@@ -1,132 +0,0 @@
-from captioner import build_captioner, BaseCaptioner
-from segmenter import build_segmenter
-from text_refiner import build_text_refiner
-import os
-import argparse
-import pdb
-import time
-from PIL import Image
-import cv2
-import numpy as np
-
-class CaptionAnything():
- def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
- self.args = args
- self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
- self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
-
- self.text_refiner = None
- if not args.disable_gpt:
- if text_refiner is not None:
- self.text_refiner = text_refiner
- else:
- self.init_refiner(api_key)
-
- def init_refiner(self, api_key):
- try:
- self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
- self.text_refiner.llm('hi') # test
- except:
- self.text_refiner = None
- print('OpenAI GPT is not available')
-
- def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
- # segment with prompt
- print("CA prompt: ", prompt, "CA controls",controls)
- seg_mask = self.segmenter.inference(image, prompt)[0, ...]
- if self.args.enable_morphologyex:
- seg_mask = 255 * seg_mask.astype(np.uint8)
- seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis = -1)
- seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel = np.ones((6, 6), np.uint8))
- seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel = np.ones((6, 6), np.uint8))
- seg_mask = seg_mask[:,:,0] > 0
- mask_save_path = f'result/mask_{time.time()}.png'
- if not os.path.exists(os.path.dirname(mask_save_path)):
- os.makedirs(os.path.dirname(mask_save_path))
- seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
- if seg_mask_img.mode != 'RGB':
- seg_mask_img = seg_mask_img.convert('RGB')
- seg_mask_img.save(mask_save_path)
- print('seg_mask path: ', mask_save_path)
- print("seg_mask.shape: ", seg_mask.shape)
- # captioning with mask
- if self.args.enable_reduce_tokens:
- caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
- else:
- caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
- # refining with TextRefiner
- context_captions = []
- if self.args.context_captions:
- context_captions.append(self.captioner.inference(image))
- if not disable_gpt and self.text_refiner is not None:
- refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions, enable_wiki=enable_wiki)
- else:
- refined_caption = {'raw_caption': caption}
- out = {'generated_captions': refined_caption,
- 'crop_save_path': crop_save_path,
- 'mask_save_path': mask_save_path,
- 'mask': seg_mask_img,
- 'context_captions': context_captions}
- return out
-
-def parse_augment():
- parser = argparse.ArgumentParser()
- parser.add_argument('--captioner', type=str, default="blip2")
- parser.add_argument('--segmenter', type=str, default="huge")
- parser.add_argument('--text_refiner', type=str, default="base")
- parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
- parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
- parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
- parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
- parser.add_argument('--disable_regular_box', action="store_true", default = False, help="crop image with a regular box")
- parser.add_argument('--device', type=str, default="cuda:0")
- parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
- parser.add_argument('--debug', action="store_true")
- parser.add_argument('--gradio_share', action="store_true")
- parser.add_argument('--disable_gpt', action="store_true")
- parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
- parser.add_argument('--disable_reuse_features', action="store_true", default=False)
- parser.add_argument('--enable_morphologyex', action="store_true", default=False)
- args = parser.parse_args()
-
- if args.debug:
- print(args)
- return args
-
-if __name__ == "__main__":
- args = parse_augment()
- # image_path = 'test_img/img3.jpg'
- image_path = 'test_img/img13.jpg'
- prompts = [
- {
- "prompt_type":["click"],
- "input_point":[[500, 300], [1000, 500]],
- "input_label":[1, 0],
- "multimask_output":"True",
- },
- {
- "prompt_type":["click"],
- "input_point":[[900, 800]],
- "input_label":[1],
- "multimask_output":"True",
- }
- ]
- controls = {
- "length": "30",
- "sentiment": "positive",
- # "imagination": "True",
- "imagination": "False",
- "language": "English",
- }
-
- model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
- for prompt in prompts:
- print('*'*30)
- print('Image path: ', image_path)
- image = Image.open(image_path)
- print(image)
- print('Visual controls (SAM prompt):\n', prompt)
- print('Language controls:\n', controls)
- out = model.inference(image_path, prompt, controls)
-
-
diff --git a/caption_anything/__init__.py b/caption_anything/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/captioner/README.md b/caption_anything/captioner/README.md
similarity index 100%
rename from captioner/README.md
rename to caption_anything/captioner/README.md
diff --git a/captioner/__init__.py b/caption_anything/captioner/__init__.py
similarity index 100%
rename from captioner/__init__.py
rename to caption_anything/captioner/__init__.py
diff --git a/captioner/base_captioner.py b/caption_anything/captioner/base_captioner.py
similarity index 99%
rename from captioner/base_captioner.py
rename to caption_anything/captioner/base_captioner.py
index 6da6f78368af7e57bcb6ae72abebebddecb22beb..eb2f59e1849a04e489789be107236117455d6704 100644
--- a/captioner/base_captioner.py
+++ b/caption_anything/captioner/base_captioner.py
@@ -191,7 +191,7 @@ class BaseCaptioner:
if __name__ == '__main__':
model = BaseCaptioner(device='cuda:0')
- image_path = 'test_img/img2.jpg'
+ image_path = 'test_images/img2.jpg'
seg_mask = np.zeros((15,15))
seg_mask[5:10, 5:10] = 1
seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
diff --git a/captioner/blip.py b/caption_anything/captioner/blip.py
similarity index 91%
rename from captioner/blip.py
rename to caption_anything/captioner/blip.py
index f9ee55d465b3ba705742b89cd11fd5cd0a6585ee..6ff54be847a70535b75557e44b57ed5a080d0afe 100644
--- a/captioner/blip.py
+++ b/caption_anything/captioner/blip.py
@@ -54,13 +54,13 @@ class BLIPCaptioner(BaseCaptioner):
if __name__ == '__main__':
model = BLIPCaptioner(device='cuda:0')
- # image_path = 'test_img/img2.jpg'
- image_path = '/group/30042/wybertwang/project/woa_visgpt/chatARC/image/SAM/img10.jpg'
+ # image_path = 'test_images/img2.jpg'
+ image_path = 'image/SAM/img10.jpg'
seg_mask = np.zeros((15,15))
seg_mask[5:10, 5:10] = 1
- seg_mask = 'test_img/img10.jpg.raw_mask.png'
- image_path = 'test_img/img2.jpg'
- seg_mask = 'test_img/img2.jpg.raw_mask.png'
+ seg_mask = 'test_images/img10.jpg.raw_mask.png'
+ image_path = 'test_images/img2.jpg'
+ seg_mask = 'test_images/img2.jpg.raw_mask.png'
print(f'process image {image_path}')
print(model.inference_with_reduced_tokens(image_path, seg_mask))
\ No newline at end of file
diff --git a/captioner/blip2.py b/caption_anything/captioner/blip2.py
similarity index 94%
rename from captioner/blip2.py
rename to caption_anything/captioner/blip2.py
index b63cfad7eebbda7ccbe6b92eafca44e625adb2be..3ea39eb87e572b28ec5ff330916ba2e8e0d90876 100644
--- a/captioner/blip2.py
+++ b/caption_anything/captioner/blip2.py
@@ -1,13 +1,10 @@
import torch
-from PIL import Image, ImageDraw, ImageOps
-from transformers import AutoProcessor, Blip2ForConditionalGeneration
-import json
-import pdb
-import cv2
+from PIL import Image
import numpy as np
from typing import Union
+from transformers import AutoProcessor, Blip2ForConditionalGeneration
-from tools import is_platform_win
+from caption_anything.utils.utils import is_platform_win
from .base_captioner import BaseCaptioner
class BLIP2Captioner(BaseCaptioner):
@@ -55,7 +52,7 @@ if __name__ == '__main__':
dialogue = False
model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
- image_path = 'test_img/img2.jpg'
+ image_path = 'test_images/img2.jpg'
seg_mask = np.zeros((224,224))
seg_mask[50:200, 50:200] = 1
print(f'process image {image_path}')
diff --git a/captioner/git.py b/caption_anything/captioner/git.py
similarity index 98%
rename from captioner/git.py
rename to caption_anything/captioner/git.py
index fd377378dc5fbeef20d40e64ff686ad62b9de292..9e0079038f534d36ddc833605a5823c38422b6cc 100644
--- a/captioner/git.py
+++ b/caption_anything/captioner/git.py
@@ -50,7 +50,7 @@ class GITCaptioner(BaseCaptioner):
if __name__ == '__main__':
model = GITCaptioner(device='cuda:2', enable_filter=False)
- image_path = 'test_img/img2.jpg'
+ image_path = 'test_images/img2.jpg'
seg_mask = np.zeros((224,224))
seg_mask[50:200, 50:200] = 1
print(f'process image {image_path}')
diff --git a/captioner/modeling_blip.py b/caption_anything/captioner/modeling_blip.py
similarity index 100%
rename from captioner/modeling_blip.py
rename to caption_anything/captioner/modeling_blip.py
diff --git a/captioner/modeling_git.py b/caption_anything/captioner/modeling_git.py
similarity index 100%
rename from captioner/modeling_git.py
rename to caption_anything/captioner/modeling_git.py
diff --git a/captioner/vit_pixel_masks_utils.py b/caption_anything/captioner/vit_pixel_masks_utils.py
similarity index 100%
rename from captioner/vit_pixel_masks_utils.py
rename to caption_anything/captioner/vit_pixel_masks_utils.py
diff --git a/caption_anything/model.py b/caption_anything/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..82b349229a1de99f939fde07dc3a5fdc8a31520b
--- /dev/null
+++ b/caption_anything/model.py
@@ -0,0 +1,147 @@
+import os
+import argparse
+import pdb
+import time
+from PIL import Image
+import cv2
+import numpy as np
+from caption_anything.captioner import build_captioner, BaseCaptioner
+from caption_anything.segmenter import build_segmenter
+from caption_anything.text_refiner import build_text_refiner
+
+
+class CaptionAnything:
+ def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
+ self.args = args
+ self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
+ self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
+
+ self.text_refiner = None
+ if not args.disable_gpt:
+ if text_refiner is not None:
+ self.text_refiner = text_refiner
+ else:
+ self.init_refiner(api_key)
+
+ @property
+ def image_embedding(self):
+ return self.segmenter.image_embedding
+
+ @image_embedding.setter
+ def image_embedding(self, image_embedding):
+ self.segmenter.image_embedding = image_embedding
+
+ @property
+ def original_size(self):
+ return self.segmenter.predictor.original_size
+
+ @original_size.setter
+ def original_size(self, original_size):
+ self.segmenter.predictor.original_size = original_size
+
+ @property
+ def input_size(self):
+ return self.segmenter.predictor.input_size
+
+ @input_size.setter
+ def input_size(self, input_size):
+ self.segmenter.predictor.input_size = input_size
+
+ def setup(self, image_embedding, original_size, input_size, is_image_set):
+ self.image_embedding = image_embedding
+ self.original_size = original_size
+ self.input_size = input_size
+ self.segmenter.predictor.is_image_set = is_image_set
+
+ def init_refiner(self, api_key):
+ try:
+ self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
+ self.text_refiner.llm('hi') # test
+ except:
+ self.text_refiner = None
+ print('OpenAI GPT is not available')
+
+ def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
+ # segment with prompt
+ print("CA prompt: ", prompt, "CA controls", controls)
+ seg_mask = self.segmenter.inference(image, prompt)[0, ...]
+ if self.args.enable_morphologyex:
+ seg_mask = 255 * seg_mask.astype(np.uint8)
+ seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1)
+ seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8))
+ seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8))
+ seg_mask = seg_mask[:, :, 0] > 0
+ mask_save_path = f'result/mask_{time.time()}.png'
+ if not os.path.exists(os.path.dirname(mask_save_path)):
+ os.makedirs(os.path.dirname(mask_save_path))
+ seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
+ if seg_mask_img.mode != 'RGB':
+ seg_mask_img = seg_mask_img.convert('RGB')
+ seg_mask_img.save(mask_save_path)
+ print('seg_mask path: ', mask_save_path)
+ print("seg_mask.shape: ", seg_mask.shape)
+ # captioning with mask
+ if self.args.enable_reduce_tokens:
+ caption, crop_save_path = self.captioner. \
+ inference_with_reduced_tokens(image, seg_mask,
+ crop_mode=self.args.seg_crop_mode,
+ filter=self.args.clip_filter,
+ disable_regular_box=self.args.disable_regular_box)
+ else:
+ caption, crop_save_path = self.captioner. \
+ inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode,
+ filter=self.args.clip_filter,
+ disable_regular_box=self.args.disable_regular_box)
+ # refining with TextRefiner
+ context_captions = []
+ if self.args.context_captions:
+ context_captions.append(self.captioner.inference(image))
+ if not disable_gpt and self.text_refiner is not None:
+ refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions,
+ enable_wiki=enable_wiki)
+ else:
+ refined_caption = {'raw_caption': caption}
+ out = {'generated_captions': refined_caption,
+ 'crop_save_path': crop_save_path,
+ 'mask_save_path': mask_save_path,
+ 'mask': seg_mask_img,
+ 'context_captions': context_captions}
+ return out
+
+
+if __name__ == "__main__":
+ from caption_anything.utils.parser import parse_augment
+ args = parse_augment()
+ # image_path = 'test_images/img3.jpg'
+ image_path = 'test_images/img1.jpg'
+ prompts = [
+ {
+ "prompt_type": ["click"],
+ "input_point": [[500, 300], [200, 500]],
+ "input_label": [1, 0],
+ "multimask_output": "True",
+ },
+ {
+ "prompt_type": ["click"],
+ "input_point": [[300, 800]],
+ "input_label": [1],
+ "multimask_output": "True",
+ }
+ ]
+ controls = {
+ "length": "30",
+ "sentiment": "positive",
+ # "imagination": "True",
+ "imagination": "False",
+ "language": "English",
+ }
+
+ model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
+ for prompt in prompts:
+ print('*' * 30)
+ print('Image path: ', image_path)
+ image = Image.open(image_path)
+ print(image)
+ print('Visual controls (SAM prompt):\n', prompt)
+ print('Language controls:\n', controls)
+ out = model.inference(image_path, prompt, controls)
diff --git a/caption_anything/segmenter/__init__.py b/caption_anything/segmenter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..067d43f7fdd09fa8a53cc7fa7e468d682bca9707
--- /dev/null
+++ b/caption_anything/segmenter/__init__.py
@@ -0,0 +1,5 @@
+from .base_segmenter import BaseSegmenter
+from caption_anything.utils.utils import seg_model_map
+
+def build_segmenter(model_name, device, args=None, model=None):
+ return BaseSegmenter(device, args.segmenter_checkpoint, model_name, reuse_feature=not args.disable_reuse_features, model=model)
\ No newline at end of file
diff --git a/segmenter/base_segmenter.py b/caption_anything/segmenter/base_segmenter.py
similarity index 62%
rename from segmenter/base_segmenter.py
rename to caption_anything/segmenter/base_segmenter.py
index aebd3ef067c5c1876282602e6004d0f964afa69f..6a5cdcbb17822c9bc55016241ba1c5b02614ffdb 100644
--- a/segmenter/base_segmenter.py
+++ b/caption_anything/segmenter/base_segmenter.py
@@ -5,19 +5,22 @@ from PIL import Image, ImageDraw, ImageOps
import numpy as np
from typing import Union
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
+from caption_anything.utils.utils import prepare_segmenter, seg_model_map
import matplotlib.pyplot as plt
import PIL
+
class BaseSegmenter:
- def __init__(self, device, checkpoint, model_type='vit_h', reuse_feature = True, model=None):
+ def __init__(self, device, checkpoint, model_name='huge', reuse_feature=True, model=None):
print(f"Initializing BaseSegmenter to {device}")
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.processor = None
- self.model_type = model_type
if model is None:
+ if checkpoint is None:
+ _, checkpoint = prepare_segmenter(model_name)
+ self.model = sam_model_registry[seg_model_map[model_name]](checkpoint=checkpoint)
self.checkpoint = checkpoint
- self.model = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
self.model.to(device=self.device)
else:
self.model = model
@@ -27,26 +30,57 @@ class BaseSegmenter:
self.image_embedding = None
self.image = None
-
- @torch.no_grad()
- def set_image(self, image: Union[np.ndarray, Image.Image, str]):
- if type(image) == str: # input path
+ def read_image(self, image: Union[np.ndarray, Image.Image, str]):
+ if type(image) == str: # input path
image = Image.open(image)
image = np.array(image)
elif type(image) == Image.Image:
image = np.array(image)
+ elif type(image) == np.ndarray:
+ image = image
+ else:
+ raise TypeError
+ return image
+
+ @torch.no_grad()
+ def set_image(self, image: Union[np.ndarray, Image.Image, str]):
+ image = self.read_image(image)
self.image = image
if self.reuse_feature:
self.predictor.set_image(image)
self.image_embedding = self.predictor.get_image_embedding()
print(self.image_embedding.shape)
-
@torch.no_grad()
- def inference(self, image, control):
+ def inference(self, image: Union[np.ndarray, Image.Image, str], control: dict):
+ """
+ SAM inference of image according to control.
+ Args:
+ image: str or PIL.Image or np.ndarray
+ control:
+ prompt_type:
+ 1. {control['prompt_type'] = ['everything']} to segment everything in the image.
+ 2. {control['prompt_type'] = ['click', 'box']} to segment according to click and box.
+ 3. {control['prompt_type'] = ['click'] to segment according to click.
+ 4. {control['prompt_type'] = ['box'] to segment according to box.
+ input_point: list of [x, y] coordinates of click.
+ input_label: List of labels for points accordingly, 0 for negative, 1 for positive.
+ input_box: List of [x1, y1, x2, y2] coordinates of box.
+ multimask_output:
+ If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ Returns:
+ masks: np.ndarray of shape [num_masks, height, width]
+
+ """
+ image = self.read_image(image) # Turn image into np.ndarray
if 'everything' in control['prompt_type']:
masks = self.mask_generator.generate(image)
- new_masks = np.concatenate([mask["segmentation"][np.newaxis,:] for mask in masks])
+ new_masks = np.concatenate([mask["segmentation"][np.newaxis, :] for mask in masks])
return new_masks
else:
if not self.reuse_feature or self.image_embedding is None:
@@ -55,17 +89,17 @@ class BaseSegmenter:
else:
assert self.image_embedding is not None
self.predictor.features = self.image_embedding
-
+
if 'mutimask_output' in control:
masks, scores, logits = self.predictor.predict(
- point_coords = np.array(control['input_point']),
- point_labels = np.array(control['input_label']),
- multimask_output = True,
+ point_coords=np.array(control['input_point']),
+ point_labels=np.array(control['input_label']),
+ multimask_output=True,
)
elif 'input_boxes' in control:
transformed_boxes = self.predictor.transform.apply_boxes_torch(
torch.tensor(control["input_boxes"], device=self.predictor.device),
- image.shape[:2]
+ image.shape[1::-1] # Reverse shape because numpy is (W, H) and function need (H, W)
)
masks, _, _ = self.predictor.predict_torch(
point_coords=None,
@@ -74,31 +108,32 @@ class BaseSegmenter:
multimask_output=False,
)
masks = masks.squeeze(1).cpu().numpy()
-
+
else:
input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None
input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None
input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None
-
+
masks, scores, logits = self.predictor.predict(
- point_coords = input_point,
- point_labels = input_label,
- box = input_box,
- multimask_output = False,
+ point_coords=input_point,
+ point_labels=input_label,
+ box=input_box,
+ multimask_output=False,
)
-
+
if 0 in control['input_label']:
mask_input = logits[np.argmax(scores), :, :]
masks, scores, logits = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
- box = input_box,
+ box=input_box,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
-
+
return masks
+
if __name__ == "__main__":
image_path = 'segmenter/images/truck.jpg'
prompts = [
@@ -109,9 +144,9 @@ if __name__ == "__main__":
# "multimask_output":"True",
# },
{
- "prompt_type":["click"],
- "input_point":[[1000, 600], [1325, 625]],
- "input_label":[1, 0],
+ "prompt_type": ["click"],
+ "input_point": [[1000, 600], [1325, 625]],
+ "input_label": [1, 0],
},
# {
# "prompt_type":["click", "box"],
@@ -132,7 +167,7 @@ if __name__ == "__main__":
# "prompt_type":["everything"]
# },
]
-
+
init_time = time.time()
segmenter = BaseSegmenter(
device='cuda',
@@ -142,8 +177,8 @@ if __name__ == "__main__":
reuse_feature=True
)
print(f'init time: {time.time() - init_time}')
-
- image_path = 'test_img/img2.jpg'
+
+ image_path = 'test_images/img2.jpg'
infer_time = time.time()
for i, prompt in enumerate(prompts):
print(f'{prompt["prompt_type"]} mode')
@@ -152,5 +187,5 @@ if __name__ == "__main__":
masks = segmenter.inference(np.array(image), prompt)
Image.fromarray(masks[0]).save('seg.png')
print(masks.shape)
-
+
print(f'infer time: {time.time() - infer_time}')
diff --git a/segmenter/readme.md b/caption_anything/segmenter/readme.md
similarity index 100%
rename from segmenter/readme.md
rename to caption_anything/segmenter/readme.md
diff --git a/text_refiner/README.md b/caption_anything/text_refiner/README.md
similarity index 100%
rename from text_refiner/README.md
rename to caption_anything/text_refiner/README.md
diff --git a/text_refiner/__init__.py b/caption_anything/text_refiner/__init__.py
similarity index 72%
rename from text_refiner/__init__.py
rename to caption_anything/text_refiner/__init__.py
index 99ed26932b683fe6eccfa232c6d6131d7432f816..853c07b880c00c5336b9f1e4e3c1f5e8d4c2ca74 100644
--- a/text_refiner/__init__.py
+++ b/caption_anything/text_refiner/__init__.py
@@ -1,4 +1,4 @@
-from text_refiner.text_refiner import TextRefiner
+from .text_refiner import TextRefiner
def build_text_refiner(type, device, args=None, api_key=""):
diff --git a/text_refiner/text_refiner.py b/caption_anything/text_refiner/text_refiner.py
similarity index 100%
rename from text_refiner/text_refiner.py
rename to caption_anything/text_refiner/text_refiner.py
diff --git a/caption_anything/utils/chatbot.py b/caption_anything/utils/chatbot.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b8cb6069ef52f943dc593d6900aa5d2698f1b16
--- /dev/null
+++ b/caption_anything/utils/chatbot.py
@@ -0,0 +1,236 @@
+# Copyright (c) Microsoft
+# Modified from Visual ChatGPT Project https://github.com/microsoft/TaskMatrix/blob/main/visual_chatgpt.py
+
+import os
+import gradio as gr
+import re
+import uuid
+from PIL import Image, ImageDraw, ImageOps
+import numpy as np
+import argparse
+import inspect
+
+from langchain.agents.initialize import initialize_agent
+from langchain.agents.tools import Tool
+from langchain.chains.conversation.memory import ConversationBufferMemory
+from langchain.llms.openai import OpenAI
+import torch
+from PIL import Image, ImageDraw, ImageOps
+from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
+
+VISUAL_CHATGPT_PREFIX = """
+ Caption Anything Chatbox (short as CATchat) is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. CATchat is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
+
+ As a language model, CATchat can not directly read images, but it has a list of tools to finish different visual tasks. CATchat can invoke different tools to indirectly understand pictures.
+
+ Visual ChatGPT has access to the following tools:"""
+
+
+# VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
+
+# Visual ChatGPT is able to process and understand large amounts of text and images. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "chat_image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name.
+
+# Visual ChatGPT is aware of the coordinate of an object in the image, which is represented as a point (X, Y) on the object. Note that (0, 0) represents the bottom-left corner of the image.
+
+# Human may provide new figures to Visual ChatGPT with a description. The description helps Visual ChatGPT to understand this image, but Visual ChatGPT should use tools to finish following tasks, rather than directly imagine from the description.
+
+# Overall, Visual ChatGPT is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics.
+
+
+# TOOLS:
+# ------
+
+# Visual ChatGPT has access to the following tools:"""
+
+VISUAL_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
+
+"Thought: Do I need to use a tool? Yes
+Action: the action to take, should be one of [{tool_names}], remember the action must to be one tool
+Action Input: the input to the action
+Observation: the result of the action"
+
+When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
+
+"Thought: Do I need to use a tool? No
+{ai_prefix}: [your response here]"
+
+"""
+
+VISUAL_CHATGPT_SUFFIX = """
+Begin Chatting!
+
+Previous conversation history:
+{chat_history}
+
+New input: {input}
+Since CATchat is a text language model, CATchat must use tools iteratively to observe images rather than imagination.
+The thoughts and observations are only visible for CATchat, CATchat should remember to repeat important information in the final response for Human.
+
+Thought: Do I need to use a tool? {agent_scratchpad} (You are strictly to use the aforementioned "Thought/Action/Action Input/Observation" format as the answer.)"""
+
+os.makedirs('chat_image', exist_ok=True)
+
+
+def prompts(name, description):
+ def decorator(func):
+ func.name = name
+ func.description = description
+ return func
+ return decorator
+
+def cut_dialogue_history(history_memory, keep_last_n_words=500):
+ if history_memory is None or len(history_memory) == 0:
+ return history_memory
+ tokens = history_memory.split()
+ n_tokens = len(tokens)
+ print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
+ if n_tokens < keep_last_n_words:
+ return history_memory
+ paragraphs = history_memory.split('\n')
+ last_n_tokens = n_tokens
+ while last_n_tokens >= keep_last_n_words:
+ last_n_tokens -= len(paragraphs[0].split(' '))
+ paragraphs = paragraphs[1:]
+ return '\n' + '\n'.join(paragraphs)
+
+def get_new_image_name(folder='chat_image', func_name="update"):
+ this_new_uuid = str(uuid.uuid4())[:8]
+ new_file_name = f'{func_name}_{this_new_uuid}.png'
+ return os.path.join(folder, new_file_name)
+
+class VisualQuestionAnswering:
+ def __init__(self, device):
+ print(f"Initializing VisualQuestionAnswering to {device}")
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
+ self.device = device
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
+ self.model = BlipForQuestionAnswering.from_pretrained(
+ "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
+ # self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
+ # self.model = BlipForQuestionAnswering.from_pretrained(
+ # "Salesforce/blip-vqa-capfilt-large", torch_dtype=self.torch_dtype).to(self.device)
+
+ @prompts(name="Answer Question About The Image",
+ description="useful when you need an answer for a question based on an image. "
+ "like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
+ "The input to this tool should be a comma separated string of two, representing the image_path and the question")
+ def inference(self, inputs):
+ image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
+ raw_image = Image.open(image_path).convert('RGB')
+ inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
+ out = self.model.generate(**inputs)
+ answer = self.processor.decode(out[0], skip_special_tokens=True)
+ print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
+ f"Output Answer: {answer}")
+ return answer
+
+def build_chatbot_tools(load_dict):
+ print(f"Initializing ChatBot, load_dict={load_dict}")
+ models = {}
+ # Load Basic Foundation Models
+ for class_name, device in load_dict.items():
+ models[class_name] = globals()[class_name](device=device)
+
+ # Load Template Foundation Models
+ for class_name, module in globals().items():
+ if getattr(module, 'template_model', False):
+ template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if k!='self'}
+ loaded_names = set([type(e).__name__ for e in models.values()])
+ if template_required_names.issubset(loaded_names):
+ models[class_name] = globals()[class_name](
+ **{name: models[name] for name in template_required_names})
+
+ tools = []
+ for instance in models.values():
+ for e in dir(instance):
+ if e.startswith('inference'):
+ func = getattr(instance, e)
+ tools.append(Tool(name=func.name, description=func.description, func=func))
+ return tools
+
+class ConversationBot:
+ def __init__(self, tools, api_key=""):
+ # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
+ llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
+ self.llm = llm
+ self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
+ self.tools = tools
+ self.current_image = None
+ self.point_prompt = ""
+ self.agent = initialize_agent(
+ self.tools,
+ self.llm,
+ agent="conversational-react-description",
+ verbose=True,
+ memory=self.memory,
+ return_intermediate_steps=True,
+ agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
+ 'suffix': VISUAL_CHATGPT_SUFFIX}, )
+
+ def constructe_intermediate_steps(self, agent_res):
+ ans = []
+ for action, output in agent_res:
+ if hasattr(action, "tool_input"):
+ use_tool = "Yes"
+ act = (f"Thought: Do I need to use a tool? {use_tool}\nAction: {action.tool}\nAction Input: {action.tool_input}", f"Observation: {output}")
+ else:
+ use_tool = "No"
+ act = (f"Thought: Do I need to use a tool? {use_tool}", f"AI: {output}")
+ act= list(map(lambda x: x.replace('\n', ' '), act))
+ ans.append(act)
+ return ans
+
+ def run_text(self, text, state, aux_state):
+ self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
+ if self.point_prompt != "":
+ Human_prompt = f'\nHuman: {self.point_prompt}\n'
+ AI_prompt = 'Ok'
+ self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
+ self.point_prompt = ""
+ res = self.agent({"input": text})
+ res['output'] = res['output'].replace("\\", "/")
+ response = re.sub('(chat_image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
+ state = state + [(text, response)]
+
+ aux_state = aux_state + [(f"User Input: {text}", None)]
+ aux_state = aux_state + self.constructe_intermediate_steps(res['intermediate_steps'])
+ print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
+ f"Current Memory: {self.agent.memory.buffer}\n"
+ f"Aux state: {aux_state}\n"
+ )
+ return state, state, aux_state, aux_state
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--load', type=str, default="VisualQuestionAnswering_cuda:0")
+ parser.add_argument('--port', type=int, default=1015)
+
+ args = parser.parse_args()
+ load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
+ tools = build_chatbot_tools(load_dict)
+ bot = ConversationBot(tools)
+ with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
+ with gr.Row():
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT").style(height=1000,scale=0.5)
+ auxwindow = gr.Chatbot(elem_id="chatbot", label="Aux Window").style(height=1000,scale=0.5)
+ state = gr.State([])
+ aux_state = gr.State([])
+ with gr.Row():
+ with gr.Column(scale=0.7):
+ txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
+ container=False)
+ with gr.Column(scale=0.15, min_width=0):
+ clear = gr.Button("Clear")
+ with gr.Column(scale=0.15, min_width=0):
+ btn = gr.UploadButton("Upload", file_types=["image"])
+
+ txt.submit(bot.run_text, [txt, state, aux_state], [chatbot, state, aux_state, auxwindow])
+ txt.submit(lambda: "", None, txt)
+ btn.upload(bot.run_image, [btn, state, txt, aux_state], [chatbot, state, txt, aux_state, auxwindow])
+ clear.click(bot.memory.clear)
+ clear.click(lambda: [], None, chatbot)
+ clear.click(lambda: [], None, auxwindow)
+ clear.click(lambda: [], None, state)
+ clear.click(lambda: [], None, aux_state)
+ demo.launch(server_name="0.0.0.0", server_port=args.port, share=True)
diff --git a/image_editing_utils.py b/caption_anything/utils/image_editing_utils.py
similarity index 85%
rename from image_editing_utils.py
rename to caption_anything/utils/image_editing_utils.py
index 31adc4bd7939b2133908ccb362b7e295ebc9ee24..d5806447a94532f8ce981a92d02a2c181569ed47 100644
--- a/image_editing_utils.py
+++ b/caption_anything/utils/image_editing_utils.py
@@ -1,7 +1,8 @@
from PIL import Image, ImageDraw, ImageFont
import copy
import numpy as np
-import cv2
+import cv2
+
def wrap_text(text, font, max_width):
lines = []
@@ -18,11 +19,18 @@ def wrap_text(text, font, max_width):
lines.append(current_line)
return lines
-def create_bubble_frame(image, text, point, segmask, input_points, input_labels, font_path='times_with_simsun.ttf', font_size_ratio=0.033, point_size_ratio=0.01):
+
+def create_bubble_frame(image, text, point, segmask, input_points=(), input_labels=(),
+ font_path='assets/times_with_simsun.ttf', font_size_ratio=0.033, point_size_ratio=0.01):
# Load the image
+ if input_points is None:
+ input_points = []
+ if input_labels is None:
+ input_labels = []
+
if type(image) == np.ndarray:
image = Image.fromarray(image)
-
+
image = copy.deepcopy(image)
width, height = image.size
@@ -47,19 +55,19 @@ def create_bubble_frame(image, text, point, segmask, input_points, input_labels,
bubble_height = text_height + 2 * padding
# Create a new image for the bubble frame
- bubble = Image.new('RGBA', (bubble_width, bubble_height), (255,248, 220, 0))
+ bubble = Image.new('RGBA', (bubble_width, bubble_height), (255, 248, 220, 0))
# Draw the bubble frame on the new image
draw = ImageDraw.Draw(bubble)
# draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
- draw_rounded_rectangle(draw, (0, 0, bubble_width - 1, bubble_height - 1), point_size * 2,
- fill=(255,248, 220, 120), outline=None, width=2)
+ draw_rounded_rectangle(draw, (0, 0, bubble_width - 1, bubble_height - 1), point_size * 2,
+ fill=(255, 248, 220, 120), outline=None, width=2)
# Draw the wrapped text line by line
y_text = padding
for line in lines:
draw.text((padding, y_text), line, font=font, fill=(0, 0, 0, 255))
y_text += font.getsize(line)[1]
-
+
# Determine the point by the min area rect of mask
try:
ret, thresh = cv2.threshold(segmask, 127, 255, 0)
@@ -109,7 +117,11 @@ def draw_rounded_rectangle(draw, xy, corner_radius, fill=None, outline=None, wid
width=width
)
- draw.pieslice((x1, y1, x1 + corner_radius * 2, y1 + corner_radius * 2), 180, 270, fill=fill, outline=outline, width=width)
- draw.pieslice((x2 - corner_radius * 2, y1, x2, y1 + corner_radius * 2), 270, 360, fill=fill, outline=outline, width=width)
- draw.pieslice((x2 - corner_radius * 2, y2 - corner_radius * 2, x2, y2), 0, 90, fill=fill, outline=outline, width=width)
- draw.pieslice((x1, y2 - corner_radius * 2, x1 + corner_radius * 2, y2), 90, 180, fill=fill, outline=outline, width=width)
+ draw.pieslice((x1, y1, x1 + corner_radius * 2, y1 + corner_radius * 2), 180, 270, fill=fill, outline=outline,
+ width=width)
+ draw.pieslice((x2 - corner_radius * 2, y1, x2, y1 + corner_radius * 2), 270, 360, fill=fill, outline=outline,
+ width=width)
+ draw.pieslice((x2 - corner_radius * 2, y2 - corner_radius * 2, x2, y2), 0, 90, fill=fill, outline=outline,
+ width=width)
+ draw.pieslice((x1, y2 - corner_radius * 2, x1 + corner_radius * 2, y2), 90, 180, fill=fill, outline=outline,
+ width=width)
diff --git a/caption_anything/utils/parser.py b/caption_anything/utils/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..50c1ededcb153e672c78c2ad74174084099cf958
--- /dev/null
+++ b/caption_anything/utils/parser.py
@@ -0,0 +1,29 @@
+import argparse
+
+def parse_augment():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--captioner', type=str, default="blip2")
+ parser.add_argument('--segmenter', type=str, default="huge")
+ parser.add_argument('--text_refiner', type=str, default="base")
+ parser.add_argument('--segmenter_checkpoint', type=str, default=None, help="SAM checkpoint path")
+ parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'],
+ help="whether to add or remove background of the image when captioning")
+ parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
+ parser.add_argument('--context_captions', action="store_true",
+ help="use surrounding captions to enhance current caption (TODO)")
+ parser.add_argument('--disable_regular_box', action="store_true", default=False,
+ help="crop image with a regular box")
+ parser.add_argument('--device', type=str, default="cuda:0")
+ parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
+ parser.add_argument('--debug', action="store_true")
+ parser.add_argument('--gradio_share', action="store_true")
+ parser.add_argument('--disable_gpt', action="store_true")
+ parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
+ parser.add_argument('--disable_reuse_features', action="store_true", default=False)
+ parser.add_argument('--enable_morphologyex', action="store_true", default=False)
+ parser.add_argument('--chat_tools_dict', type=str, default='VisualQuestionAnswering_cuda:0', help='Visual ChatGPT tools, only useful when running gradio applications')
+ args = parser.parse_args()
+
+ if args.debug:
+ print(args)
+ return args
diff --git a/caption_anything/utils/utils.py b/caption_anything/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1dde8a6b9de3a84cc93e741a2dd20ad7dccd173
--- /dev/null
+++ b/caption_anything/utils/utils.py
@@ -0,0 +1,419 @@
+import os
+import cv2
+import requests
+import numpy as np
+from PIL import Image
+import time
+import sys
+import urllib
+from tqdm import tqdm
+import hashlib
+
+def is_platform_win():
+ return sys.platform == "win32"
+
+
+def colormap(rgb=True):
+ color_list = np.array(
+ [
+ 0.000, 0.000, 0.000,
+ 1.000, 1.000, 1.000,
+ 1.000, 0.498, 0.313,
+ 0.392, 0.581, 0.929,
+ 0.000, 0.447, 0.741,
+ 0.850, 0.325, 0.098,
+ 0.929, 0.694, 0.125,
+ 0.494, 0.184, 0.556,
+ 0.466, 0.674, 0.188,
+ 0.301, 0.745, 0.933,
+ 0.635, 0.078, 0.184,
+ 0.300, 0.300, 0.300,
+ 0.600, 0.600, 0.600,
+ 1.000, 0.000, 0.000,
+ 1.000, 0.500, 0.000,
+ 0.749, 0.749, 0.000,
+ 0.000, 1.000, 0.000,
+ 0.000, 0.000, 1.000,
+ 0.667, 0.000, 1.000,
+ 0.333, 0.333, 0.000,
+ 0.333, 0.667, 0.000,
+ 0.333, 1.000, 0.000,
+ 0.667, 0.333, 0.000,
+ 0.667, 0.667, 0.000,
+ 0.667, 1.000, 0.000,
+ 1.000, 0.333, 0.000,
+ 1.000, 0.667, 0.000,
+ 1.000, 1.000, 0.000,
+ 0.000, 0.333, 0.500,
+ 0.000, 0.667, 0.500,
+ 0.000, 1.000, 0.500,
+ 0.333, 0.000, 0.500,
+ 0.333, 0.333, 0.500,
+ 0.333, 0.667, 0.500,
+ 0.333, 1.000, 0.500,
+ 0.667, 0.000, 0.500,
+ 0.667, 0.333, 0.500,
+ 0.667, 0.667, 0.500,
+ 0.667, 1.000, 0.500,
+ 1.000, 0.000, 0.500,
+ 1.000, 0.333, 0.500,
+ 1.000, 0.667, 0.500,
+ 1.000, 1.000, 0.500,
+ 0.000, 0.333, 1.000,
+ 0.000, 0.667, 1.000,
+ 0.000, 1.000, 1.000,
+ 0.333, 0.000, 1.000,
+ 0.333, 0.333, 1.000,
+ 0.333, 0.667, 1.000,
+ 0.333, 1.000, 1.000,
+ 0.667, 0.000, 1.000,
+ 0.667, 0.333, 1.000,
+ 0.667, 0.667, 1.000,
+ 0.667, 1.000, 1.000,
+ 1.000, 0.000, 1.000,
+ 1.000, 0.333, 1.000,
+ 1.000, 0.667, 1.000,
+ 0.167, 0.000, 0.000,
+ 0.333, 0.000, 0.000,
+ 0.500, 0.000, 0.000,
+ 0.667, 0.000, 0.000,
+ 0.833, 0.000, 0.000,
+ 1.000, 0.000, 0.000,
+ 0.000, 0.167, 0.000,
+ 0.000, 0.333, 0.000,
+ 0.000, 0.500, 0.000,
+ 0.000, 0.667, 0.000,
+ 0.000, 0.833, 0.000,
+ 0.000, 1.000, 0.000,
+ 0.000, 0.000, 0.167,
+ 0.000, 0.000, 0.333,
+ 0.000, 0.000, 0.500,
+ 0.000, 0.000, 0.667,
+ 0.000, 0.000, 0.833,
+ 0.000, 0.000, 1.000,
+ 0.143, 0.143, 0.143,
+ 0.286, 0.286, 0.286,
+ 0.429, 0.429, 0.429,
+ 0.571, 0.571, 0.571,
+ 0.714, 0.714, 0.714,
+ 0.857, 0.857, 0.857
+ ]
+ ).astype(np.float32)
+ color_list = color_list.reshape((-1, 3)) * 255
+ if not rgb:
+ color_list = color_list[:, ::-1]
+ return color_list
+
+
+color_list = colormap()
+color_list = color_list.astype('uint8').tolist()
+
+
+def vis_add_mask(image, mask, color, alpha, kernel_size):
+ color = np.array(color)
+ mask = mask.astype('float').copy()
+ mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha)
+ for i in range(3):
+ image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
+ return image
+
+
+def vis_add_mask_wo_blur(image, mask, color, alpha):
+ color = np.array(color)
+ mask = mask.astype('float').copy()
+ for i in range(3):
+ image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
+ return image
+
+
+def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
+ background_color = np.array(background_color)
+ contour_color = np.array(contour_color)
+
+ # background_mask = 1 - background_mask
+ # contour_mask = 1 - contour_mask
+
+ for i in range(3):
+ image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
+ + background_color[i] * (background_alpha-background_mask*background_alpha)
+
+ image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
+ + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
+
+ return image.astype('uint8')
+
+
+def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, background_color=0, paint_foreground=False):
+ """
+ add color mask to the background/foreground area
+ input_image: numpy array (w, h, C)
+ input_mask: numpy array (w, h)
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
+ background_blur_radius: radius of background blur, must be odd number
+ contour_width: width of mask contour, must be odd number
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
+ background_color: color index of the background (area with input_mask == False)
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
+ paint_foreground: True for paint on foreground, False for background. Default: Flase
+
+ Output:
+ painted_image: numpy array
+ """
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
+
+ # 0: background, 1: foreground
+ input_mask[input_mask>0] = 255
+ if paint_foreground:
+ painted_image = vis_add_mask(input_image, 255 - input_mask, color_list[background_color], background_alpha, background_blur_radius) # black for background
+ else:
+ # mask background
+ painted_image = vis_add_mask(input_image, input_mask, color_list[background_color], background_alpha, background_blur_radius) # black for background
+ # mask contour
+ contour_mask = input_mask.copy()
+ contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction
+ # widden contour
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
+ contour_mask = cv2.dilate(contour_mask, kernel)
+ painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
+ return painted_image
+
+
+def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1):
+ """
+ paint color mask on the all foreground area
+ input_image: numpy array with shape (w, h, C)
+ input_mask: list of masks, each mask is a numpy array with shape (w,h)
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
+ background_blur_radius: radius of background blur, must be odd number
+ contour_width: width of mask contour, must be odd number
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
+ background_color: color index of the background (area with input_mask == False)
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
+
+ Output:
+ painted_image: numpy array
+ """
+
+ for i, input_mask in enumerate(input_masks):
+ input_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, background_color=i + 2, paint_foreground=True)
+ return input_image
+
+def mask_generator_00(mask, background_radius, contour_radius):
+ # no background width when '00'
+ # distance map
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ contour_mask[contour_mask>0.5] = 1.
+
+ return mask, contour_mask
+
+
+def mask_generator_01(mask, background_radius, contour_radius):
+ # no background width when '00'
+ # distance map
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ return mask, contour_mask
+
+
+def mask_generator_10(mask, background_radius, contour_radius):
+ # distance map
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # .....:::::!!!!!
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
+ background_mask = (background_mask - np.min(background_mask))
+ background_mask = background_mask / np.max(background_mask)
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ contour_mask[contour_mask>0.5] = 1.
+ return background_mask, contour_mask
+
+
+def mask_generator_11(mask, background_radius, contour_radius):
+ # distance map
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # .....:::::!!!!!
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
+ background_mask = (background_mask - np.min(background_mask))
+ background_mask = background_mask / np.max(background_mask)
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ return background_mask, contour_mask
+
+
+def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
+ """
+ Input:
+ input_image: numpy array
+ input_mask: numpy array
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
+ background_blur_radius: radius of background blur, must be odd number
+ contour_width: width of mask contour, must be odd number
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
+ mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
+
+ Output:
+ painted_image: numpy array
+ """
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
+ assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
+
+ # downsample input image and mask
+ width, height = input_image.shape[0], input_image.shape[1]
+ res = 1024
+ ratio = min(1.0 * res / max(width, height), 1.0)
+ input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
+ input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
+
+ # 0: background, 1: foreground
+ msk = np.clip(input_mask, 0, 1)
+
+ # generate masks for background and contour pixels
+ background_radius = (background_blur_radius - 1) // 2
+ contour_radius = (contour_width - 1) // 2
+ generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
+ background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
+
+ # paint
+ painted_image = vis_add_mask_wo_gaussian \
+ (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
+
+ return painted_image
+
+
+if __name__ == '__main__':
+
+ background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
+ background_blur_radius = 31 # radius of background blur, must be odd number
+ contour_width = 11 # contour width, must be odd number
+ contour_color = 3 # id in color map, 0: black, 1: white, >1: others
+ contour_alpha = 1 # transparency of background, 0: no contour highlighted
+
+ # load input image and mask
+ input_image = np.array(Image.open('./test_images/painter_input_image.jpg').convert('RGB'))
+ input_mask = np.array(Image.open('./test_images/painter_input_mask.jpg').convert('P'))
+
+ # paint
+ overall_time_1 = 0
+ overall_time_2 = 0
+ overall_time_3 = 0
+ overall_time_4 = 0
+ overall_time_5 = 0
+
+ for i in range(50):
+ t2 = time.time()
+ painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
+ e2 = time.time()
+
+ t3 = time.time()
+ painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
+ e3 = time.time()
+
+ t1 = time.time()
+ painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
+ e1 = time.time()
+
+ t4 = time.time()
+ painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
+ e4 = time.time()
+
+ t5 = time.time()
+ painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
+ e5 = time.time()
+
+ overall_time_1 += (e1 - t1)
+ overall_time_2 += (e2 - t2)
+ overall_time_3 += (e3 - t3)
+ overall_time_4 += (e4 - t4)
+ overall_time_5 += (e5 - t5)
+
+ print(f'average time w gaussian: {overall_time_1/50}')
+ print(f'average time w/o gaussian00: {overall_time_2/50}')
+ print(f'average time w/o gaussian10: {overall_time_3/50}')
+ print(f'average time w/o gaussian01: {overall_time_4/50}')
+ print(f'average time w/o gaussian11: {overall_time_5/50}')
+
+ # save
+ painted_image_00 = Image.fromarray(painted_image_00)
+ painted_image_00.save('./test_images/painter_output_image_00.png')
+
+ painted_image_10 = Image.fromarray(painted_image_10)
+ painted_image_10.save('./test_images/painter_output_image_10.png')
+
+ painted_image_01 = Image.fromarray(painted_image_01)
+ painted_image_01.save('./test_images/painter_output_image_01.png')
+
+ painted_image_11 = Image.fromarray(painted_image_11)
+ painted_image_11.save('./test_images/painter_output_image_11.png')
+
+
+seg_model_map = {
+ 'base': 'vit_b',
+ 'large': 'vit_l',
+ 'huge': 'vit_h'
+}
+ckpt_url_map = {
+ 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
+ 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
+ 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
+}
+expected_sha256_map = {
+ 'vit_b': 'ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912',
+ 'vit_l': '3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622',
+ 'vit_h': 'a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e'
+}
+def prepare_segmenter(segmenter = "huge", download_root: str = None):
+ """
+ Prepare segmenter model and download checkpoint if necessary.
+
+ Returns: segmenter model name from 'vit_b', 'vit_l', 'vit_h'.
+
+ """
+
+ os.makedirs('result', exist_ok=True)
+ seg_model_name = seg_model_map[segmenter]
+ checkpoint_url = ckpt_url_map[seg_model_name]
+ folder = download_root or os.path.expanduser("~/.cache/SAM")
+ filename = os.path.basename(checkpoint_url)
+ segmenter_checkpoint = download_checkpoint(checkpoint_url, folder, filename, expected_sha256_map[seg_model_name])
+
+ return seg_model_name, segmenter_checkpoint
+
+
+def download_checkpoint(url, folder, filename, expected_sha256):
+ os.makedirs(folder, exist_ok=True)
+ download_target = os.path.join(folder, filename)
+ if os.path.isfile(download_target):
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+
+ print(f'Download SAM checkpoint {url}, saving to {download_target} ...')
+ with requests.get(url, stream=True) as response, open(download_target, "wb") as output:
+ progress = tqdm(total=int(response.headers.get('content-length', 0)), unit='B', unit_scale=True)
+ for data in response.iter_content(chunk_size=1024):
+ size = output.write(data)
+ progress.update(size)
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
+ return download_target
\ No newline at end of file
diff --git a/env.sh b/env.sh
deleted file mode 100644
index 5d9e5e913530ebdcbf28c48fe2c4d46250895756..0000000000000000000000000000000000000000
--- a/env.sh
+++ /dev/null
@@ -1,6 +0,0 @@
-conda create -n caption_anything python=3.8 -y
-source activate caption_anything
-pip install -r requirements.txt
-# cd segmentertengwang@connect.hku.hk
-# wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
-
diff --git a/segmenter/__init__.py b/segmenter/__init__.py
deleted file mode 100644
index 9ce25a446923643ae2c384c2a46b30dda7713df5..0000000000000000000000000000000000000000
--- a/segmenter/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from segmenter.base_segmenter import BaseSegmenter
-
-
-def build_segmenter(type, device, args=None, model=None):
- return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features, model=model)
\ No newline at end of file
diff --git a/segmenter/images/truck.jpg b/segmenter/images/truck.jpg
deleted file mode 100644
index 6b98688c3c84981200c06259b8d54820ebf85660..0000000000000000000000000000000000000000
Binary files a/segmenter/images/truck.jpg and /dev/null differ
diff --git a/segmenter/sam_vit_h_4b8939.pth b/segmenter/sam_vit_h_4b8939.pth
deleted file mode 100644
index 8523acce9ddab1cf7e355628a08b1aab8ce08a72..0000000000000000000000000000000000000000
--- a/segmenter/sam_vit_h_4b8939.pth
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
-size 2564550879
diff --git a/test_img/img0.png b/test_img/img0.png
deleted file mode 100644
index 19c2a9cd512c9a798821511f3dc7e15bf43680ba..0000000000000000000000000000000000000000
Binary files a/test_img/img0.png and /dev/null differ
diff --git a/test_img/img1.jpg b/test_img/img1.jpg
deleted file mode 100644
index 83c0c9eb9f5026fdb7a7f49fba081d4764ce0515..0000000000000000000000000000000000000000
Binary files a/test_img/img1.jpg and /dev/null differ
diff --git a/test_img/img1.jpg.raw_mask.png b/test_img/img1.jpg.raw_mask.png
deleted file mode 100644
index ba811712737fa16ca0fd79aa981ff8b6f65d6d5f..0000000000000000000000000000000000000000
Binary files a/test_img/img1.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img10.jpg b/test_img/img10.jpg
deleted file mode 100644
index b51fde5fbe4d06c4295270b100f8861bbb02a870..0000000000000000000000000000000000000000
Binary files a/test_img/img10.jpg and /dev/null differ
diff --git a/test_img/img10.jpg.raw_mask.png b/test_img/img10.jpg.raw_mask.png
deleted file mode 100644
index 9f9145d5c6f0c671d2c0d44f860ccc20aaf8e33f..0000000000000000000000000000000000000000
Binary files a/test_img/img10.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img11.jpg b/test_img/img11.jpg
deleted file mode 100644
index 698333f481ea34d1ebb379f4d5802939072c83db..0000000000000000000000000000000000000000
Binary files a/test_img/img11.jpg and /dev/null differ
diff --git a/test_img/img12.jpg b/test_img/img12.jpg
deleted file mode 100644
index 20a3789bad40238cc90cca7b8e0049aaad1e1dbd..0000000000000000000000000000000000000000
Binary files a/test_img/img12.jpg and /dev/null differ
diff --git a/test_img/img12.jpg.raw_mask.png b/test_img/img12.jpg.raw_mask.png
deleted file mode 100644
index 7c857a906c303eb038ce7af6eb37d69762301871..0000000000000000000000000000000000000000
Binary files a/test_img/img12.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img13.jpg b/test_img/img13.jpg
deleted file mode 100644
index 9374e1fa87e3103869e727a8c56fb22525adb715..0000000000000000000000000000000000000000
Binary files a/test_img/img13.jpg and /dev/null differ
diff --git a/test_img/img13.jpg.raw_mask.png b/test_img/img13.jpg.raw_mask.png
deleted file mode 100644
index 23f9dcebc3b52026ab6ce27fba85b48059b1bb8c..0000000000000000000000000000000000000000
Binary files a/test_img/img13.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img14.jpg b/test_img/img14.jpg
deleted file mode 100644
index f60ad955110a5238e80ef93af7bbce03a4322e48..0000000000000000000000000000000000000000
Binary files a/test_img/img14.jpg and /dev/null differ
diff --git a/test_img/img14.jpg.raw_mask.png b/test_img/img14.jpg.raw_mask.png
deleted file mode 100644
index da46cc403fc9eedd021e728db7921c91c5f43e05..0000000000000000000000000000000000000000
Binary files a/test_img/img14.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img15.jpg b/test_img/img15.jpg
deleted file mode 100644
index ab3ef5ec0c62965253ab782d0e0dbf02929588af..0000000000000000000000000000000000000000
Binary files a/test_img/img15.jpg and /dev/null differ
diff --git a/test_img/img15.jpg.raw_mask.png b/test_img/img15.jpg.raw_mask.png
deleted file mode 100644
index e5ebf143e53a9a74f471f8ee2b2209fc72854463..0000000000000000000000000000000000000000
Binary files a/test_img/img15.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img16.jpg b/test_img/img16.jpg
deleted file mode 100644
index 4871a9f5a7b300f34e99337097d2e178ad649ed9..0000000000000000000000000000000000000000
Binary files a/test_img/img16.jpg and /dev/null differ
diff --git a/test_img/img16.jpg.raw_mask.png b/test_img/img16.jpg.raw_mask.png
deleted file mode 100644
index 62739f20714ef4b48b24e24d77311d5a60e3268d..0000000000000000000000000000000000000000
Binary files a/test_img/img16.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img17.jpg b/test_img/img17.jpg
deleted file mode 100644
index 1b5534d3978b826e88ba0e88c9a9953062cbe57a..0000000000000000000000000000000000000000
Binary files a/test_img/img17.jpg and /dev/null differ
diff --git a/test_img/img18.jpg b/test_img/img18.jpg
deleted file mode 100644
index db9215dfbaefa5f1c64c03dd1b928de1c6117ff8..0000000000000000000000000000000000000000
--- a/test_img/img18.jpg
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:e02c393a23aadd1304497e3a9b41144df166d1cfda33ea3e00eed94e27da3aa4
-size 1372251
diff --git a/test_img/img19.jpg b/test_img/img19.jpg
deleted file mode 100644
index abbe797820425e442a2a6b99a22b327aee3e9961..0000000000000000000000000000000000000000
Binary files a/test_img/img19.jpg and /dev/null differ
diff --git a/test_img/img2.jpg b/test_img/img2.jpg
deleted file mode 100644
index 583f69ec771a6f562e8dd9511b61fb9034a1af64..0000000000000000000000000000000000000000
Binary files a/test_img/img2.jpg and /dev/null differ
diff --git a/test_img/img2.jpg.raw_mask.png b/test_img/img2.jpg.raw_mask.png
deleted file mode 100644
index d4d55444af67e7831aee27d3df86c5150f680e70..0000000000000000000000000000000000000000
Binary files a/test_img/img2.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img20.jpg b/test_img/img20.jpg
deleted file mode 100644
index 1c75bd821f2beb8cd72d56ad1e5f9064c96066c7..0000000000000000000000000000000000000000
Binary files a/test_img/img20.jpg and /dev/null differ
diff --git a/test_img/img21.jpg b/test_img/img21.jpg
deleted file mode 100644
index 98462cd6c0a8cfdbfe158f4484843d0a320d5dce..0000000000000000000000000000000000000000
Binary files a/test_img/img21.jpg and /dev/null differ
diff --git a/test_img/img22.jpg b/test_img/img22.jpg
deleted file mode 100644
index a6b898f4558d34b4a3fcd44dcffda58bbea2b942..0000000000000000000000000000000000000000
--- a/test_img/img22.jpg
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:5c5159bf7114d08967f95475176670043115b157bf700efa34190260cd917662
-size 1025438
diff --git a/test_img/img23.jpg b/test_img/img23.jpg
deleted file mode 100644
index 8b070a469b4009a565f167784c552ce9886769e8..0000000000000000000000000000000000000000
Binary files a/test_img/img23.jpg and /dev/null differ
diff --git a/test_img/img24.jpg b/test_img/img24.jpg
deleted file mode 100644
index c90f0967fe7878cba26a72014e0bf377f0fd9c7d..0000000000000000000000000000000000000000
Binary files a/test_img/img24.jpg and /dev/null differ
diff --git a/test_img/img25.jpg b/test_img/img25.jpg
deleted file mode 100644
index ad24ad0005f04a31fef9077793768bf68fc3276c..0000000000000000000000000000000000000000
Binary files a/test_img/img25.jpg and /dev/null differ
diff --git a/test_img/img27.jpg b/test_img/img27.jpg
deleted file mode 100644
index 08cac0fa26a959dbd2a4fb33043a75cb3a1b6d06..0000000000000000000000000000000000000000
Binary files a/test_img/img27.jpg and /dev/null differ
diff --git a/test_img/img28.jpg b/test_img/img28.jpg
deleted file mode 100644
index 31c7c4a57e21b8c3cf82ee9783179c3410472ed6..0000000000000000000000000000000000000000
Binary files a/test_img/img28.jpg and /dev/null differ
diff --git a/test_img/img29.jpg b/test_img/img29.jpg
deleted file mode 100644
index 5fbab4d5eafcd33d33db45669cc1a9ce7432f111..0000000000000000000000000000000000000000
Binary files a/test_img/img29.jpg and /dev/null differ
diff --git a/test_img/img3.jpg b/test_img/img3.jpg
deleted file mode 100644
index deeafdbc1d4ac40426f75ee7395ecd82025d6e95..0000000000000000000000000000000000000000
Binary files a/test_img/img3.jpg and /dev/null differ
diff --git a/test_img/img30.jpg b/test_img/img30.jpg
deleted file mode 100644
index 060d18f3481662e618e1cf376281fc39bfd0b41d..0000000000000000000000000000000000000000
Binary files a/test_img/img30.jpg and /dev/null differ
diff --git a/test_img/img31.jpg b/test_img/img31.jpg
deleted file mode 100644
index bef87cab085b7bd4f1025090d875bc09bc2e5c96..0000000000000000000000000000000000000000
Binary files a/test_img/img31.jpg and /dev/null differ
diff --git a/test_img/img32.jpg b/test_img/img32.jpg
deleted file mode 100644
index aa916c29c7839093f4092e8c258824a771901cd6..0000000000000000000000000000000000000000
Binary files a/test_img/img32.jpg and /dev/null differ
diff --git a/test_img/img33.jpg b/test_img/img33.jpg
deleted file mode 100644
index d91468d1488d494904ea0068dc811ba78aa69339..0000000000000000000000000000000000000000
Binary files a/test_img/img33.jpg and /dev/null differ
diff --git a/test_img/img34.jpg b/test_img/img34.jpg
deleted file mode 100644
index f05ceaa1b35c6b74e196d979efadc5f4b79b6170..0000000000000000000000000000000000000000
Binary files a/test_img/img34.jpg and /dev/null differ
diff --git a/test_img/img35.webp b/test_img/img35.webp
deleted file mode 100644
index 1b934790352d86a013cee6dfc4119701ee676b1d..0000000000000000000000000000000000000000
Binary files a/test_img/img35.webp and /dev/null differ
diff --git a/test_img/img4.jpg b/test_img/img4.jpg
deleted file mode 100644
index a204f5a7567288216ec7e18a5223e677ab397b36..0000000000000000000000000000000000000000
Binary files a/test_img/img4.jpg and /dev/null differ
diff --git a/test_img/img4.jpg.raw_mask.png b/test_img/img4.jpg.raw_mask.png
deleted file mode 100644
index dba4577e87cd6e52870525af2cdbde07940f100a..0000000000000000000000000000000000000000
Binary files a/test_img/img4.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img5.jpg b/test_img/img5.jpg
deleted file mode 100644
index 80e2e7e4b9505a1528b8d319d6b2efcbde16a9cf..0000000000000000000000000000000000000000
Binary files a/test_img/img5.jpg and /dev/null differ
diff --git a/test_img/img5.jpg.raw_mask.png b/test_img/img5.jpg.raw_mask.png
deleted file mode 100644
index d1924854e703bd95f8587436bf9297ccf775a041..0000000000000000000000000000000000000000
Binary files a/test_img/img5.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img6.jpg b/test_img/img6.jpg
deleted file mode 100644
index 35d44f6a08b0fab2b38efb68c65e7528d65aca48..0000000000000000000000000000000000000000
Binary files a/test_img/img6.jpg and /dev/null differ
diff --git a/test_img/img6.jpg.raw_mask.png b/test_img/img6.jpg.raw_mask.png
deleted file mode 100644
index 42fd658b5c69644b742d56ae567207e1898f6a7e..0000000000000000000000000000000000000000
Binary files a/test_img/img6.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img7.jpg b/test_img/img7.jpg
deleted file mode 100644
index 679431b782257372a0bbe19ab701c308d114f0d7..0000000000000000000000000000000000000000
Binary files a/test_img/img7.jpg and /dev/null differ
diff --git a/test_img/img7.jpg.raw_mask.png b/test_img/img7.jpg.raw_mask.png
deleted file mode 100644
index a15527829eac9c027c42e60c4702bc70e57db460..0000000000000000000000000000000000000000
Binary files a/test_img/img7.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img8.jpg b/test_img/img8.jpg
deleted file mode 100644
index 62ef2d4a1c9fb498fc3f2e3f8928fd3626832d0b..0000000000000000000000000000000000000000
Binary files a/test_img/img8.jpg and /dev/null differ
diff --git a/test_img/img8.jpg.raw_mask.png b/test_img/img8.jpg.raw_mask.png
deleted file mode 100644
index 285410c86e57c3905b223ec392e61570a98c369b..0000000000000000000000000000000000000000
Binary files a/test_img/img8.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/img9.jpg b/test_img/img9.jpg
deleted file mode 100644
index 49acb36e2ad271dac1fc629cd10440c14954e70e..0000000000000000000000000000000000000000
Binary files a/test_img/img9.jpg and /dev/null differ
diff --git a/test_img/img9.jpg.raw_mask.png b/test_img/img9.jpg.raw_mask.png
deleted file mode 100644
index 127a194bbb9088338966e1db84fd9c4dca94afdf..0000000000000000000000000000000000000000
Binary files a/test_img/img9.jpg.raw_mask.png and /dev/null differ
diff --git a/test_img/painter_input_image.jpg b/test_img/painter_input_image.jpg
deleted file mode 100644
index deeafdbc1d4ac40426f75ee7395ecd82025d6e95..0000000000000000000000000000000000000000
Binary files a/test_img/painter_input_image.jpg and /dev/null differ
diff --git a/test_img/painter_input_mask.jpg b/test_img/painter_input_mask.jpg
deleted file mode 100644
index 0720afed9caf1e4e8b1864a86a7004c43307d845..0000000000000000000000000000000000000000
Binary files a/test_img/painter_input_mask.jpg and /dev/null differ
diff --git a/test_img/painter_output_image.png b/test_img/painter_output_image.png
deleted file mode 100644
index 40b97bb859e559e82c03fff625c29f9b391723cc..0000000000000000000000000000000000000000
Binary files a/test_img/painter_output_image.png and /dev/null differ
diff --git a/times_with_simsun.ttf b/times_with_simsun.ttf
deleted file mode 100644
index 0213c4b5dd14af52f642645aa01e4503569f11b4..0000000000000000000000000000000000000000
--- a/times_with_simsun.ttf
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:0b15a12dd4bba4a48885c279a1d16590b652773f02137a7e62ede3411970c59f
-size 11066612
diff --git a/tools.py b/tools.py
deleted file mode 100644
index b16d6d84ff58334eab72af098791a3e849631b53..0000000000000000000000000000000000000000
--- a/tools.py
+++ /dev/null
@@ -1,348 +0,0 @@
-import cv2
-import torch
-import numpy as np
-from PIL import Image
-import copy
-import time
-import sys
-
-
-def is_platform_win():
- return sys.platform == "win32"
-
-
-def colormap(rgb=True):
- color_list = np.array(
- [
- 0.000, 0.000, 0.000,
- 1.000, 1.000, 1.000,
- 1.000, 0.498, 0.313,
- 0.392, 0.581, 0.929,
- 0.000, 0.447, 0.741,
- 0.850, 0.325, 0.098,
- 0.929, 0.694, 0.125,
- 0.494, 0.184, 0.556,
- 0.466, 0.674, 0.188,
- 0.301, 0.745, 0.933,
- 0.635, 0.078, 0.184,
- 0.300, 0.300, 0.300,
- 0.600, 0.600, 0.600,
- 1.000, 0.000, 0.000,
- 1.000, 0.500, 0.000,
- 0.749, 0.749, 0.000,
- 0.000, 1.000, 0.000,
- 0.000, 0.000, 1.000,
- 0.667, 0.000, 1.000,
- 0.333, 0.333, 0.000,
- 0.333, 0.667, 0.000,
- 0.333, 1.000, 0.000,
- 0.667, 0.333, 0.000,
- 0.667, 0.667, 0.000,
- 0.667, 1.000, 0.000,
- 1.000, 0.333, 0.000,
- 1.000, 0.667, 0.000,
- 1.000, 1.000, 0.000,
- 0.000, 0.333, 0.500,
- 0.000, 0.667, 0.500,
- 0.000, 1.000, 0.500,
- 0.333, 0.000, 0.500,
- 0.333, 0.333, 0.500,
- 0.333, 0.667, 0.500,
- 0.333, 1.000, 0.500,
- 0.667, 0.000, 0.500,
- 0.667, 0.333, 0.500,
- 0.667, 0.667, 0.500,
- 0.667, 1.000, 0.500,
- 1.000, 0.000, 0.500,
- 1.000, 0.333, 0.500,
- 1.000, 0.667, 0.500,
- 1.000, 1.000, 0.500,
- 0.000, 0.333, 1.000,
- 0.000, 0.667, 1.000,
- 0.000, 1.000, 1.000,
- 0.333, 0.000, 1.000,
- 0.333, 0.333, 1.000,
- 0.333, 0.667, 1.000,
- 0.333, 1.000, 1.000,
- 0.667, 0.000, 1.000,
- 0.667, 0.333, 1.000,
- 0.667, 0.667, 1.000,
- 0.667, 1.000, 1.000,
- 1.000, 0.000, 1.000,
- 1.000, 0.333, 1.000,
- 1.000, 0.667, 1.000,
- 0.167, 0.000, 0.000,
- 0.333, 0.000, 0.000,
- 0.500, 0.000, 0.000,
- 0.667, 0.000, 0.000,
- 0.833, 0.000, 0.000,
- 1.000, 0.000, 0.000,
- 0.000, 0.167, 0.000,
- 0.000, 0.333, 0.000,
- 0.000, 0.500, 0.000,
- 0.000, 0.667, 0.000,
- 0.000, 0.833, 0.000,
- 0.000, 1.000, 0.000,
- 0.000, 0.000, 0.167,
- 0.000, 0.000, 0.333,
- 0.000, 0.000, 0.500,
- 0.000, 0.000, 0.667,
- 0.000, 0.000, 0.833,
- 0.000, 0.000, 1.000,
- 0.143, 0.143, 0.143,
- 0.286, 0.286, 0.286,
- 0.429, 0.429, 0.429,
- 0.571, 0.571, 0.571,
- 0.714, 0.714, 0.714,
- 0.857, 0.857, 0.857
- ]
- ).astype(np.float32)
- color_list = color_list.reshape((-1, 3)) * 255
- if not rgb:
- color_list = color_list[:, ::-1]
- return color_list
-
-
-color_list = colormap()
-color_list = color_list.astype('uint8').tolist()
-
-
-def vis_add_mask(image, mask, color, alpha, kernel_size):
- color = np.array(color)
- mask = mask.astype('float').copy()
- mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha)
-
- for i in range(3):
- image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
-
- return image
-
-
-def vis_add_mask_wo_blur(image, mask, color, alpha):
- color = np.array(color)
- mask = mask.astype('float').copy()
- for i in range(3):
- image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
- return image
-
-
-def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
- background_color = np.array(background_color)
- contour_color = np.array(contour_color)
-
- # background_mask = 1 - background_mask
- # contour_mask = 1 - contour_mask
-
- for i in range(3):
- image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
- + background_color[i] * (background_alpha-background_mask*background_alpha)
-
- image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
- + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
-
- return image.astype('uint8')
-
-
-def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1):
- """
- Input:
- input_image: numpy array
- input_mask: numpy array
- background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
- background_blur_radius: radius of background blur, must be odd number
- contour_width: width of mask contour, must be odd number
- contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
- contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
-
- Output:
- painted_image: numpy array
- """
- assert input_image.shape[:2] == input_mask.shape, 'different shape'
- assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
-
-
- # 0: background, 1: foreground
- input_mask[input_mask>0] = 255
-
- # mask background
- painted_image = vis_add_mask(input_image, input_mask, color_list[0], background_alpha, background_blur_radius) # black for background
- # mask contour
- contour_mask = input_mask.copy()
- contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction
- # widden contour
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
- contour_mask = cv2.dilate(contour_mask, kernel)
- painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
-
- # painted_image = background_dist_map
-
- return painted_image
-
-
-def mask_generator_00(mask, background_radius, contour_radius):
- # no background width when '00'
- # distance map
- dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
- dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
- dist_map = dist_transform_fore - dist_transform_back
- # ...:::!!!:::...
- contour_radius += 2
- contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
- contour_mask = contour_mask / np.max(contour_mask)
- contour_mask[contour_mask>0.5] = 1.
-
- return mask, contour_mask
-
-
-def mask_generator_01(mask, background_radius, contour_radius):
- # no background width when '00'
- # distance map
- dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
- dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
- dist_map = dist_transform_fore - dist_transform_back
- # ...:::!!!:::...
- contour_radius += 2
- contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
- contour_mask = contour_mask / np.max(contour_mask)
- return mask, contour_mask
-
-
-def mask_generator_10(mask, background_radius, contour_radius):
- # distance map
- dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
- dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
- dist_map = dist_transform_fore - dist_transform_back
- # .....:::::!!!!!
- background_mask = np.clip(dist_map, -background_radius, background_radius)
- background_mask = (background_mask - np.min(background_mask))
- background_mask = background_mask / np.max(background_mask)
- # ...:::!!!:::...
- contour_radius += 2
- contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
- contour_mask = contour_mask / np.max(contour_mask)
- contour_mask[contour_mask>0.5] = 1.
- return background_mask, contour_mask
-
-
-def mask_generator_11(mask, background_radius, contour_radius):
- # distance map
- dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
- dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
- dist_map = dist_transform_fore - dist_transform_back
- # .....:::::!!!!!
- background_mask = np.clip(dist_map, -background_radius, background_radius)
- background_mask = (background_mask - np.min(background_mask))
- background_mask = background_mask / np.max(background_mask)
- # ...:::!!!:::...
- contour_radius += 2
- contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
- contour_mask = contour_mask / np.max(contour_mask)
- return background_mask, contour_mask
-
-
-def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
- """
- Input:
- input_image: numpy array
- input_mask: numpy array
- background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
- background_blur_radius: radius of background blur, must be odd number
- contour_width: width of mask contour, must be odd number
- contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
- contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
- mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
-
- Output:
- painted_image: numpy array
- """
- assert input_image.shape[:2] == input_mask.shape, 'different shape'
- assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
- assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
-
- # downsample input image and mask
- width, height = input_image.shape[0], input_image.shape[1]
- res = 1024
- ratio = min(1.0 * res / max(width, height), 1.0)
- input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
- input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
-
- # 0: background, 1: foreground
- msk = np.clip(input_mask, 0, 1)
-
- # generate masks for background and contour pixels
- background_radius = (background_blur_radius - 1) // 2
- contour_radius = (contour_width - 1) // 2
- generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
- background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
-
- # paint
- painted_image = vis_add_mask_wo_gaussian \
- (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
-
- return painted_image
-
-
-if __name__ == '__main__':
-
- background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
- background_blur_radius = 31 # radius of background blur, must be odd number
- contour_width = 11 # contour width, must be odd number
- contour_color = 3 # id in color map, 0: black, 1: white, >1: others
- contour_alpha = 1 # transparency of background, 0: no contour highlighted
-
- # load input image and mask
- input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
- input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
-
- # paint
- overall_time_1 = 0
- overall_time_2 = 0
- overall_time_3 = 0
- overall_time_4 = 0
- overall_time_5 = 0
-
- for i in range(50):
- t2 = time.time()
- painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
- e2 = time.time()
-
- t3 = time.time()
- painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
- e3 = time.time()
-
- t1 = time.time()
- painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
- e1 = time.time()
-
- t4 = time.time()
- painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
- e4 = time.time()
-
- t5 = time.time()
- painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
- e5 = time.time()
-
- overall_time_1 += (e1 - t1)
- overall_time_2 += (e2 - t2)
- overall_time_3 += (e3 - t3)
- overall_time_4 += (e4 - t4)
- overall_time_5 += (e5 - t5)
-
- print(f'average time w gaussian: {overall_time_1/50}')
- print(f'average time w/o gaussian00: {overall_time_2/50}')
- print(f'average time w/o gaussian10: {overall_time_3/50}')
- print(f'average time w/o gaussian01: {overall_time_4/50}')
- print(f'average time w/o gaussian11: {overall_time_5/50}')
-
- # save
- painted_image_00 = Image.fromarray(painted_image_00)
- painted_image_00.save('./test_img/painter_output_image_00.png')
-
- painted_image_10 = Image.fromarray(painted_image_10)
- painted_image_10.save('./test_img/painter_output_image_10.png')
-
- painted_image_01 = Image.fromarray(painted_image_01)
- painted_image_01.save('./test_img/painter_output_image_01.png')
-
- painted_image_11 = Image.fromarray(painted_image_11)
- painted_image_11.save('./test_img/painter_output_image_11.png')