File size: 3,532 Bytes
3a0062c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import pytorch_lightning as pl

from . import config
from .utils import (
    check_class_accuracy,
    get_evaluation_bboxes,
    mean_average_precision,
    plot_couple_examples,
)

class PlotTestExamplesCallback(pl.Callback):
    def __init__(self, every_n_epochs: int = 1) -> None:
        super().__init__()
        self.every_n_epochs = every_n_epochs

    def on_train_epoch_end(self, trainer:pl.Trainer, pl_module:pl.LightningModule) -> None:
        if (trainer.current_epoch + 1) % self.every_n_epochs == 0:
            plot_couple_examples(
                model=pl_module,
                loader=trainer.datamodule.test_dataloader(),
                thresh=0.6,
                iou_thresh=0.5,
                anchors=pl_module.scaled_anchors
            )

class CheckClassAccuracyCallback(pl.Callback):
    def __init__(
        self, train_every_n_epochs: int = 1, test_every_n_epochs: int = 3
    ) -> None:
        super().__init__()
        self.train_every_n_epochs = train_every_n_epochs
        self.test_every_n_epochs = test_every_n_epochs

    def on_train_epoch_end(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule
    ) -> None:
        if (trainer.current_epoch + 1) % self.train_every_n_epochs == 0:
            print("+++ TRAIN ACCURACIES")
            class_acc, no_obj_acc, obj_acc = check_class_accuracy(
                model=pl_module,
                loader=trainer.datamodule.train_dataloader(),
                threshold=config.CONF_THRESHOLD,
            )
            pl_module.log_dict(
                {
                    "train_class_acc": class_acc,
                    "train_no_obj_acc": no_obj_acc,
                    "train_obj_acc": obj_acc,
                },
                logger=True,
            )

        if (trainer.current_epoch + 1) % self.test_every_n_epochs == 0:
            print("+++ TEST ACCURACIES")
            class_acc, no_obj_acc, obj_acc = check_class_accuracy(
                model=pl_module,
                loader=trainer.datamodule.test_dataloader(),
                threshold=config.CONF_THRESHOLD,
            )
            pl_module.log_dict(
                {
                    "test_class_acc": class_acc,
                    "test_no_obj_acc": no_obj_acc,
                    "test_obj_acc": obj_acc,
                },
                logger=True,
            )
class MAPCallback(pl.Callback):
    def __init__(self, every_n_epochs: int = 3) -> None:
        super().__init__()
        self.every_n_epochs = every_n_epochs

    def on_train_epoch_end(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule
    ) -> None:
        if (trainer.current_epoch + 1) % self.every_n_epochs == 0:
            pred_boxes, true_boxes = get_evaluation_bboxes(
                loader=trainer.datamodule.test_dataloader(),
                model=pl_module,
                iou_threshold=config.NMS_IOU_THRESH,
                anchors=config.ANCHORS,
                threshold=config.CONF_THRESHOLD,
                device=config.DEVICE,
            )

            map_val = mean_average_precision(
                pred_boxes=pred_boxes,
                true_boxes=true_boxes,
                iou_threshold=config.MAP_IOU_THRESH,
                box_format="midpoint",
                num_classes=config.NUM_CLASSES,
            )
            print("+++ MAP: ", map_val.item())
            pl_module.log(
                "MAP",
                map_val.item(),
                logger=True,
            )
            pl_module.train()