File size: 6,774 Bytes
aea73e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# -*- coding: utf-8 -*-
# WSI Model
#
# @ Fabian Hörst, [email protected]
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen


import json
from pathlib import Path
from typing import Union, List, Callable, Tuple

from dataclasses import dataclass, field
import numpy as np
import yaml
import logging
import torch
from PIL import Image


@dataclass
class WSI:
    """WSI object

    Args:
        name (str): WSI name
        patient (str): Patient name
        slide_path (Union[str, Path]): Full path to the WSI file.
        patched_slide_path (Union[str, Path], optional): Full path to preprocessed WSI files (patches). Defaults to None.
        embedding_name (Union[str, Path], optional): Defaults to None.
        label (Union[str, int, float, np.ndarray], optional): Label of the WSI. Defaults to None.
        logger (logging.logger, optional): Logger module for logging information. Defaults to None.
    """

    name: str
    patient: str
    slide_path: Union[str, Path]
    patched_slide_path: Union[str, Path] = None
    embedding_name: Union[str, Path] = None
    label: Union[str, int, float, np.ndarray] = None
    logger: logging.Logger = None

    # unset attributes used in this class
    metadata: dict = field(init=False, repr=False)
    all_patch_metadata: List[dict] = field(init=False, repr=False)
    patches_list: List = field(init=False, repr=False)
    patch_transform: Callable = field(init=False, repr=False)

    # name without ending (e.g. slide1 instead of slide1.svs)
    def __post_init__(self):
        """Post-Processing object"""
        super().__init__()
        # define paramaters that are used, but not defined at startup

        # convert string to path
        self.slide_path = Path(self.slide_path).resolve()
        if self.patched_slide_path is not None:
            self.patched_slide_path = Path(self.patched_slide_path).resolve()
            # load metadata
            self._get_metadata()
            self._get_wsi_patch_metadata()
            self.patch_transform = None  # hardcode to None (should not be a parameter, but should be defined)

        if self.logger is not None:
            self.logger.debug(self.__repr__())

    def _get_metadata(self) -> None:
        """Load metadata yaml file"""
        self.metadata_path = self.patched_slide_path / "metadata.yaml"
        with open(self.metadata_path.resolve(), "r") as metadata_yaml:
            try:
                self.metadata = yaml.safe_load(metadata_yaml)
            except yaml.YAMLError as exc:
                print(exc)
        self.metadata["label_map_inverse"] = {
            v: k for k, v in self.metadata["label_map"].items()
        }

    def _get_wsi_patch_metadata(self) -> None:
        """Load patch_metadata json file and convert to dict and lists"""
        with open(self.patched_slide_path / "patch_metadata.json", "r") as json_file:
            metadata = json.load(json_file)
            self.patches_list = [str(list(elem.keys())[0]) for elem in metadata]
            self.all_patch_metadata = {
                str(list(elem.keys())[0]): elem[str(list(elem.keys())[0])]
                for elem in metadata
            }

    def load_patch_metadata(self, patch_name: str) -> dict:
        """Return the metadata of a patch with given name (including patch suffix, e.g., wsi_1_1.png)

        This function assumes that metadata path is a subpath of the patches dataset path

        Args:
            patch_name (str): Name of patch

        Returns:
            dict: metadata
        """
        patch_metadata_path = self.all_patch_metadata[patch_name]["metadata_path"]
        patch_metadata_path = self.patched_slide_path / patch_metadata_path

        # open
        with open(patch_metadata_path, "r") as metadata_yaml:
            patch_metadata = yaml.safe_load(metadata_yaml)
        patch_metadata["name"] = patch_name

        return patch_metadata

    def set_patch_transform(self, transform: Callable) -> None:
        """Set the transformation function to process a patch

        Args:
            transform (Callable): Transformation function
        """
        self.patch_transform = transform

    # patch processing
    def process_patch_image(
        self, patch_name: str, transform: Callable = None
    ) -> Tuple[torch.Tensor, dict]:
        """Process one patch: Load from disk, apply transformation if needed. ToTensor is applied automatically

        Args:
            patch_name (Path): Name of patch to load, including patch suffix, e.g., wsi_1_1.png
            transform (Callable, optional): Optional Patch-Transformation
        Returns:
            Tuple[torch.Tensor, dict]:

            * torch.Tensor: patch as torch.tensor (:,:,3)
            * dict: patch metadata as dictionary
        """
        patch = Image.open(self.patched_slide_path / "patches" / patch_name)
        if transform:
            patch = transform(patch)

        metadata = self.load_patch_metadata(patch_name)
        return patch, metadata

    def get_number_patches(self) -> int:
        """Return the number of patches for this WSI

        Returns:
            int: number of patches
        """
        return int(len(self.patches_list))

    def get_patches(
        self, transform: Callable = None
    ) -> Tuple[torch.Tensor, list, list]:
        """Get all patches for one image

        Args:
            transform (Callable, optional): Optional Patch-Transformation

        Returns:
            Tuple[torch.Tensor, list]:

            * patched image: Shape of torch.Tensor(num_patches, 3, :, :)
            * coordinates as list metadata_dictionary

        """
        if self.logger is not None:
            self.logger.warning(f"Loading {self.get_number_patches()} patches!")
        patches = []
        metadata = []
        for patch in self.patches_list:
            transformed_patch, meta = self.process_patch_image(patch, transform)
            patches.append(transformed_patch)
            metadata.append(meta)
        patches = torch.stack(patches)

        return patches, metadata

    def load_embedding(self) -> torch.Tensor:
        """Load embedding from subfolder patched_slide_path/embedding/

        Raises:
            FileNotFoundError: If embedding is not given

        Returns:
            torch.Tensor: WSI embedding
        """
        embedding_path = (
            self.patched_slide_path / "embeddings" / f"{self.embedding_name}.pt"
        )
        if embedding_path.is_file():
            embedding = torch.load(embedding_path)
            return embedding
        else:
            raise FileNotFoundError(
                f"Embeddings for WSI {self.slide_path} cannot be found in path {embedding_path}"
            )