add_confidence_score (#56)
Browse files- add confidence score parsing (2ce7cd7837dfc8d93bf1d77dae95669ef1bcf0b3)
- README.md +38 -0
- processing_florence2.py +83 -24
README.md
CHANGED
@@ -190,6 +190,44 @@ prompt = "<OCR_WITH_REGION>"
|
|
190 |
run_example(prompt)
|
191 |
```
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
for More detailed examples, please refer to [notebook](https://huggingface.co/microsoft/Florence-2-large/blob/main/sample_inference.ipynb)
|
194 |
</details>
|
195 |
|
|
|
190 |
run_example(prompt)
|
191 |
```
|
192 |
|
193 |
+
### Output confidence score with Object Detection
|
194 |
+
```python
|
195 |
+
|
196 |
+
def run_example_with_score(task_prompt, text_input=None):
|
197 |
+
if text_input is None:
|
198 |
+
prompt = task_prompt
|
199 |
+
else:
|
200 |
+
prompt = task_prompt + text_input
|
201 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
|
202 |
+
generated_ids = model.generate(
|
203 |
+
input_ids=inputs["input_ids"],
|
204 |
+
pixel_values=inputs["pixel_values"],
|
205 |
+
max_new_tokens=1024,
|
206 |
+
num_beams=3,
|
207 |
+
return_dict_in_generate=True,
|
208 |
+
output_scores=True,
|
209 |
+
)
|
210 |
+
generated_text = processor.batch_decode(generated_ids.sequences, skip_special_tokens=False)[0]
|
211 |
+
|
212 |
+
prediction, scores, beam_indices = generated_ids.sequences, generated_ids.scores, generated_ids.beam_indices
|
213 |
+
transition_beam_scores = model.compute_transition_scores(
|
214 |
+
sequences=prediction,
|
215 |
+
scores=scores,
|
216 |
+
beam_indices=beam_indices,
|
217 |
+
)
|
218 |
+
|
219 |
+
parsed_answer = processor.post_process_generation(sequence=generated_ids.sequences[0],
|
220 |
+
transition_beam_score=transition_beam_scores[0],
|
221 |
+
task=task_prompt, image_size=(image.width, image.height)
|
222 |
+
)
|
223 |
+
|
224 |
+
print(parsed_answer)
|
225 |
+
|
226 |
+
prompt = "<OD>"
|
227 |
+
run_example_with_score(prompt)
|
228 |
+
|
229 |
+
```
|
230 |
+
|
231 |
for More detailed examples, please refer to [notebook](https://huggingface.co/microsoft/Florence-2-large/blob/main/sample_inference.ipynb)
|
232 |
</details>
|
233 |
|
processing_florence2.py
CHANGED
@@ -20,6 +20,7 @@ import re
|
|
20 |
import logging
|
21 |
from typing import List, Optional, Union
|
22 |
import numpy as np
|
|
|
23 |
|
24 |
import torch
|
25 |
|
@@ -32,6 +33,7 @@ from transformers.tokenization_utils_base import (
|
|
32 |
TextInput,
|
33 |
TruncationStrategy,
|
34 |
)
|
|
|
35 |
from transformers.utils import TensorType
|
36 |
|
37 |
|
@@ -304,7 +306,7 @@ class Florence2Processor(ProcessorMixin):
|
|
304 |
image_processor_input_names = self.image_processor.model_input_names
|
305 |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
306 |
|
307 |
-
def post_process_generation(self, text, task, image_size):
|
308 |
"""
|
309 |
Post-process the output of the model to each of the task outputs.
|
310 |
|
@@ -317,6 +319,8 @@ class Florence2Processor(ProcessorMixin):
|
|
317 |
task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
|
318 |
task_answer = self.post_processor(
|
319 |
text=text,
|
|
|
|
|
320 |
image_size=image_size,
|
321 |
parse_tasks=task_answer_post_processing_type,
|
322 |
)[task_answer_post_processing_type]
|
@@ -330,6 +334,9 @@ class Florence2Processor(ProcessorMixin):
|
|
330 |
bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
|
331 |
labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
|
332 |
final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
|
|
|
|
|
|
|
333 |
elif task_answer_post_processing_type in ['ocr']:
|
334 |
bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
|
335 |
labels = [str(_od_instance['text']) for _od_instance in task_answer]
|
@@ -591,7 +598,8 @@ class Florence2PostProcesser(object):
|
|
591 |
'PARSE_TASKS': [
|
592 |
{
|
593 |
'TASK_NAME': 'od',
|
594 |
-
'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
|
|
|
595 |
},
|
596 |
{
|
597 |
'TASK_NAME': 'ocr',
|
@@ -607,6 +615,7 @@ class Florence2PostProcesser(object):
|
|
607 |
},
|
608 |
{
|
609 |
'TASK_NAME': 'description_with_bboxes',
|
|
|
610 |
},
|
611 |
{
|
612 |
'TASK_NAME': 'description_with_polygons',
|
@@ -647,10 +656,6 @@ class Florence2PostProcesser(object):
|
|
647 |
filtered_tokens = tokenizer.convert_ids_to_tokens(
|
648 |
token_ids, skip_special_tokens=False)
|
649 |
assert len(filtered_tokens) == len(token_ids)
|
650 |
-
|
651 |
-
# To avoid mixing byte-level and unicode for byte-level BPT
|
652 |
-
# we need to build string separately for added tokens and byte-level tokens
|
653 |
-
# cf. https://github.com/huggingface/transformers/issues/1133
|
654 |
sub_texts = []
|
655 |
for token in filtered_tokens:
|
656 |
if token in self.all_special_tokens:
|
@@ -658,10 +663,6 @@ class Florence2PostProcesser(object):
|
|
658 |
else:
|
659 |
if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
|
660 |
sub_text = tokenizer.convert_tokens_to_string([token])
|
661 |
-
elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
|
662 |
-
# Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
|
663 |
-
# Note: Do not strip sub_text as it may have functional whitespace
|
664 |
-
sub_text = token.replace('▁', ' ')
|
665 |
else:
|
666 |
raise ValueError(f'type {type(tokenizer)} not supported')
|
667 |
sub_texts.append(sub_text)
|
@@ -672,14 +673,6 @@ class Florence2PostProcesser(object):
|
|
672 |
span = (len(text), len(text) + len(sub_text)) # [start index, end index).
|
673 |
text += sub_text
|
674 |
spans.append(span)
|
675 |
-
|
676 |
-
# Text format:
|
677 |
-
# 1. T5Tokenizer/T5TokenizerFast:
|
678 |
-
# "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
|
679 |
-
# Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
|
680 |
-
# 2. BartTokenizer (need to double check):
|
681 |
-
# "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
|
682 |
-
# Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
|
683 |
return text, spans
|
684 |
|
685 |
def parse_od_from_text_and_spans(
|
@@ -714,7 +707,7 @@ class Florence2PostProcesser(object):
|
|
714 |
return instances
|
715 |
|
716 |
def parse_ocr_from_text_and_spans(self,
|
717 |
-
|
718 |
pattern,
|
719 |
image_size,
|
720 |
area_threshold=-1.0,
|
@@ -818,9 +811,26 @@ class Florence2PostProcesser(object):
|
|
818 |
|
819 |
return instances
|
820 |
|
821 |
-
def parse_description_with_bboxes_from_text_and_spans(
|
822 |
-
|
823 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
824 |
|
825 |
text = text.replace('<s>', '')
|
826 |
text = text.replace('</s>', '')
|
@@ -842,13 +852,16 @@ class Florence2PostProcesser(object):
|
|
842 |
phrase_text_strip = pharse_text.replace('<obj>', '', 1)
|
843 |
|
844 |
if phrase_text_strip == '' and not allow_empty_phrase:
|
|
|
845 |
continue
|
846 |
|
847 |
# parse phrase, get string
|
848 |
phrase = re.search(pattern, phrase_text_strip)
|
849 |
if phrase is None:
|
|
|
850 |
continue
|
851 |
|
|
|
852 |
phrase = phrase.group()
|
853 |
# remove leading and trailing spaces
|
854 |
phrase = phrase.strip()
|
@@ -856,6 +869,7 @@ class Florence2PostProcesser(object):
|
|
856 |
# parse bboxes by box_pattern
|
857 |
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
|
858 |
if len(bboxes_parsed) == 0:
|
|
|
859 |
continue
|
860 |
|
861 |
# a list of list
|
@@ -866,14 +880,42 @@ class Florence2PostProcesser(object):
|
|
866 |
size=image_size
|
867 |
).tolist()
|
868 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
869 |
phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
|
870 |
-
for _bboxes in bboxes:
|
871 |
# Prepare instance.
|
872 |
instance = {}
|
873 |
instance['bbox'] = _bboxes
|
874 |
# exclude non-ascii characters
|
875 |
instance['cat_name'] = phrase
|
|
|
|
|
876 |
instances.append(instance)
|
|
|
|
|
877 |
|
878 |
return instances
|
879 |
|
@@ -991,6 +1033,8 @@ class Florence2PostProcesser(object):
|
|
991 |
def __call__(
|
992 |
self,
|
993 |
text=None,
|
|
|
|
|
994 |
image_size=None,
|
995 |
parse_tasks=None,
|
996 |
):
|
@@ -1008,7 +1052,18 @@ class Florence2PostProcesser(object):
|
|
1008 |
assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
|
1009 |
|
1010 |
# sequence or text should be provided
|
1011 |
-
assert text is not None, 'text should be provided'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1012 |
|
1013 |
parsed_dict = {
|
1014 |
'text': text
|
@@ -1019,6 +1074,7 @@ class Florence2PostProcesser(object):
|
|
1019 |
continue
|
1020 |
|
1021 |
pattern = self.parse_tasks_configs[task].get('PATTERN', None)
|
|
|
1022 |
|
1023 |
if task == 'ocr':
|
1024 |
instances = self.parse_ocr_from_text_and_spans(
|
@@ -1040,6 +1096,9 @@ class Florence2PostProcesser(object):
|
|
1040 |
elif task == 'description_with_bboxes':
|
1041 |
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
1042 |
text,
|
|
|
|
|
|
|
1043 |
pattern=pattern,
|
1044 |
image_size=image_size,
|
1045 |
)
|
|
|
20 |
import logging
|
21 |
from typing import List, Optional, Union
|
22 |
import numpy as np
|
23 |
+
import math
|
24 |
|
25 |
import torch
|
26 |
|
|
|
33 |
TextInput,
|
34 |
TruncationStrategy,
|
35 |
)
|
36 |
+
from transformers import BartTokenizer, BartTokenizerFast
|
37 |
from transformers.utils import TensorType
|
38 |
|
39 |
|
|
|
306 |
image_processor_input_names = self.image_processor.model_input_names
|
307 |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
308 |
|
309 |
+
def post_process_generation(self, text=None, sequence=None, transition_beam_score=None, task=None, image_size=None):
|
310 |
"""
|
311 |
Post-process the output of the model to each of the task outputs.
|
312 |
|
|
|
319 |
task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
|
320 |
task_answer = self.post_processor(
|
321 |
text=text,
|
322 |
+
sequence=sequence,
|
323 |
+
transition_beam_score=transition_beam_score,
|
324 |
image_size=image_size,
|
325 |
parse_tasks=task_answer_post_processing_type,
|
326 |
)[task_answer_post_processing_type]
|
|
|
334 |
bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
|
335 |
labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
|
336 |
final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
|
337 |
+
if len(od_instances) and 'score' in od_instances[0]:
|
338 |
+
scores_od = [_od_instance['score'] for _od_instance in od_instances]
|
339 |
+
final_answer['scores'] = scores_od
|
340 |
elif task_answer_post_processing_type in ['ocr']:
|
341 |
bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
|
342 |
labels = [str(_od_instance['text']) for _od_instance in task_answer]
|
|
|
598 |
'PARSE_TASKS': [
|
599 |
{
|
600 |
'TASK_NAME': 'od',
|
601 |
+
'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>',
|
602 |
+
'SCORE_MODE': 'avg_loc_scores'
|
603 |
},
|
604 |
{
|
605 |
'TASK_NAME': 'ocr',
|
|
|
615 |
},
|
616 |
{
|
617 |
'TASK_NAME': 'description_with_bboxes',
|
618 |
+
'SCORE_MODE': 'avg_loc_scores'
|
619 |
},
|
620 |
{
|
621 |
'TASK_NAME': 'description_with_polygons',
|
|
|
656 |
filtered_tokens = tokenizer.convert_ids_to_tokens(
|
657 |
token_ids, skip_special_tokens=False)
|
658 |
assert len(filtered_tokens) == len(token_ids)
|
|
|
|
|
|
|
|
|
659 |
sub_texts = []
|
660 |
for token in filtered_tokens:
|
661 |
if token in self.all_special_tokens:
|
|
|
663 |
else:
|
664 |
if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
|
665 |
sub_text = tokenizer.convert_tokens_to_string([token])
|
|
|
|
|
|
|
|
|
666 |
else:
|
667 |
raise ValueError(f'type {type(tokenizer)} not supported')
|
668 |
sub_texts.append(sub_text)
|
|
|
673 |
span = (len(text), len(text) + len(sub_text)) # [start index, end index).
|
674 |
text += sub_text
|
675 |
spans.append(span)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
return text, spans
|
677 |
|
678 |
def parse_od_from_text_and_spans(
|
|
|
707 |
return instances
|
708 |
|
709 |
def parse_ocr_from_text_and_spans(self,
|
710 |
+
text,
|
711 |
pattern,
|
712 |
image_size,
|
713 |
area_threshold=-1.0,
|
|
|
811 |
|
812 |
return instances
|
813 |
|
814 |
+
def parse_description_with_bboxes_from_text_and_spans(
|
815 |
+
self,
|
816 |
+
text,
|
817 |
+
spans=None,
|
818 |
+
scores=None,
|
819 |
+
score_mode=None,
|
820 |
+
pattern=None,
|
821 |
+
image_size=None,
|
822 |
+
allow_empty_phrase=False
|
823 |
+
):
|
824 |
+
def find_matched_token_indices(cur_span, token_spans):
|
825 |
+
inds = []
|
826 |
+
for i, token_span in enumerate(token_spans):
|
827 |
+
if not (token_span[1] <= cur_span[0] or token_span[0] >= cur_span[1]):
|
828 |
+
inds.append(i)
|
829 |
+
return inds
|
830 |
+
|
831 |
+
cur_span = 0
|
832 |
+
if text.startswith('<s>'):
|
833 |
+
cur_span += 3
|
834 |
|
835 |
text = text.replace('<s>', '')
|
836 |
text = text.replace('</s>', '')
|
|
|
852 |
phrase_text_strip = pharse_text.replace('<obj>', '', 1)
|
853 |
|
854 |
if phrase_text_strip == '' and not allow_empty_phrase:
|
855 |
+
cur_span += len(pharse_text)
|
856 |
continue
|
857 |
|
858 |
# parse phrase, get string
|
859 |
phrase = re.search(pattern, phrase_text_strip)
|
860 |
if phrase is None:
|
861 |
+
cur_span += len(pharse_text)
|
862 |
continue
|
863 |
|
864 |
+
phrase_span = phrase.span()
|
865 |
phrase = phrase.group()
|
866 |
# remove leading and trailing spaces
|
867 |
phrase = phrase.strip()
|
|
|
869 |
# parse bboxes by box_pattern
|
870 |
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
|
871 |
if len(bboxes_parsed) == 0:
|
872 |
+
cur_span += len(pharse_text)
|
873 |
continue
|
874 |
|
875 |
# a list of list
|
|
|
880 |
size=image_size
|
881 |
).tolist()
|
882 |
|
883 |
+
if score_mode == 'avg_loc_scores':
|
884 |
+
if spans is None or scores is None:
|
885 |
+
all_scores = None
|
886 |
+
else:
|
887 |
+
bbox_end_spans = [_bboxes_parsed.span(0) for _bboxes_parsed in bboxes_parsed]
|
888 |
+
all_scores = []
|
889 |
+
for _spans in bbox_end_spans:
|
890 |
+
token_inds = find_matched_token_indices((_spans[0] + cur_span, _spans[1]+ cur_span), spans)
|
891 |
+
loc_scores = [scores[token_i] for token_i in token_inds]
|
892 |
+
score = sum(loc_scores) / len(loc_scores)
|
893 |
+
all_scores.append(score)
|
894 |
+
elif score_mode == 'avg_cat_name_scores':
|
895 |
+
if spans is None or scores is None:
|
896 |
+
all_scores = None
|
897 |
+
else:
|
898 |
+
cat_name_token_inds = find_matched_token_indices((phrase_span[0] + cur_span, phrase_span[1]+cur_span), spans)
|
899 |
+
cat_name_scores = [scores[token_i] for token_i in cat_name_token_inds]
|
900 |
+
score = sum(cat_name_scores) / len(cat_name_scores)
|
901 |
+
all_scores = [score] * len(bboxes)
|
902 |
+
elif score_mode is None:
|
903 |
+
all_scores = None
|
904 |
+
else:
|
905 |
+
raise ValueError('Unknown score mode: {}'.format(score_mode))
|
906 |
+
|
907 |
phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
|
908 |
+
for _idx, _bboxes in enumerate(bboxes):
|
909 |
# Prepare instance.
|
910 |
instance = {}
|
911 |
instance['bbox'] = _bboxes
|
912 |
# exclude non-ascii characters
|
913 |
instance['cat_name'] = phrase
|
914 |
+
if all_scores is not None:
|
915 |
+
instance['score'] = math.exp(all_scores[_idx])
|
916 |
instances.append(instance)
|
917 |
+
|
918 |
+
cur_span += len(pharse_text)
|
919 |
|
920 |
return instances
|
921 |
|
|
|
1033 |
def __call__(
|
1034 |
self,
|
1035 |
text=None,
|
1036 |
+
sequence=None,
|
1037 |
+
transition_beam_score=None,
|
1038 |
image_size=None,
|
1039 |
parse_tasks=None,
|
1040 |
):
|
|
|
1052 |
assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
|
1053 |
|
1054 |
# sequence or text should be provided
|
1055 |
+
assert sequence is not None or text is not None, 'sequence or text should be provided'
|
1056 |
+
assert sequence is None or text is None, 'only one of sequence and text should be provided'
|
1057 |
+
|
1058 |
+
if sequence is not None:
|
1059 |
+
sequence = sequence.tolist()[1:]
|
1060 |
+
text, spans = self.decode_with_spans(self.tokenizer, sequence)
|
1061 |
+
if transition_beam_score is not None:
|
1062 |
+
transition_beam_score = transition_beam_score.tolist()
|
1063 |
+
assert len(sequence) == len(transition_beam_score)
|
1064 |
+
else:
|
1065 |
+
spans = None
|
1066 |
+
transition_beam_score = None
|
1067 |
|
1068 |
parsed_dict = {
|
1069 |
'text': text
|
|
|
1074 |
continue
|
1075 |
|
1076 |
pattern = self.parse_tasks_configs[task].get('PATTERN', None)
|
1077 |
+
score_mode = self.parse_tasks_configs[task].get('SCORE_MODE', None)
|
1078 |
|
1079 |
if task == 'ocr':
|
1080 |
instances = self.parse_ocr_from_text_and_spans(
|
|
|
1096 |
elif task == 'description_with_bboxes':
|
1097 |
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
1098 |
text,
|
1099 |
+
spans=spans,
|
1100 |
+
scores=transition_beam_score,
|
1101 |
+
score_mode=score_mode,
|
1102 |
pattern=pattern,
|
1103 |
image_size=image_size,
|
1104 |
)
|