Biomap / biomap /helper.py
jeremyLE-Ekimetrics's picture
streamlit
9fcd62f
raw
history blame
5.78 kB
import torch.multiprocessing
import torchvision.transforms as T
import numpy as np
from utils import transform_to_pil, compute_biodiv_score, plot_imgs_labels, plot_image
from utils_gee import get_image
from dateutil.relativedelta import relativedelta
from model import LitUnsupervisedSegmenter
import datetime
import matplotlib as mpl
from joblib import Parallel, cpu_count, delayed
import logging
from inference import inference
import streamlit as st
import cv2
@st.cache_data(hash_funcs={LitUnsupervisedSegmenter: lambda dt: dt.name})
def inference_on_location(model, longitude=2.98, latitude=48.81, start_date=2020, end_date=2022, how="year"):
"""Performe an inference on the latitude and longitude between the start date and the end date
Args:
latitude (float): the latitude of the landscape
longitude (float): the longitude of the landscape
start_date (str): the start date for our inference
end_date (str): the end date for our inference
model (_type_, optional): _description_. Defaults to model.
Returns:
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
"""
logging.info("Running Inference on location")
logging.info(f"latitude : {latitude} & longitude : {longitude}")
logging.info(f"start date : {start_date} & end_date : {end_date}")
logging.info(f"Prediction on intervale : {how}")
if how == "month":
delta_month = 1
elif how == "2months":
delta_month = 2
elif how == "year":
delta_month = 11
else:
raise ValueError("Wrong interval")
assert int(end_date) > int(start_date), "end date must be stricly higher than start date"
location = [float(latitude), float(longitude)]
# Extract img numpy from earth engine and transform it to PIL img
dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)]
while dates[-1] < datetime.datetime(int(end_date), 1, 1, 0, 0, 0):
dates.append(dates[-1] + relativedelta(months=delta_month))
dates = [d.strftime("%Y-%m-%d") for d in dates]
all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:]))
# all_image = [cv2.imread("output/img.png") for i in range(len(dates))]
outputs = inference(np.array(all_image), model)
logging.info("Calculating Biodiversity Scores...")
scores, scores_details = map(list, zip(*[compute_biodiv_score(output["linear_preds"].detach().numpy()) for output in outputs]))
logging.info(f"Calculated Biodiversity Score : {scores}")
imgs, labels, labeled_imgs = map(list, zip(*[transform_to_pil(output) for output in outputs]))
images = [np.asarray(img) for img in imgs]
labeled_imgs = [np.asarray(img) for img in labeled_imgs]
fig = plot_imgs_labels(dates, images, labeled_imgs, scores_details, scores)
# fig.save("test.png")
return fig
@st.cache_data(hash_funcs={LitUnsupervisedSegmenter: lambda dt: dt.name})
def inference_on_location_and_month(model, longitude = 2.98, latitude = 48.81, start_date = '2020-03-20'):
"""Performe an inference on the latitude and longitude between the start date and the end date
Args:
latitude (float): the latitude of the landscape
longitude (float): the longitude of the landscape
start_date (str): the start date for our inference
end_date (str): the end date for our inference
model (_type_, optional): _description_. Defaults to model.
Returns:
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
"""
logging.info("Running Inference on location and month")
logging.info(f"latitude : {latitude} & longitude : {longitude}")
location = [float(latitude), float(longitude)]
# Extract img numpy from earth engine and transform it to PIL img
end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1)
end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d")
img_test = get_image(location, start_date, end_date)
outputs = inference(np.array([img_test]), model)
logging.info("Calculating Biodiversity Score...")
score, score_details = compute_biodiv_score(outputs[0]["linear_preds"].detach().numpy())
logging.info(f"Calculated Biodiversity Score : {score}")
img, label, labeled_img = transform_to_pil(outputs[0])
fig = plot_image([start_date], [np.asarray(img)], [np.asarray(labeled_img)], [score_details], [score])
return fig
if __name__ == "__main__":
import logging
import hydra
import sys
from model import LitUnsupervisedSegmenter
file_handler = logging.FileHandler(filename='biomap.log')
stdout_handler = logging.StreamHandler(stream=sys.stdout)
handlers = [file_handler, stdout_handler]
logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# Initialize hydra with configs
hydra.initialize(config_path="configs", job_name="corine")
cfg = hydra.compose(config_name="my_train_config.yml")
logging.info(f"config : {cfg}")
# Load the model
nbclasses = cfg.dir_dataset_n_classes
model = LitUnsupervisedSegmenter(nbclasses, cfg)
logging.info(f"Model Initialiazed")
model_path = "biomap/checkpoint/model/model.pt"
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
logging.info(f"Model weights Loaded")
model.load_state_dict(saved_state_dict)
logging.info(f"Model Loaded")
# inference_on_location_and_month(model)
inference_on_location(model)