Biomap / biomap /utils copy.py
jeremyLE-Ekimetrics's picture
streamlit
9fcd62f
raw
history blame
26.6 kB
import collections
import os
from os.path import join
import io
import matplotlib.pyplot as plt
import numpy as np
import torch.multiprocessing
import torch.nn as nn
import torch.nn.functional as F
import wget
import datetime
from dateutil.relativedelta import relativedelta
from PIL import Image
from scipy.optimize import linear_sum_assignment
from torch._six import string_classes
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
from torchmetrics import Metric
from torchvision import models
from torchvision import transforms as T
from torch.utils.tensorboard.summary import hparams
import matplotlib as mpl
from PIL import Image
import matplotlib as mpl
import torch.multiprocessing
import torchvision.transforms as T
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from plotly.subplots import make_subplots
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey")
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
mapping_class = {
"Buildings": 1,
"Cultivation": 2,
"Natural green": 3,
"Wetland": 4,
"Water": 5,
"Infrastructure": 6,
"Background": 0,
}
score_attribution = {
"Buildings" : 0.,
"Cultivation": 0.3,
"Natural green": 1.,
"Wetland": 0.9,
"Water": 0.9,
"Infrastructure": 0.,
"Background": 0.
}
bounds = list(np.arange(len(mapping_class.keys()) + 1) + 1)
cmap = mpl.colors.ListedColormap(colors)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
def compute_biodiv_score(class_image):
"""Compute the biodiversity score of an image
Args:
image (_type_): _description_
Returns:
biodiversity_score: the biodiversity score associated to the landscape of the image
"""
score_matrice = class_image.copy().astype(int)
for key in mapping_class.keys():
score_matrice = np.where(score_matrice==mapping_class[key], score_attribution[key], score_matrice)
number_of_pixel = np.prod(list(score_matrice.shape))
score = np.sum(score_matrice)/number_of_pixel
score_details = {
key: np.sum(np.where(class_image == mapping_class[key], 1, 0))
for key in mapping_class.keys()
if key not in ["background"]
}
return score, score_details
def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
scores = [0.89, 0.70, 0.3, 0.2]
# fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
# fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
# # Scores
# scatters = [go.Scatter(
# x=months[:i+1],
# y=scores[:i+1],
# mode="lines+markers+text",
# marker_color="black",
# text = [f"{score:.4f}" for score in scores[:i+1]],
# textposition="top center",
# ) for i in range(len(scores))]
# fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
# fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
# fig.add_trace(go.Pie(labels = class_names,
# values = [nb_values[0][key] for key in mapping_class.keys()],
# marker_colors = colors,
# name="Segment repartition",
# textposition='inside',
# texttemplate = "%{percent:.0%}",
# textfont_size=14
# ),
# row=1, col=3)
# fig.add_trace(scatters[0], row=1, col=4)
# # fig.update_traces(selector=dict(type='scatter'))
# number_frames = len(imgs)
# frames = [dict(
# name = k,
# data = [ fig2["frames"][k]["data"][0],
# fig3["frames"][k]["data"][0],
# go.Pie(labels = class_names,
# values = [nb_values[k][key] for key in mapping_class.keys()],
# marker_colors = colors,
# name="Segment repartition",
# textposition='inside',
# texttemplate = "%{percent:.0%}",
# textfont_size=14
# ),
# scatters[k]
# ],
# traces=[0, 1, 2, 3]
# ) for k in range(number_frames)]
# updatemenus = [dict(type='buttons',
# buttons=[dict(
# label='Play',
# method='animate',
# args=[
# [f'{k}' for k in range(number_frames)],
# dict(
# frame=dict(duration=500, redraw=False),
# transition=dict(duration=0),
# # easing='linear',
# # fromcurrent=True,
# # mode='immediate'
# )
# ])
# ],
# direction= 'left',
# pad=dict(r= 10, t=85),
# showactive=True, x= 0.1, y= 0.1, xanchor= 'right', yanchor= 'bottom')
# ]
# sliders = [{'yanchor': 'top',
# 'xanchor': 'left',
# 'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
# 'transition': {'duration': 500.0, 'easing': 'linear'},
# 'pad': {'b': 10, 't': 50},
# 'len': 0.9, 'x': 0.1, 'y': 0,
# 'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
# 'transition': {'duration': 0, 'easing': 'linear'}}],
# 'label': months[k], 'method': 'animate'} for k in range(number_frames)
# ]}]
# fig.update(frames=frames,
# layout={
# "xaxis1": {
# "autorange":True,
# 'showgrid': False,
# 'zeroline': False, # thick line at x=0
# 'visible': False, # numbers below
# },
# "yaxis1": {
# "autorange":True,
# 'showgrid': False,
# 'zeroline': False,
# 'visible': False,},
# "xaxis2": {
# "autorange":True,
# 'showgrid': False,
# 'zeroline': False,
# 'visible': False,
# },
# "yaxis2": {
# "autorange":True,
# 'showgrid': False,
# 'zeroline': False,
# 'visible': False,},
# "xaxis4": {
# "ticktext": months,
# "tickvals": months,
# "tickangle": 90,
# },
# "yaxis4": {
# 'range': [min(scores)*0.9, max(scores)* 1.1],
# 'showgrid': False,
# 'zeroline': False,
# 'visible': True
# },
# })
# fig.update_layout(
# updatemenus=updatemenus,
# sliders=sliders,
# # legend=dict(
# # yanchor= 'bottom',
# # xanchor= 'center',
# # orientation="h"),
# )
# Scores
fig = make_subplots(
rows=1, cols=4,
specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "scatter"}]],
subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores")
)
fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
pie_charts = [go.Pie(labels = class_names,
values = [nb_values[k][key] for key in mapping_class.keys()],
marker_colors = colors,
name="Segment repartition",
textposition='inside',
texttemplate = "%{percent:.0%}",
textfont_size=14,
)
for k in range(len(scores))]
scatters = [go.Scatter(
x=months[:i+1],
y=scores[:i+1],
mode="lines+markers+text",
marker_color="black",
text = [f"{score:.4f}" for score in scores[:i+1]],
textposition="top center",
) for i in range(len(scores))]
fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
fig.add_trace(pie_charts[0], row=1, col=3)
fig.add_trace(scatters[0], row=1, col=4)
start_date = datetime.datetime.strptime(months[0], "%Y-%m-%d") - relativedelta(months=1)
end_date = datetime.datetime.strptime(months[-1], "%Y-%m-%d") + relativedelta(months=1)
interval = [start_date.strftime("%Y-%m-%d"),end_date.strftime("%Y-%m-%d")]
fig.update_layout({
"xaxis": {
"autorange":True,
'showgrid': False,
'zeroline': False, # thick line at x=0
'visible': False, # numbers below
},
"yaxis": {
"autorange":True,
'showgrid': False,
'zeroline': False,
'visible': False,},
"xaxis1": {
"range":[0,imgs[0].shape[1]],
'showgrid': False,
'zeroline': False,
'visible': False,
},
"yaxis1": {
"range":[imgs[0].shape[0],0],
'showgrid': False,
'zeroline': False,
'visible': False,},
"xaxis3": {
"dtick":"M3",
"range":interval
},
"yaxis3": {
'range': [min(scores)*0.9, max(scores)* 1.1],
'showgrid': False,
'zeroline': False,
'visible': True
}}
)
frames = [dict(
name = k,
data = [ fig2["frames"][k]["data"][0],
fig3["frames"][k]["data"][0],
pie_charts[k],
scatters[k]
],
traces=[0,1,2,3]
) for k in range(len(scores))]
updatemenus = [dict(type='buttons',
buttons=[dict(label='Play',
method='animate',
args=[
[f'{k}' for k in range(len(scores))],
dict(
frame=dict(duration=500, redraw=False),
transition=dict(duration=0),
# easing='linear',
# fromcurrent=True,
# mode='immediate'
)
]
)],
direction= 'left',
pad=dict(r= 10, t=85),
showactive =True, x= 0.1, y= 0, xanchor= 'right', yanchor= 'top')
]
sliders = [{'yanchor': 'top',
'xanchor': 'left',
'currentvalue': {
'font': {'size': 16},
'visible': True,
'xanchor': 'right'},
'transition': {
'duration': 500.0,
'easing': 'linear'},
'pad': {'b': 10, 't': 50},
'len': 0.9, 'x': 0.1, 'y': 0,
'steps': [{'args': [None, {'frame': {'duration': 500.0,'redraw': False},
'transition': {'duration': 0}}],
'label': k, 'method': 'animate'} for k in range(len(scores))
]
}]
fig.update_layout(updatemenus=updatemenus,
sliders=sliders,
)
fig.update(frames=frames)
return fig
def transform_to_pil(output, alpha=0.3):
# Transform img with torch
img = torch.moveaxis(prep_for_plot(output['img']),-1,0)
img=T.ToPILImage()(img)
cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
labels = np.array(output['linear_preds'])-1
label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
# Overlay labels with img wit alpha
background = img.convert("RGBA")
overlay = label.convert("RGBA")
labeled_img = Image.blend(background, overlay, alpha)
return img, label, labeled_img
def prep_for_plot(img, rescale=True, resize=None):
if resize is not None:
img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear")
else:
img = img.unsqueeze(0)
plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0)
if rescale:
plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min())
return plot_img
def add_plot(writer, name, step):
buf = io.BytesIO()
plt.savefig(buf, format='jpeg', dpi=100)
buf.seek(0)
image = Image.open(buf)
image = T.ToTensor()(image)
writer.add_image(name, image, step)
plt.clf()
plt.close()
@torch.jit.script
def shuffle(x):
return x[torch.randperm(x.shape[0])]
def add_hparams_fixed(writer, hparam_dict, metric_dict, global_step):
exp, ssi, sei = hparams(hparam_dict, metric_dict)
writer.file_writer.add_summary(exp)
writer.file_writer.add_summary(ssi)
writer.file_writer.add_summary(sei)
for k, v in metric_dict.items():
writer.add_scalar(k, v, global_step)
@torch.jit.script
def resize(classes: torch.Tensor, size: int):
return F.interpolate(classes, (size, size), mode="bilinear", align_corners=False)
def one_hot_feats(labels, n_classes):
return F.one_hot(labels, n_classes).permute(0, 3, 1, 2).to(torch.float32)
def load_model(model_type, data_dir):
if model_type == "robust_resnet50":
model = models.resnet50(pretrained=False)
model_file = join(data_dir, 'imagenet_l2_3_0.pt')
if not os.path.exists(model_file):
wget.download("http://6.869.csail.mit.edu/fa19/psets19/pset6/imagenet_l2_3_0.pt",
model_file)
model_weights = torch.load(model_file)
model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if
'model' in name}
model.load_state_dict(model_weights_modified)
model = nn.Sequential(*list(model.children())[:-1])
elif model_type == "densecl":
model = models.resnet50(pretrained=False)
model_file = join(data_dir, 'densecl_r50_coco_1600ep.pth')
if not os.path.exists(model_file):
wget.download("https://cloudstor.aarnet.edu.au/plus/s/3GapXiWuVAzdKwJ/download",
model_file)
model_weights = torch.load(model_file)
# model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if
# 'model' in name}
model.load_state_dict(model_weights['state_dict'], strict=False)
model = nn.Sequential(*list(model.children())[:-1])
elif model_type == "resnet50":
model = models.resnet50(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1])
elif model_type == "mocov2":
model = models.resnet50(pretrained=False)
model_file = join(data_dir, 'moco_v2_800ep_pretrain.pth.tar')
if not os.path.exists(model_file):
wget.download("https://dl.fbaipublicfiles.com/moco/moco_checkpoints/"
"moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", model_file)
checkpoint = torch.load(model_file)
# rename moco pre-trained keys
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
# retain only encoder_q up to before the embedding layer
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
# remove prefix
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
model = nn.Sequential(*list(model.children())[:-1])
elif model_type == "densenet121":
model = models.densenet121(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))])
elif model_type == "vgg11":
model = models.vgg11(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))])
else:
raise ValueError("No model: {} found".format(model_type))
model.eval()
model.cuda()
return model
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image):
image2 = torch.clone(image)
for t, m, s in zip(image2, self.mean, self.std):
t.mul_(s).add_(m)
return image2
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
class ToTargetTensor(object):
def __call__(self, target):
return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0)
def prep_args():
import sys
old_args = sys.argv
new_args = [old_args.pop(0)]
while len(old_args) > 0:
arg = old_args.pop(0)
if len(arg.split("=")) == 2:
new_args.append(arg)
elif arg.startswith("--"):
new_args.append(arg[2:] + "=" + old_args.pop(0))
else:
raise ValueError("Unexpected arg style {}".format(arg))
sys.argv = new_args
def get_transform(res, is_label, crop_type):
if crop_type == "center":
cropper = T.CenterCrop(res)
elif crop_type == "random":
cropper = T.RandomCrop(res)
elif crop_type is None:
cropper = T.Lambda(lambda x: x)
res = (res, res)
else:
raise ValueError("Unknown Cropper {}".format(crop_type))
if is_label:
return T.Compose([T.Resize(res, Image.NEAREST),
cropper,
ToTargetTensor()])
else:
return T.Compose([T.Resize(res, Image.NEAREST),
cropper,
T.ToTensor(),
normalize])
def _remove_axes(ax):
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.yaxis.set_major_formatter(plt.NullFormatter())
ax.set_xticks([])
ax.set_yticks([])
def remove_axes(axes):
if len(axes.shape) == 2:
for ax1 in axes:
for ax in ax1:
_remove_axes(ax)
else:
for ax in axes:
_remove_axes(ax)
class UnsupervisedMetrics(Metric):
def __init__(self, prefix: str, n_classes: int, extra_clusters: int, compute_hungarian: bool,
dist_sync_on_step=True):
# call `self.add_state`for every internal state that is needed for the metrics computations
# dist_reduce_fx indicates the function that should be used to reduce
# state from multiple processes
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.n_classes = n_classes
self.extra_clusters = extra_clusters
self.compute_hungarian = compute_hungarian
self.prefix = prefix
self.add_state("stats",
default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64),
dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
with torch.no_grad():
actual = target.reshape(-1)
preds = preds.reshape(-1)
mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes)
actual = actual[mask]
preds = preds[mask]
self.stats += torch.bincount(
(self.n_classes + self.extra_clusters) * actual + preds,
minlength=self.n_classes * (self.n_classes + self.extra_clusters)) \
.reshape(self.n_classes, self.n_classes + self.extra_clusters).t().to(self.stats.device)
def map_clusters(self, clusters):
if self.extra_clusters == 0:
return torch.tensor(self.assignments[1])[clusters]
else:
missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])))
cluster_to_class = self.assignments[1]
for missing_entry in missing:
if missing_entry == cluster_to_class.shape[0]:
cluster_to_class = np.append(cluster_to_class, -1)
else:
cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1)
cluster_to_class = torch.tensor(cluster_to_class)
return cluster_to_class[clusters]
def compute(self):
if self.compute_hungarian:
self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True)
# print(self.assignments)
if self.extra_clusters == 0:
self.histogram = self.stats[np.argsort(self.assignments[1]), :]
if self.extra_clusters > 0:
self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True)
histogram = self.stats[self.assignments_t[1], :]
missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))
new_row = self.stats[missing, :].sum(0, keepdim=True)
histogram = torch.cat([histogram, new_row], axis=0)
new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device)
self.histogram = torch.cat([histogram, new_col], axis=1)
else:
self.assignments = (torch.arange(self.n_classes).unsqueeze(1),
torch.arange(self.n_classes).unsqueeze(1))
self.histogram = self.stats
tp = torch.diag(self.histogram)
fp = torch.sum(self.histogram, dim=0) - tp
fn = torch.sum(self.histogram, dim=1) - tp
iou = tp / (tp + fp + fn)
prc = tp / (tp + fn)
opc = torch.sum(tp) / torch.sum(self.histogram)
metric_dict = {self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(),
self.prefix + "Accuracy": opc.item()}
return {k: 100 * v for k, v in metric_dict.items()}
def flexible_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
try:
return torch.stack(batch, 0, out=out)
except RuntimeError:
return batch
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return flexible_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: flexible_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(flexible_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [flexible_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
if __name__ == "__main__":
fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)