Spaces:
Runtime error
Runtime error
import torch.multiprocessing | |
import torchvision.transforms as T | |
import numpy as np | |
from utils import transform_to_pil, compute_biodiv_score, plot_imgs_labels, plot_image | |
from utils_gee import get_image | |
from dateutil.relativedelta import relativedelta | |
from model import LitUnsupervisedSegmenter | |
import datetime | |
import matplotlib as mpl | |
from joblib import Parallel, cpu_count, delayed | |
import logging | |
from inference import inference | |
import streamlit as st | |
import cv2 | |
def inference_on_location(model, longitude=2.98, latitude=48.81, start_date=2020, end_date=2022, how="year"): | |
"""Performe an inference on the latitude and longitude between the start date and the end date | |
Args: | |
latitude (float): the latitude of the landscape | |
longitude (float): the longitude of the landscape | |
start_date (str): the start date for our inference | |
end_date (str): the end date for our inference | |
model (_type_, optional): _description_. Defaults to model. | |
Returns: | |
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape | |
""" | |
logging.info("Running Inference on location") | |
logging.info(f"latitude : {latitude} & longitude : {longitude}") | |
logging.info(f"start date : {start_date} & end_date : {end_date}") | |
logging.info(f"Prediction on intervale : {how}") | |
if how == "month": | |
delta_month = 1 | |
elif how == "2months": | |
delta_month = 2 | |
elif how == "year": | |
delta_month = 11 | |
else: | |
raise ValueError("Wrong interval") | |
assert int(end_date) > int(start_date), "end date must be stricly higher than start date" | |
location = [float(latitude), float(longitude)] | |
# Extract img numpy from earth engine and transform it to PIL img | |
dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)] | |
while dates[-1] < datetime.datetime(int(end_date), 1, 1, 0, 0, 0): | |
dates.append(dates[-1] + relativedelta(months=delta_month)) | |
dates = [d.strftime("%Y-%m-%d") for d in dates] | |
all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:])) | |
# all_image = [cv2.imread("output/img.png") for i in range(len(dates))] | |
outputs = inference(np.array(all_image), model) | |
logging.info("Calculating Biodiversity Scores...") | |
scores, scores_details = map(list, zip(*[compute_biodiv_score(output["linear_preds"].detach().numpy()) for output in outputs])) | |
logging.info(f"Calculated Biodiversity Score : {scores}") | |
imgs, labels, labeled_imgs = map(list, zip(*[transform_to_pil(output) for output in outputs])) | |
images = [np.asarray(img) for img in imgs] | |
labeled_imgs = [np.asarray(img) for img in labeled_imgs] | |
fig = plot_imgs_labels(dates, images, labeled_imgs, scores_details, scores) | |
# fig.save("test.png") | |
return fig | |
def inference_on_location_and_month(model, longitude = 2.98, latitude = 48.81, start_date = '2020-03-20'): | |
"""Performe an inference on the latitude and longitude between the start date and the end date | |
Args: | |
latitude (float): the latitude of the landscape | |
longitude (float): the longitude of the landscape | |
start_date (str): the start date for our inference | |
end_date (str): the end date for our inference | |
model (_type_, optional): _description_. Defaults to model. | |
Returns: | |
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape | |
""" | |
logging.info("Running Inference on location and month") | |
logging.info(f"latitude : {latitude} & longitude : {longitude}") | |
location = [float(latitude), float(longitude)] | |
# Extract img numpy from earth engine and transform it to PIL img | |
end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1) | |
end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d") | |
img_test = get_image(location, start_date, end_date) | |
outputs = inference(np.array([img_test]), model) | |
logging.info("Calculating Biodiversity Score...") | |
score, score_details = compute_biodiv_score(outputs[0]["linear_preds"].detach().numpy()) | |
logging.info(f"Calculated Biodiversity Score : {score}") | |
img, label, labeled_img = transform_to_pil(outputs[0]) | |
fig = plot_image([start_date], [np.asarray(img)], [np.asarray(labeled_img)], [score_details], [score]) | |
return fig | |
if __name__ == "__main__": | |
import logging | |
import hydra | |
import sys | |
from model import LitUnsupervisedSegmenter | |
file_handler = logging.FileHandler(filename='biomap.log') | |
stdout_handler = logging.StreamHandler(stream=sys.stdout) | |
handlers = [file_handler, stdout_handler] | |
logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
# Initialize hydra with configs | |
hydra.initialize(config_path="configs", job_name="corine") | |
cfg = hydra.compose(config_name="my_train_config.yml") | |
logging.info(f"config : {cfg}") | |
# Load the model | |
nbclasses = cfg.dir_dataset_n_classes | |
model = LitUnsupervisedSegmenter(nbclasses, cfg) | |
logging.info(f"Model Initialiazed") | |
model_path = "biomap/checkpoint/model/model.pt" | |
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
logging.info(f"Model weights Loaded") | |
model.load_state_dict(saved_state_dict) | |
logging.info(f"Model Loaded") | |
# inference_on_location_and_month(model) | |
inference_on_location(model) | |