nathbotbol's picture
Upload folder using huggingface_hub
d7aea57 verified
raw
history blame contribute delete
No virus
2.97 kB
import json
import os.path
from PIL import Image
import tqdm
from sacred import Experiment
from engie_pipeline.pipeline import run, pipeline, set_up_pipeline, pipeline_experiment
eval_experiment = Experiment('eval', ingredients=[pipeline_experiment])
@eval_experiment.config
def config():
label_path = "/Users/benoit/Projects/EngieTableauElectrique/data/dataset/test.json"
image_folder = "data/test_data"
conformity_threshold = 0.9
def intersection_over_union(boxA, boxB):
# determine the (x, y)-coordinates of the intersection rectangle
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
# compute the area of intersection rectangle
interArea = max(0, xB - xA) * max(0, yB - yA)
# compute the area of both the prediction and ground-truth
# rectangles
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
# compute the intersection over union by taking the intersection
# area and dividing it by the sum of prediction + ground-truth
# areas - the interesection area
iou = interArea / float(boxAArea + boxBArea - interArea)
# return the intersection over union value
return iou
@eval_experiment.automain
def evaluate(label_path, image_folder, conformity_threshold):
models = set_up_pipeline()
confusion_matrix = {
"tp": 0,
"fp": 0,
"fn": 0,
}
with open(label_path) as file:
test_data = json.load(file)
for image in tqdm.tqdm(test_data['images']):
_, boxes, labels, scores, _ = pipeline(image=Image.open(os.path.join(image_folder, image['filename'])), **models, force_detr=True)
annotations = [annotation for annotation in test_data['annotations'] if annotation['image_id'] == image['id'] and annotation['category_id'] == 2]
boxes = boxes[labels == 2]
scores = scores[labels == 2]
boxes = boxes[scores >= conformity_threshold]
bbox_useful = [False for _ in boxes]
for annotation in annotations:
x, y, w, h = annotation['bbox']
is_detected = False
for bbox_id, box in enumerate(boxes):
if intersection_over_union(box, [x, y, x+w, y+h]) > 0:
confusion_matrix['tp'] += 1
is_detected = True
bbox_useful[bbox_id] = True
break
if not is_detected:
confusion_matrix['fn'] += 1
for bbox_id, box in enumerate(boxes):
is_useful = False
for annotation in annotations:
x, y, w, h = annotation['bbox']
if intersection_over_union(box, [x, y, x + w, y + h]) > 0:
is_useful = True
break
if not is_useful:
confusion_matrix['fp'] += 1
return confusion_matrix