Spaces:
Runtime error
Runtime error
update lightweight code
Browse files- app.py +10 -8
- app_w_sam.py +0 -139
- models/__pycache__/blip2_model.cpython-38.pyc +0 -0
- models/__pycache__/controlnet_model.cpython-38.pyc +0 -0
- models/__pycache__/gpt_model.cpython-38.pyc +0 -0
- models/__pycache__/grit_model.cpython-38.pyc +0 -0
- models/__pycache__/image_text_transformation.cpython-38.pyc +0 -0
- models/__pycache__/region_semantic.cpython-38.pyc +0 -0
- models/blip2_model.py +15 -10
- models/controlnet_model.py +4 -12
- models/gpt_model.py +1 -1
- models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc +0 -0
- models/grit_src/image_dense_captions.py +2 -0
- models/image_text_transformation.py +6 -7
- models/region_semantic.py +32 -10
- models/segment_models/__pycache__/edit_anything_model.cpython-38.pyc +0 -0
- models/segment_models/__pycache__/semantic_segment_anything_model.cpython-38.pyc +0 -0
- models/segment_models/__pycache__/semgent_anything_model.cpython-38.pyc +0 -0
- models/segment_models/edit_anything_model.py +62 -0
- models/segment_models/semantic_segment_anything_model.py +2 -0
- models/segment_models/semgent_anything_model.py +11 -2
- pretrained_models/sam_vit_b_01ec64.pth +3 -0
- requirements.txt +1 -0
- utils/__pycache__/util.cpython-38.pyc +0 -0
- utils/image_dense_captions.py +0 -108
- utils/util.py +14 -1
app.py
CHANGED
@@ -12,10 +12,13 @@ parser = argparse.ArgumentParser()
|
|
12 |
parser.add_argument('--gpt_version', choices=['gpt-3.5-turbo', 'gpt4'], default='gpt-3.5-turbo')
|
13 |
parser.add_argument('--image_caption', action='store_true', dest='image_caption', default=True, help='Set this flag to True if you want to use BLIP2 Image Caption')
|
14 |
parser.add_argument('--dense_caption', action='store_true', dest='dense_caption', default=True, help='Set this flag to True if you want to use Dense Caption')
|
15 |
-
parser.add_argument('--semantic_segment', action='store_true', dest='semantic_segment', default=
|
16 |
-
parser.add_argument('--
|
17 |
-
parser.add_argument('--
|
18 |
-
parser.add_argument('--
|
|
|
|
|
|
|
19 |
parser.add_argument('--contolnet_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, <6G GPU is not recommended>')
|
20 |
|
21 |
args = parser.parse_args()
|
@@ -49,8 +52,7 @@ def process_image(image_src, options=None, processor=None):
|
|
49 |
print(options)
|
50 |
if options is None:
|
51 |
options = []
|
52 |
-
|
53 |
-
processor.args.semantic_segment = False
|
54 |
image_generation_status = "Image Generation" in options
|
55 |
image_caption, dense_caption, region_semantic, gen_text = processor.image_to_text(image_src)
|
56 |
if image_generation_status:
|
@@ -96,7 +98,7 @@ processor = ImageTextTransformation(args)
|
|
96 |
|
97 |
# Create Gradio input and output components
|
98 |
image_input = gr.inputs.Image(type='filepath', label="Input Image")
|
99 |
-
|
100 |
image_generation_checkbox = gr.inputs.Checkbox(label="Image Generation", default=False)
|
101 |
|
102 |
|
@@ -120,7 +122,7 @@ interface = gr.Interface(
|
|
120 |
inputs=[image_input,
|
121 |
gr.CheckboxGroup(
|
122 |
label="Options",
|
123 |
-
choices=["Image Generation"],
|
124 |
),
|
125 |
],
|
126 |
outputs=gr.outputs.HTML(),
|
|
|
12 |
parser.add_argument('--gpt_version', choices=['gpt-3.5-turbo', 'gpt4'], default='gpt-3.5-turbo')
|
13 |
parser.add_argument('--image_caption', action='store_true', dest='image_caption', default=True, help='Set this flag to True if you want to use BLIP2 Image Caption')
|
14 |
parser.add_argument('--dense_caption', action='store_true', dest='dense_caption', default=True, help='Set this flag to True if you want to use Dense Caption')
|
15 |
+
parser.add_argument('--semantic_segment', action='store_true', dest='semantic_segment', default=True, help='Set this flag to True if you want to use semantic segmentation')
|
16 |
+
parser.add_argument('--sam_arch', choices=['vit_b', 'vit_l', 'vit_h'], dest='sam_arch', default='vit_b', help='vit_b is the default model (fast but not accurate), vit_l and vit_h are larger models')
|
17 |
+
parser.add_argument('--captioner_base_model', choices=['blip', 'blip2'], dest='captioner_base_model', default='blip', help='blip2 requires 15G GPU memory, blip requires 6G GPU memory')
|
18 |
+
parser.add_argument('--region_classify_model', choices=['ssa', 'edit_anything'], dest='region_classify_model', default='edit_anything', help='Select the region classification model: edit anything is ten times faster than ssa, but less accurate.')
|
19 |
+
parser.add_argument('--image_caption_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
|
20 |
+
parser.add_argument('--dense_caption_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, < 6G GPU is not recommended>')
|
21 |
+
parser.add_argument('--semantic_segment_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended. Make sue this model and image_caption model on same device.')
|
22 |
parser.add_argument('--contolnet_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, <6G GPU is not recommended>')
|
23 |
|
24 |
args = parser.parse_args()
|
|
|
52 |
print(options)
|
53 |
if options is None:
|
54 |
options = []
|
55 |
+
processor.args.semantic_segment = "Semantic Segment" in options
|
|
|
56 |
image_generation_status = "Image Generation" in options
|
57 |
image_caption, dense_caption, region_semantic, gen_text = processor.image_to_text(image_src)
|
58 |
if image_generation_status:
|
|
|
98 |
|
99 |
# Create Gradio input and output components
|
100 |
image_input = gr.inputs.Image(type='filepath', label="Input Image")
|
101 |
+
semantic_segment_checkbox = gr.inputs.Checkbox(label="Semantic Segment", default=False)
|
102 |
image_generation_checkbox = gr.inputs.Checkbox(label="Image Generation", default=False)
|
103 |
|
104 |
|
|
|
122 |
inputs=[image_input,
|
123 |
gr.CheckboxGroup(
|
124 |
label="Options",
|
125 |
+
choices=["Image Generation", "Semantic Segment"],
|
126 |
),
|
127 |
],
|
128 |
outputs=gr.outputs.HTML(),
|
app_w_sam.py
DELETED
@@ -1,139 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import cv2
|
3 |
-
import numpy as np
|
4 |
-
from PIL import Image
|
5 |
-
import base64
|
6 |
-
from io import BytesIO
|
7 |
-
from models.image_text_transformation import ImageTextTransformation
|
8 |
-
import argparse
|
9 |
-
import torch
|
10 |
-
|
11 |
-
parser = argparse.ArgumentParser()
|
12 |
-
parser.add_argument('--gpt_version', choices=['gpt-3.5-turbo', 'gpt4'], default='gpt-3.5-turbo')
|
13 |
-
parser.add_argument('--image_caption', action='store_true', dest='image_caption', default=True, help='Set this flag to True if you want to use BLIP2 Image Caption')
|
14 |
-
parser.add_argument('--dense_caption', action='store_true', dest='dense_caption', default=True, help='Set this flag to True if you want to use Dense Caption')
|
15 |
-
parser.add_argument('--semantic_segment', action='store_true', dest='semantic_segment', default=True, help='Set this flag to True if you want to use semantic segmentation')
|
16 |
-
parser.add_argument('--image_caption_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
|
17 |
-
parser.add_argument('--dense_caption_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, < 6G GPU is not recommended>')
|
18 |
-
parser.add_argument('--semantic_segment_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
|
19 |
-
parser.add_argument('--contolnet_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, <6G GPU is not recommended>')
|
20 |
-
|
21 |
-
args = parser.parse_args()
|
22 |
-
|
23 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
-
# device = "cpu"
|
25 |
-
|
26 |
-
if device == "cuda":
|
27 |
-
args.image_caption_device = "cpu"
|
28 |
-
args.dense_caption_device = "cuda"
|
29 |
-
args.semantic_segment_device = "cuda"
|
30 |
-
args.contolnet_device = "cuda"
|
31 |
-
else:
|
32 |
-
args.image_caption_device = "cpu"
|
33 |
-
args.dense_caption_device = "cpu"
|
34 |
-
args.semantic_segment_device = "cpu"
|
35 |
-
args.contolnet_device = "cpu"
|
36 |
-
|
37 |
-
def pil_image_to_base64(image):
|
38 |
-
buffered = BytesIO()
|
39 |
-
image.save(buffered, format="JPEG")
|
40 |
-
img_str = base64.b64encode(buffered.getvalue()).decode()
|
41 |
-
return img_str
|
42 |
-
|
43 |
-
def add_logo():
|
44 |
-
with open("examples/logo.png", "rb") as f:
|
45 |
-
logo_base64 = base64.b64encode(f.read()).decode()
|
46 |
-
return logo_base64
|
47 |
-
|
48 |
-
def process_image(image_src, options=None, processor=None):
|
49 |
-
print(options)
|
50 |
-
if options is None:
|
51 |
-
options = []
|
52 |
-
processor.args.semantic_segment = "Semantic Segment" in options
|
53 |
-
image_generation_status = "Image Generation" in options
|
54 |
-
image_caption, dense_caption, region_semantic, gen_text = processor.image_to_text(image_src)
|
55 |
-
if image_generation_status:
|
56 |
-
gen_image = processor.text_to_image(gen_text)
|
57 |
-
gen_image_str = pil_image_to_base64(gen_image)
|
58 |
-
# Combine the outputs into a single HTML output
|
59 |
-
custom_output = f'''
|
60 |
-
<h2>Image->Text:</h2>
|
61 |
-
<div style="display: flex; flex-wrap: wrap;">
|
62 |
-
<div style="flex: 1;">
|
63 |
-
<h3>Image Caption</h3>
|
64 |
-
<p>{image_caption}</p>
|
65 |
-
</div>
|
66 |
-
<div style="flex: 1;">
|
67 |
-
<h3>Dense Caption</h3>
|
68 |
-
<p>{dense_caption}</p>
|
69 |
-
</div>
|
70 |
-
<div style="flex: 1;">
|
71 |
-
<h3>Region Semantic</h3>
|
72 |
-
<p>{region_semantic}</p>
|
73 |
-
</div>
|
74 |
-
</div>
|
75 |
-
<div style="display: flex; flex-wrap: wrap;">
|
76 |
-
<div style="flex: 1;">
|
77 |
-
<h3>GPT4 Reasoning:</h3>
|
78 |
-
<p>{gen_text}</p>
|
79 |
-
</div>
|
80 |
-
</div>
|
81 |
-
'''
|
82 |
-
if image_generation_status:
|
83 |
-
custom_output += f'''
|
84 |
-
<h2>Text->Image:</h2>
|
85 |
-
<div style="display: flex; flex-wrap: wrap;">
|
86 |
-
<div style="flex: 1;">
|
87 |
-
<h3>Generated Image</h3>
|
88 |
-
<img src="data:image/jpeg;base64,{gen_image_str}" width="400" style="vertical-align: middle;">
|
89 |
-
</div>
|
90 |
-
</div>
|
91 |
-
'''
|
92 |
-
return custom_output
|
93 |
-
|
94 |
-
processor = ImageTextTransformation(args)
|
95 |
-
|
96 |
-
# Create Gradio input and output components
|
97 |
-
image_input = gr.inputs.Image(type='filepath', label="Input Image")
|
98 |
-
semantic_segment_checkbox = gr.inputs.Checkbox(label="Semantic Segment", default=False)
|
99 |
-
image_generation_checkbox = gr.inputs.Checkbox(label="Image Generation", default=False)
|
100 |
-
|
101 |
-
|
102 |
-
extra_title = r'![vistors](https://visitor-badge.glitch.me/badge?page_id=fingerrec.Image2Paragraph)' + '\n' + \
|
103 |
-
r'[![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-md-dark.svg)](https://huggingface.co/spaces/Awiny/Image2Paragraph?duplicate=true)' + '\n\n'
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
logo_base64 = add_logo()
|
108 |
-
# Create the title with the logo
|
109 |
-
title_with_logo = \
|
110 |
-
f'<img src="data:image/jpeg;base64,{logo_base64}" width="400" style="vertical-align: middle;"> Understanding Image with Text'
|
111 |
-
|
112 |
-
examples = [
|
113 |
-
["examples/test_4.jpg"],
|
114 |
-
]
|
115 |
-
|
116 |
-
# Create Gradio interface
|
117 |
-
interface = gr.Interface(
|
118 |
-
fn=lambda image, options: process_image(image, options, processor),
|
119 |
-
inputs=[image_input,
|
120 |
-
gr.CheckboxGroup(
|
121 |
-
label="Options",
|
122 |
-
choices=["Image Generation", "Semantic Segment"],
|
123 |
-
),
|
124 |
-
],
|
125 |
-
outputs=gr.outputs.HTML(),
|
126 |
-
title=title_with_logo,
|
127 |
-
examples=examples,
|
128 |
-
description=extra_title +"""
|
129 |
-
Image.txt. This code support image to text transformation. Then the generated text can do retrieval, question answering et al to conduct zero-shot.
|
130 |
-
\n Github: https://github.com/showlab/Image2Paragraph
|
131 |
-
\n Twitter: https://twitter.com/awinyimgprocess/status/1646225454599372800?s=46&t=HvOe9T2n35iFuCHP5aIHpQ
|
132 |
-
\n Since GPU is expensive, we use CPU for demo and not include semantic segment anything. Run code local with gpu or google colab we provided for fast speed.
|
133 |
-
\n Ttext2image model is controlnet ( very slow in cpu(~2m)), which used canny edge as reference.
|
134 |
-
\n To speed up, we generate image with small size 384, run the code local for high-quality sample.
|
135 |
-
"""
|
136 |
-
)
|
137 |
-
|
138 |
-
# Launch the interface
|
139 |
-
interface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/__pycache__/blip2_model.cpython-38.pyc
CHANGED
Binary files a/models/__pycache__/blip2_model.cpython-38.pyc and b/models/__pycache__/blip2_model.cpython-38.pyc differ
|
|
models/__pycache__/controlnet_model.cpython-38.pyc
CHANGED
Binary files a/models/__pycache__/controlnet_model.cpython-38.pyc and b/models/__pycache__/controlnet_model.cpython-38.pyc differ
|
|
models/__pycache__/gpt_model.cpython-38.pyc
CHANGED
Binary files a/models/__pycache__/gpt_model.cpython-38.pyc and b/models/__pycache__/gpt_model.cpython-38.pyc differ
|
|
models/__pycache__/grit_model.cpython-38.pyc
CHANGED
Binary files a/models/__pycache__/grit_model.cpython-38.pyc and b/models/__pycache__/grit_model.cpython-38.pyc differ
|
|
models/__pycache__/image_text_transformation.cpython-38.pyc
CHANGED
Binary files a/models/__pycache__/image_text_transformation.cpython-38.pyc and b/models/__pycache__/image_text_transformation.cpython-38.pyc differ
|
|
models/__pycache__/region_semantic.cpython-38.pyc
CHANGED
Binary files a/models/__pycache__/region_semantic.cpython-38.pyc and b/models/__pycache__/region_semantic.cpython-38.pyc differ
|
|
models/blip2_model.py
CHANGED
@@ -6,28 +6,33 @@ from utils.util import resize_long_edge
|
|
6 |
|
7 |
|
8 |
class ImageCaptioning:
|
9 |
-
def __init__(self, device):
|
10 |
self.device = device
|
|
|
11 |
self.processor, self.model = self.initialize_model()
|
12 |
|
13 |
-
def initialize_model(self):
|
14 |
if self.device == 'cpu':
|
15 |
self.data_type = torch.float32
|
16 |
else:
|
17 |
self.data_type = torch.float16
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
model.to(self.device)
|
26 |
return processor, model
|
27 |
|
28 |
def image_caption(self, image_src):
|
29 |
image = Image.open(image_src)
|
30 |
-
image = resize_long_edge(image)
|
31 |
inputs = self.processor(images=image, return_tensors="pt").to(self.device, self.data_type)
|
32 |
generated_ids = self.model.generate(**inputs)
|
33 |
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
|
|
6 |
|
7 |
|
8 |
class ImageCaptioning:
|
9 |
+
def __init__(self, device, captioner_base_model='blip'):
|
10 |
self.device = device
|
11 |
+
self.captioner_base_model = captioner_base_model
|
12 |
self.processor, self.model = self.initialize_model()
|
13 |
|
14 |
+
def initialize_model(self,):
|
15 |
if self.device == 'cpu':
|
16 |
self.data_type = torch.float32
|
17 |
else:
|
18 |
self.data_type = torch.float16
|
19 |
+
if self.captioner_base_model == 'blip2':
|
20 |
+
processor = Blip2Processor.from_pretrained("pretrained_models/blip2-opt-2.7b")
|
21 |
+
model = Blip2ForConditionalGeneration.from_pretrained(
|
22 |
+
"pretrained_models/blip2-opt-2.7b", torch_dtype=self.data_type
|
23 |
+
)
|
24 |
+
# for gpu with small memory
|
25 |
+
elif self.captioner_base_model == 'blip':
|
26 |
+
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
27 |
+
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=self.data_type)
|
28 |
+
else:
|
29 |
+
raise ValueError('arch not supported')
|
30 |
model.to(self.device)
|
31 |
return processor, model
|
32 |
|
33 |
def image_caption(self, image_src):
|
34 |
image = Image.open(image_src)
|
35 |
+
image = resize_long_edge(image, 384)
|
36 |
inputs = self.processor(images=image, return_tensors="pt").to(self.device, self.data_type)
|
37 |
generated_ids = self.model.generate(**inputs)
|
38 |
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
models/controlnet_model.py
CHANGED
@@ -15,29 +15,21 @@ class TextToImage:
|
|
15 |
self.model = self.initialize_model()
|
16 |
|
17 |
def initialize_model(self):
|
18 |
-
if self.device == 'cpu':
|
19 |
-
self.data_type = torch.float32
|
20 |
-
else:
|
21 |
-
self.data_type = torch.float16
|
22 |
controlnet = ControlNetModel.from_pretrained(
|
23 |
"fusing/stable-diffusion-v1-5-controlnet-canny",
|
24 |
-
torch_dtype=
|
25 |
-
|
26 |
-
).to(self.device)
|
27 |
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
|
28 |
-
# "pretrained_models/stable-diffusion-v1-5",
|
29 |
"runwayml/stable-diffusion-v1-5",
|
30 |
controlnet=controlnet,
|
31 |
safety_checker=None,
|
32 |
-
torch_dtype=
|
33 |
-
map_location=self.device, # Add this line
|
34 |
)
|
35 |
pipeline.scheduler = UniPCMultistepScheduler.from_config(
|
36 |
pipeline.scheduler.config
|
37 |
)
|
|
|
38 |
pipeline.to(self.device)
|
39 |
-
if self.device != 'cpu':
|
40 |
-
pipeline.enable_model_cpu_offload()
|
41 |
return pipeline
|
42 |
|
43 |
@staticmethod
|
|
|
15 |
self.model = self.initialize_model()
|
16 |
|
17 |
def initialize_model(self):
|
|
|
|
|
|
|
|
|
18 |
controlnet = ControlNetModel.from_pretrained(
|
19 |
"fusing/stable-diffusion-v1-5-controlnet-canny",
|
20 |
+
torch_dtype=torch.float16,
|
21 |
+
)
|
|
|
22 |
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
|
|
|
23 |
"runwayml/stable-diffusion-v1-5",
|
24 |
controlnet=controlnet,
|
25 |
safety_checker=None,
|
26 |
+
torch_dtype=torch.float16,
|
|
|
27 |
)
|
28 |
pipeline.scheduler = UniPCMultistepScheduler.from_config(
|
29 |
pipeline.scheduler.config
|
30 |
)
|
31 |
+
pipeline.enable_model_cpu_offload()
|
32 |
pipeline.to(self.device)
|
|
|
|
|
33 |
return pipeline
|
34 |
|
35 |
@staticmethod
|
models/gpt_model.py
CHANGED
@@ -17,7 +17,7 @@ class ImageToText:
|
|
17 |
Use nouns rather than coordinates to show position information of each object.
|
18 |
No more than 7 sentences.
|
19 |
Only use one paragraph.
|
20 |
-
Describe position
|
21 |
Do not appear number.
|
22 |
"""
|
23 |
template = f"{prompt_prefix_1}{prompt_prefix_2}{{width}}X{{height}}{prompt_prefix_3}{{caption}}{prompt_prefix_4}{{dense_caption}}{prompt_prefix_5}{{region_semantic}}{prompt_suffix}"
|
|
|
17 |
Use nouns rather than coordinates to show position information of each object.
|
18 |
No more than 7 sentences.
|
19 |
Only use one paragraph.
|
20 |
+
Describe position of each object.
|
21 |
Do not appear number.
|
22 |
"""
|
23 |
template = f"{prompt_prefix_1}{prompt_prefix_2}{{width}}X{{height}}{prompt_prefix_3}{{caption}}{prompt_prefix_4}{{dense_caption}}{prompt_prefix_5}{{region_semantic}}{prompt_suffix}"
|
models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc
CHANGED
Binary files a/models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc and b/models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc differ
|
|
models/grit_src/image_dense_captions.py
CHANGED
@@ -16,6 +16,7 @@ from models.grit_src.grit.config import add_grit_config
|
|
16 |
|
17 |
from models.grit_src.grit.predictor import VisualizationDemo
|
18 |
import json
|
|
|
19 |
|
20 |
|
21 |
# constants
|
@@ -62,6 +63,7 @@ def image_caption_api(image_src, device):
|
|
62 |
demo = VisualizationDemo(cfg)
|
63 |
if image_src:
|
64 |
img = read_image(image_src, format="BGR")
|
|
|
65 |
predictions, visualized_output = demo.run_on_image(img)
|
66 |
new_caption = dense_pred_to_caption(predictions)
|
67 |
return new_caption
|
|
|
16 |
|
17 |
from models.grit_src.grit.predictor import VisualizationDemo
|
18 |
import json
|
19 |
+
from utils.util import resize_long_edge_cv2
|
20 |
|
21 |
|
22 |
# constants
|
|
|
63 |
demo = VisualizationDemo(cfg)
|
64 |
if image_src:
|
65 |
img = read_image(image_src, format="BGR")
|
66 |
+
img = resize_long_edge_cv2(img, 384)
|
67 |
predictions, visualized_output = demo.run_on_image(img)
|
68 |
new_caption = dense_pred_to_caption(predictions)
|
69 |
return new_caption
|
models/image_text_transformation.py
CHANGED
@@ -3,13 +3,12 @@ from models.grit_model import DenseCaptioning
|
|
3 |
from models.gpt_model import ImageToText
|
4 |
from models.controlnet_model import TextToImage
|
5 |
from models.region_semantic import RegionSemantic
|
6 |
-
from utils.util import read_image_width_height, display_images_and_text
|
7 |
import argparse
|
8 |
from PIL import Image
|
9 |
import base64
|
10 |
from io import BytesIO
|
11 |
import os
|
12 |
-
from utils.util import resize_long_edge
|
13 |
|
14 |
def pil_image_to_base64(image):
|
15 |
buffered = BytesIO()
|
@@ -27,23 +26,23 @@ class ImageTextTransformation:
|
|
27 |
|
28 |
def init_models(self):
|
29 |
openai_key = os.environ['OPENAI_KEY']
|
|
|
30 |
print('\033[1;34m' + "Welcome to the Image2Paragraph toolbox...".center(50, '-') + '\033[0m')
|
31 |
print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')
|
32 |
print('\033[1;31m' + "This is time-consuming, please wait...".center(50, '-') + '\033[0m')
|
33 |
-
self.image_caption_model = ImageCaptioning(device=self.args.image_caption_device)
|
34 |
self.dense_caption_model = DenseCaptioning(device=self.args.dense_caption_device)
|
35 |
self.gpt_model = ImageToText(openai_key)
|
36 |
self.controlnet_model = TextToImage(device=self.args.contolnet_device)
|
37 |
-
|
38 |
-
if self.args.semantic_segment:
|
39 |
-
self.region_semantic_model = RegionSemantic(device=self.args.semantic_segment_device)
|
40 |
print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')
|
41 |
|
42 |
|
43 |
def image_to_text(self, img_src):
|
44 |
# the information to generate paragraph based on the context
|
45 |
self.ref_image = Image.open(img_src)
|
46 |
-
|
|
|
47 |
width, height = read_image_width_height(img_src)
|
48 |
print(self.args)
|
49 |
if self.args.image_caption:
|
|
|
3 |
from models.gpt_model import ImageToText
|
4 |
from models.controlnet_model import TextToImage
|
5 |
from models.region_semantic import RegionSemantic
|
6 |
+
from utils.util import read_image_width_height, display_images_and_text, resize_long_edge
|
7 |
import argparse
|
8 |
from PIL import Image
|
9 |
import base64
|
10 |
from io import BytesIO
|
11 |
import os
|
|
|
12 |
|
13 |
def pil_image_to_base64(image):
|
14 |
buffered = BytesIO()
|
|
|
26 |
|
27 |
def init_models(self):
|
28 |
openai_key = os.environ['OPENAI_KEY']
|
29 |
+
print(self.args)
|
30 |
print('\033[1;34m' + "Welcome to the Image2Paragraph toolbox...".center(50, '-') + '\033[0m')
|
31 |
print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')
|
32 |
print('\033[1;31m' + "This is time-consuming, please wait...".center(50, '-') + '\033[0m')
|
33 |
+
self.image_caption_model = ImageCaptioning(device=self.args.image_caption_device, captioner_base_model=self.args.captioner_base_model)
|
34 |
self.dense_caption_model = DenseCaptioning(device=self.args.dense_caption_device)
|
35 |
self.gpt_model = ImageToText(openai_key)
|
36 |
self.controlnet_model = TextToImage(device=self.args.contolnet_device)
|
37 |
+
self.region_semantic_model = RegionSemantic(device=self.args.semantic_segment_device, image_caption_model=self.image_caption_model, region_classify_model=self.args.region_classify_model, sam_arch=self.args.sam_arch)
|
|
|
|
|
38 |
print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')
|
39 |
|
40 |
|
41 |
def image_to_text(self, img_src):
|
42 |
# the information to generate paragraph based on the context
|
43 |
self.ref_image = Image.open(img_src)
|
44 |
+
# resize image to long edge 384
|
45 |
+
self.ref_image = resize_long_edge(self.ref_image, 384)
|
46 |
width, height = read_image_width_height(img_src)
|
47 |
print(self.args)
|
48 |
if self.args.image_caption:
|
models/region_semantic.py
CHANGED
@@ -1,17 +1,27 @@
|
|
1 |
from models.segment_models.semgent_anything_model import SegmentAnything
|
2 |
from models.segment_models.semantic_segment_anything_model import SemanticSegment
|
|
|
3 |
|
4 |
|
5 |
class RegionSemantic():
|
6 |
-
def __init__(self, device):
|
7 |
self.device = device
|
|
|
|
|
|
|
8 |
self.init_models()
|
9 |
|
10 |
def init_models(self):
|
11 |
-
self.segment_model = SegmentAnything(self.device)
|
12 |
-
self.
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
"""
|
16 |
fliter too small objects and objects with low stability score
|
17 |
anns: [{'class_name': 'person', 'bbox': [0.0, 0.0, 0.0, 0.0], 'size': [0, 0], 'stability_score': 0.0}, ...]
|
@@ -19,20 +29,32 @@ class RegionSemantic():
|
|
19 |
"""
|
20 |
# Sort annotations by area in descending order
|
21 |
sorted_annotations = sorted(anns, key=lambda x: x['area'], reverse=True)
|
|
|
22 |
# Select the top 10 largest regions
|
23 |
-
top_10_largest_regions = sorted_annotations[:
|
24 |
semantic_prompt = ""
|
25 |
-
print('\033[1;35m' + '*' * 100 + '\033[0m')
|
26 |
-
print("\nStep3, Semantic Prompt:")
|
27 |
for region in top_10_largest_regions:
|
28 |
semantic_prompt += region['class_name'] + ': ' + str(region['bbox']) + "; "
|
29 |
print(semantic_prompt)
|
30 |
print('\033[1;35m' + '*' * 100 + '\033[0m')
|
31 |
return semantic_prompt
|
32 |
|
33 |
-
def region_semantic(self, img_src):
|
|
|
|
|
|
|
34 |
anns = self.segment_model.generate_mask(img_src)
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
return self.semantic_prompt_gen(anns_w_class)
|
37 |
|
38 |
def region_semantic_debug(self, img_src):
|
|
|
1 |
from models.segment_models.semgent_anything_model import SegmentAnything
|
2 |
from models.segment_models.semantic_segment_anything_model import SemanticSegment
|
3 |
+
from models.segment_models.edit_anything_model import EditAnything
|
4 |
|
5 |
|
6 |
class RegionSemantic():
|
7 |
+
def __init__(self, device, image_caption_model, region_classify_model='edit_anything', sam_arch='vit_b'):
|
8 |
self.device = device
|
9 |
+
self.sam_arch = sam_arch
|
10 |
+
self.image_caption_model = image_caption_model
|
11 |
+
self.region_classify_model = region_classify_model
|
12 |
self.init_models()
|
13 |
|
14 |
def init_models(self):
|
15 |
+
self.segment_model = SegmentAnything(self.device, arch=self.sam_arch)
|
16 |
+
if self.region_classify_model == 'ssa':
|
17 |
+
self.semantic_segment_model = SemanticSegment(self.device)
|
18 |
+
elif self.region_classify_model == 'edit_anything':
|
19 |
+
self.edit_anything_model = EditAnything(self.image_caption_model)
|
20 |
+
print('initalize edit anything model')
|
21 |
+
else:
|
22 |
+
raise ValueError("semantic_class_model must be 'ssa' or 'edit_anything'")
|
23 |
+
|
24 |
+
def semantic_prompt_gen(self, anns, topk=5):
|
25 |
"""
|
26 |
fliter too small objects and objects with low stability score
|
27 |
anns: [{'class_name': 'person', 'bbox': [0.0, 0.0, 0.0, 0.0], 'size': [0, 0], 'stability_score': 0.0}, ...]
|
|
|
29 |
"""
|
30 |
# Sort annotations by area in descending order
|
31 |
sorted_annotations = sorted(anns, key=lambda x: x['area'], reverse=True)
|
32 |
+
anns_len = len(sorted_annotations)
|
33 |
# Select the top 10 largest regions
|
34 |
+
top_10_largest_regions = sorted_annotations[:min(anns_len, topk)]
|
35 |
semantic_prompt = ""
|
|
|
|
|
36 |
for region in top_10_largest_regions:
|
37 |
semantic_prompt += region['class_name'] + ': ' + str(region['bbox']) + "; "
|
38 |
print(semantic_prompt)
|
39 |
print('\033[1;35m' + '*' * 100 + '\033[0m')
|
40 |
return semantic_prompt
|
41 |
|
42 |
+
def region_semantic(self, img_src, region_classify_model='edit_anything'):
|
43 |
+
print('\033[1;35m' + '*' * 100 + '\033[0m')
|
44 |
+
print("\nStep3, Semantic Prompt:")
|
45 |
+
print('extract region segmentation with SAM model....\n')
|
46 |
anns = self.segment_model.generate_mask(img_src)
|
47 |
+
print('finished...\n')
|
48 |
+
if region_classify_model == 'ssa':
|
49 |
+
print('generate region supervision with blip2 model....\n')
|
50 |
+
anns_w_class = self.semantic_segment_model.semantic_class_w_mask(img_src, anns)
|
51 |
+
print('finished...\n')
|
52 |
+
elif region_classify_model == 'edit_anything':
|
53 |
+
print('generate region supervision with edit anything model....\n')
|
54 |
+
anns_w_class = self.edit_anything_model.semantic_class_w_mask(img_src, anns)
|
55 |
+
print('finished...\n')
|
56 |
+
else:
|
57 |
+
raise ValueError("semantic_class_model must be 'ssa' or 'edit_anything'")
|
58 |
return self.semantic_prompt_gen(anns_w_class)
|
59 |
|
60 |
def region_semantic_debug(self, img_src):
|
models/segment_models/__pycache__/edit_anything_model.cpython-38.pyc
ADDED
Binary file (3.62 kB). View file
|
|
models/segment_models/__pycache__/semantic_segment_anything_model.cpython-38.pyc
CHANGED
Binary files a/models/segment_models/__pycache__/semantic_segment_anything_model.cpython-38.pyc and b/models/segment_models/__pycache__/semantic_segment_anything_model.cpython-38.pyc differ
|
|
models/segment_models/__pycache__/semgent_anything_model.cpython-38.pyc
CHANGED
Binary files a/models/segment_models/__pycache__/semgent_anything_model.cpython-38.pyc and b/models/segment_models/__pycache__/semgent_anything_model.cpython-38.pyc differ
|
|
models/segment_models/edit_anything_model.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import mmcv
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from utils.util import resize_long_edge
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
import time
|
9 |
+
|
10 |
+
class EditAnything:
|
11 |
+
def __init__(self, image_caption_model):
|
12 |
+
self.device = image_caption_model.device
|
13 |
+
self.data_type = image_caption_model.data_type
|
14 |
+
self.image_caption_model = image_caption_model
|
15 |
+
|
16 |
+
def region_classify_w_blip2(self, images):
|
17 |
+
inputs = self.image_caption_model.processor(images=images, return_tensors="pt").to(self.device, self.data_type)
|
18 |
+
generated_ids = self.image_caption_model.model.generate(**inputs)
|
19 |
+
generated_texts = self.image_caption_model.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
20 |
+
return [text.strip() for text in generated_texts]
|
21 |
+
|
22 |
+
def process_ann(self, ann, image, target_size=(224, 224)):
|
23 |
+
start_time = time.time()
|
24 |
+
m = ann['segmentation']
|
25 |
+
m_3c = m[:, :, np.newaxis]
|
26 |
+
m_3c = np.concatenate((m_3c, m_3c, m_3c), axis=2)
|
27 |
+
bbox = ann['bbox']
|
28 |
+
region = mmcv.imcrop(image * m_3c, np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]), scale=1)
|
29 |
+
resized_region = mmcv.imresize(region, target_size)
|
30 |
+
end_time = time.time()
|
31 |
+
print("process_ann took {:.2f} seconds".format(end_time - start_time))
|
32 |
+
return resized_region, ann
|
33 |
+
|
34 |
+
def region_level_semantic_api(self, image, anns, topk=5):
|
35 |
+
"""
|
36 |
+
rank regions by area, and classify each region with blip2, parallel processing for speed up
|
37 |
+
Args:
|
38 |
+
image: numpy array
|
39 |
+
topk: int
|
40 |
+
Returns:
|
41 |
+
topk_region_w_class_label: list of dict with key 'class_label'
|
42 |
+
"""
|
43 |
+
start_time = time.time()
|
44 |
+
if len(anns) == 0:
|
45 |
+
return []
|
46 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
47 |
+
topk_anns = sorted_anns[:min(topk, len(sorted_anns))]
|
48 |
+
with ThreadPoolExecutor() as executor:
|
49 |
+
regions_and_anns = list(executor.map(lambda ann: self.process_ann(ann, image), topk_anns))
|
50 |
+
regions = [region for region, _ in regions_and_anns]
|
51 |
+
region_class_labels = self.region_classify_w_blip2(regions)
|
52 |
+
for (region, ann), class_label in zip(regions_and_anns, region_class_labels):
|
53 |
+
ann['class_name'] = class_label
|
54 |
+
end_time = time.time()
|
55 |
+
print("region_level_semantic_api took {:.2f} seconds".format(end_time - start_time))
|
56 |
+
|
57 |
+
return [ann for _, ann in regions_and_anns]
|
58 |
+
|
59 |
+
def semantic_class_w_mask(self, img_src, anns):
|
60 |
+
image = Image.open(img_src)
|
61 |
+
image = resize_long_edge(image, 384)
|
62 |
+
return self.region_level_semantic_api(image, anns)
|
models/segment_models/semantic_segment_anything_model.py
CHANGED
@@ -10,6 +10,7 @@ from PIL import Image
|
|
10 |
import pycocotools.mask as maskUtils
|
11 |
from models.segment_models.configs.ade20k_id2label import CONFIG as CONFIG_ADE20K_ID2LABEL
|
12 |
from models.segment_models.configs.coco_id2label import CONFIG as CONFIG_COCO_ID2LABEL
|
|
|
13 |
# from mmdet.core.visualization.image import imshow_det_bboxes # comment this line if you don't use mmdet
|
14 |
|
15 |
nlp = spacy.load('en_core_web_sm')
|
@@ -113,6 +114,7 @@ class SemanticSegment():
|
|
113 |
:return: dict('segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box', "class_name", "class_proposals"})
|
114 |
"""
|
115 |
img = mmcv.imread(img_src)
|
|
|
116 |
oneformer_coco_seg = self.oneformer_segmentation(Image.fromarray(img), self.oneformer_coco_processor, self.oneformer_coco_model)
|
117 |
oneformer_ade20k_seg = self.oneformer_segmentation(Image.fromarray(img), self.oneformer_ade20k_processor, self.oneformer_ade20k_model)
|
118 |
bitmasks, class_names = [], []
|
|
|
10 |
import pycocotools.mask as maskUtils
|
11 |
from models.segment_models.configs.ade20k_id2label import CONFIG as CONFIG_ADE20K_ID2LABEL
|
12 |
from models.segment_models.configs.coco_id2label import CONFIG as CONFIG_COCO_ID2LABEL
|
13 |
+
from utils.util import resize_long_edge, resize_long_edge_cv2
|
14 |
# from mmdet.core.visualization.image import imshow_det_bboxes # comment this line if you don't use mmdet
|
15 |
|
16 |
nlp = spacy.load('en_core_web_sm')
|
|
|
114 |
:return: dict('segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box', "class_name", "class_proposals"})
|
115 |
"""
|
116 |
img = mmcv.imread(img_src)
|
117 |
+
img = resize_long_edge_cv2(img, 384)
|
118 |
oneformer_coco_seg = self.oneformer_segmentation(Image.fromarray(img), self.oneformer_coco_processor, self.oneformer_coco_model)
|
119 |
oneformer_ade20k_seg = self.oneformer_segmentation(Image.fromarray(img), self.oneformer_ade20k_processor, self.oneformer_ade20k_model)
|
120 |
bitmasks, class_names = [], []
|
models/segment_models/semgent_anything_model.py
CHANGED
@@ -1,10 +1,18 @@
|
|
1 |
import cv2
|
2 |
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
3 |
-
import
|
4 |
|
5 |
class SegmentAnything:
|
6 |
-
def __init__(self, device, arch="
|
7 |
self.device = device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
self.model = self.initialize_model(arch, pretrained_weights)
|
9 |
|
10 |
def initialize_model(self, arch, pretrained_weights):
|
@@ -16,5 +24,6 @@ class SegmentAnything:
|
|
16 |
def generate_mask(self, img_src):
|
17 |
image = cv2.imread(img_src)
|
18 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
19 |
anns = self.model.generate(image)
|
20 |
return anns
|
|
|
1 |
import cv2
|
2 |
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
3 |
+
from utils.util import resize_long_edge_cv2
|
4 |
|
5 |
class SegmentAnything:
|
6 |
+
def __init__(self, device, arch="vit_b"):
|
7 |
self.device = device
|
8 |
+
if arch=='vit_b':
|
9 |
+
pretrained_weights="pretrained_models/sam_vit_b_01ec64.pth"
|
10 |
+
elif arch=='vit_l':
|
11 |
+
pretrained_weights="pretrained_models/sam_vit_l_0e2f7b.pth"
|
12 |
+
elif arch=='vit_h':
|
13 |
+
pretrained_weights="pretrained_models/sam_vit_h_0e2f7b.pth"
|
14 |
+
else:
|
15 |
+
raise ValueError(f"arch {arch} not supported")
|
16 |
self.model = self.initialize_model(arch, pretrained_weights)
|
17 |
|
18 |
def initialize_model(self, arch, pretrained_weights):
|
|
|
24 |
def generate_mask(self, img_src):
|
25 |
image = cv2.imread(img_src)
|
26 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
27 |
+
image = resize_long_edge_cv2(image, 384)
|
28 |
anns = self.model.generate(image)
|
29 |
return anns
|
pretrained_models/sam_vit_b_01ec64.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
|
3 |
+
size 375042383
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
--extra-index-url https://download.pytorch.org/whl
|
2 |
torch==1.9.0+cu111
|
3 |
torchvision==0.10.0+cu111
|
|
|
1 |
+
# This file only test on Linux
|
2 |
--extra-index-url https://download.pytorch.org/whl
|
3 |
torch==1.9.0+cu111
|
4 |
torchvision==0.10.0+cu111
|
utils/__pycache__/util.cpython-38.pyc
CHANGED
Binary files a/utils/__pycache__/util.cpython-38.pyc and b/utils/__pycache__/util.cpython-38.pyc differ
|
|
utils/image_dense_captions.py
DELETED
@@ -1,108 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import multiprocessing as mp
|
3 |
-
import os
|
4 |
-
import time
|
5 |
-
import cv2
|
6 |
-
import tqdm
|
7 |
-
import sys
|
8 |
-
|
9 |
-
from detectron2.config import get_cfg
|
10 |
-
from detectron2.data.detection_utils import read_image
|
11 |
-
from detectron2.utils.logger import setup_logger
|
12 |
-
|
13 |
-
sys.path.insert(0, 'third_party/CenterNet2/projects/CenterNet2/')
|
14 |
-
from centernet.config import add_centernet_config
|
15 |
-
from grit.config import add_grit_config
|
16 |
-
|
17 |
-
from grit.predictor import VisualizationDemo
|
18 |
-
import json
|
19 |
-
|
20 |
-
|
21 |
-
# constants
|
22 |
-
WINDOW_NAME = "GRiT"
|
23 |
-
|
24 |
-
|
25 |
-
def dense_pred_to_caption(predictions):
|
26 |
-
boxes = predictions["instances"].pred_boxes if predictions["instances"].has("pred_boxes") else None
|
27 |
-
object_description = predictions["instances"].pred_object_descriptions.data
|
28 |
-
new_caption = ""
|
29 |
-
for i in range(len(object_description)):
|
30 |
-
new_caption += (object_description[i] + ": " + str([int(a) for a in boxes[i].tensor.cpu().detach().numpy()[0]])) + "; "
|
31 |
-
return new_caption
|
32 |
-
|
33 |
-
def setup_cfg(args):
|
34 |
-
cfg = get_cfg()
|
35 |
-
if args.cpu:
|
36 |
-
cfg.MODEL.DEVICE="cpu"
|
37 |
-
add_centernet_config(cfg)
|
38 |
-
add_grit_config(cfg)
|
39 |
-
cfg.merge_from_file(args.config_file)
|
40 |
-
cfg.merge_from_list(args.opts)
|
41 |
-
# Set score_threshold for builtin models
|
42 |
-
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
|
43 |
-
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
|
44 |
-
if args.test_task:
|
45 |
-
cfg.MODEL.TEST_TASK = args.test_task
|
46 |
-
cfg.MODEL.BEAM_SIZE = 1
|
47 |
-
cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
|
48 |
-
cfg.USE_ACT_CHECKPOINT = False
|
49 |
-
cfg.freeze()
|
50 |
-
return cfg
|
51 |
-
|
52 |
-
|
53 |
-
def get_parser():
|
54 |
-
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
|
55 |
-
parser.add_argument(
|
56 |
-
"--config-file",
|
57 |
-
default="",
|
58 |
-
metavar="FILE",
|
59 |
-
help="path to config file",
|
60 |
-
)
|
61 |
-
parser.add_argument("--cpu", action='store_true', help="Use CPU only.")
|
62 |
-
parser.add_argument(
|
63 |
-
"--image_src",
|
64 |
-
default="../examples/1.jpg",
|
65 |
-
help="Input json file include 'image' and 'caption'; "
|
66 |
-
)
|
67 |
-
# "/home/aiops/wangjp/Code/LLP/annotation/coco_karpathy_test_dense_caption.json", "/home/aiops/wangjp/Code/LLP/annotation/coco_karpathy_train_dense_caption.json"
|
68 |
-
parser.add_argument(
|
69 |
-
"--confidence-threshold",
|
70 |
-
type=float,
|
71 |
-
default=0.5,
|
72 |
-
help="Minimum score for instance predictions to be shown",
|
73 |
-
)
|
74 |
-
parser.add_argument(
|
75 |
-
"--test-task",
|
76 |
-
type=str,
|
77 |
-
default='',
|
78 |
-
help="Choose a task to have GRiT perform",
|
79 |
-
)
|
80 |
-
parser.add_argument(
|
81 |
-
"--opts",
|
82 |
-
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
83 |
-
default=[],
|
84 |
-
nargs=argparse.REMAINDER,
|
85 |
-
)
|
86 |
-
return parser
|
87 |
-
|
88 |
-
|
89 |
-
if __name__ == "__main__":
|
90 |
-
mp.set_start_method("spawn", force=True)
|
91 |
-
args = get_parser().parse_args()
|
92 |
-
setup_logger(name="fvcore")
|
93 |
-
logger = setup_logger()
|
94 |
-
logger.info("Arguments: " + str(args))
|
95 |
-
|
96 |
-
cfg = setup_cfg(args)
|
97 |
-
demo = VisualizationDemo(cfg)
|
98 |
-
if args.image_src:
|
99 |
-
img = read_image(args.image_src, format="BGR")
|
100 |
-
start_time = time.time()
|
101 |
-
predictions, visualized_output = demo.run_on_image(img)
|
102 |
-
new_caption = dense_pred_to_caption(predictions)
|
103 |
-
print(new_caption)
|
104 |
-
|
105 |
-
output_file = os.path.expanduser("~/grit_output.txt")
|
106 |
-
with open(output_file, 'w') as f:
|
107 |
-
f.write(new_caption)
|
108 |
-
# sys.exit(new_caption)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/util.py
CHANGED
@@ -14,7 +14,6 @@ def read_image_width_height(image_path):
|
|
14 |
width, height = image.size
|
15 |
return width, height
|
16 |
|
17 |
-
|
18 |
def resize_long_edge(image, target_size=384):
|
19 |
# Calculate the aspect ratio
|
20 |
width, height = image.size
|
@@ -32,6 +31,20 @@ def resize_long_edge(image, target_size=384):
|
|
32 |
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
|
33 |
return resized_image
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def display_images_and_text(source_image_path, generated_image, generated_paragraph, outfile_name):
|
36 |
source_image = Image.open(source_image_path)
|
37 |
# Create a new image that can fit the images and the text
|
|
|
14 |
width, height = image.size
|
15 |
return width, height
|
16 |
|
|
|
17 |
def resize_long_edge(image, target_size=384):
|
18 |
# Calculate the aspect ratio
|
19 |
width, height = image.size
|
|
|
31 |
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
|
32 |
return resized_image
|
33 |
|
34 |
+
def resize_long_edge_cv2(image, target_size=384):
|
35 |
+
height, width = image.shape[:2]
|
36 |
+
aspect_ratio = float(width) / float(height)
|
37 |
+
|
38 |
+
if height > width:
|
39 |
+
new_height = target_size
|
40 |
+
new_width = int(target_size * aspect_ratio)
|
41 |
+
else:
|
42 |
+
new_width = target_size
|
43 |
+
new_height = int(target_size / aspect_ratio)
|
44 |
+
|
45 |
+
resized_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
46 |
+
return resized_image
|
47 |
+
|
48 |
def display_images_and_text(source_image_path, generated_image, generated_paragraph, outfile_name):
|
49 |
source_image = Image.open(source_image_path)
|
50 |
# Create a new image that can fit the images and the text
|