import cv2 import matplotlib.pyplot as plt from PIL import ImageColor from pathlib import Path import os def annotate_image_prediction(image_path, yolo_boxes, class_dic, saving_folder, hex_class_colors=None, show=False, true_count=False, saving_image_name=None, put_title=True, box_thickness=3, font_scale=1, font_thickness=5): """ Fonction to label individual images with YOLO predictions Args: image_path (str): path to the image to label yolo_boxes (str): YOLO predicted boxes class_dic (dict): dictionary with predicted class as key and corresponding label as value saving_folder (str): folder where to save the annotated image hex_class_colors (dict, optional): HEX color code dict of the class to plot. Defaults to None. show (bool, optional): If you want a window of the annotated image to pop up. Defaults to False. true_count (bool, optional): If you want to display the true total count of cherries. Defaults to None. saving_image_name (str, optional): Name of the annotated image to save. Defaults to None. put_title (bool, optional): If you want a title to show in the plot. Defaults to True. box_thickness (int, optional): Thickness of the bounding boxes to plot. Defaults to 3. font_scale (int, optional): Font scale of the text of counts to be displayed. Defaults to 1. font_thickness (int, optional): Font thickness of the text of counts to be displayed. Defaults to 5. Returns: string: saving path of the annotated image """ if os.path.isfile(image_path): Path(saving_folder).mkdir(parents=True, exist_ok=True) image_file = image_path.split('/')[-1] if not hex_class_colors: hex_class_colors = {class_name: (255, 0, 0) for class_name in class_dic.values()} color_map = {key: ImageColor.getcolor(hex_class_colors[class_dic[key]], 'RGB') for key in [*class_dic]} img = cv2.imread(image_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) dh, dw, _ = img.shape for yolo_box in yolo_boxes: x1, y1, x2, y2 = yolo_box.xyxy[0] x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) c = int(yolo_box.cls[0]) cv2.rectangle(img, (x1, y1), (x2, y2), color_map[c], box_thickness) if show: plt.imshow(img) plt.show() img_copy = img.copy() if put_title: if true_count: title = f'Predicted count: {len(yolo_boxes)}, true count: {true_count}, delta: {len(yolo_boxes) - true_count}' else: title = f'Predicted count: {len(yolo_boxes)}' cv2.putText( img=img_copy, text=title, org=(int(0.1 * dw), int(0.1 * dh)), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=font_scale, thickness=font_thickness, color=(255,251,5), ) if not saving_image_name: saving_image_name = f'annotated_{image_file}' Path(saving_folder).mkdir(parents=True, exist_ok=True) full_saving_path = os.path.join(saving_folder, saving_image_name) plt.imsave(full_saving_path, img_copy) else: full_saving_path = None print(f'WARNING: {image_path} does not exists') return full_saving_path