File size: 3,412 Bytes
b6ad7e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import cv2
import matplotlib.pyplot as plt
from PIL import ImageColor
from pathlib import Path
import os

def annotate_image_prediction(image_path, yolo_boxes, class_dic, saving_folder, hex_class_colors=None, show=False, true_count=False, saving_image_name=None, put_title=True, box_thickness=3, font_scale=1, font_thickness=5):
    """
    Fonction to label individual images with YOLO predictions
    Args:
        image_path (str): path to the image to label
        yolo_boxes (str): YOLO predicted boxes
        class_dic (dict): dictionary with predicted class as key and corresponding label as value
        saving_folder (str): folder where to save the annotated image
        hex_class_colors (dict, optional): HEX color code dict of the class to plot. Defaults to None.
        show (bool, optional): If you want a window of the annotated image to pop up. Defaults to False.
        true_count (bool, optional): If you want to display the true total count of cherries. Defaults to None.
        saving_image_name (str, optional): Name of the annotated image to save. Defaults to None.
        put_title (bool, optional): If you want a title to show in the plot. Defaults to True.
        box_thickness (int, optional): Thickness of the bounding boxes to plot. Defaults to 3.
        font_scale (int, optional): Font scale of the text of counts to be displayed. Defaults to 1.
        font_thickness (int, optional): Font thickness of the text of counts to be displayed. Defaults to 5.

    Returns:
        string: saving path of the annotated image
    """    
    if os.path.isfile(image_path):
        Path(saving_folder).mkdir(parents=True, exist_ok=True)
        image_file = image_path.split('/')[-1]
        if not hex_class_colors:
            hex_class_colors = {class_name: (255, 0, 0) for class_name in class_dic.values()}
        color_map = {key: ImageColor.getcolor(hex_class_colors[class_dic[key]], 'RGB') for key in [*class_dic]}

        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        dh, dw, _ = img.shape

        for yolo_box in yolo_boxes:
            x1, y1, x2, y2 = yolo_box.xyxy[0]
            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
            c = int(yolo_box.cls[0])
            cv2.rectangle(img, (x1, y1), (x2, y2), color_map[c], box_thickness)

        if show:
            plt.imshow(img)
            plt.show()
        img_copy = img.copy()
        if put_title:
            if true_count:
                title = f'Predicted count: {len(yolo_boxes)}, true count: {true_count}, delta: {len(yolo_boxes) - true_count}'
            else:
                title = f'Predicted count: {len(yolo_boxes)}'
            cv2.putText(
                img=img_copy,
                text=title,
                org=(int(0.1 * dw), int(0.1 * dh)),
                fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                fontScale=font_scale,
                thickness=font_thickness,
                color=(255,251,5),
            )

        if not saving_image_name:
            saving_image_name = f'annotated_{image_file}'
        Path(saving_folder).mkdir(parents=True, exist_ok=True)
        full_saving_path = os.path.join(saving_folder, saving_image_name)
        plt.imsave(full_saving_path, img_copy)
    else:
        full_saving_path = None
        print(f'WARNING: {image_path} does not exists')
    return full_saving_path