import streamlit as st from streamlit_folium import st_folium import folium import logging import sys import hydra from plot_functions import * import hydra import torch from model import LitUnsupervisedSegmenter from helper import inference_on_location_and_month, inference_on_location DEFAULT_LATITUDE = 48.81 DEFAULT_LONGITUDE = 2.98 DEFAULT_ZOOM = 5 MIN_YEAR = 2018 MAX_YEAR = 2024 FOLIUM_WIDTH = 925 FOLIUM_HEIGHT = 300 st.set_page_config(layout="wide") @st.cache_resource def init_cfg(cfg_name): hydra.initialize(config_path="configs", job_name="corine") return hydra.compose(config_name=cfg_name) @st.cache_resource def init_app(cfg_name) -> LitUnsupervisedSegmenter: file_handler = logging.FileHandler(filename='biomap.log') stdout_handler = logging.StreamHandler(stream=sys.stdout) handlers = [file_handler, stdout_handler] logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") # # Initialize hydra with configs # GlobalHydra.instance().clear() cfg = init_cfg(cfg_name) logging.info(f"config : {cfg}") nbclasses = cfg.dir_dataset_n_classes model = LitUnsupervisedSegmenter(nbclasses, cfg) model = model.cpu() logging.info(f"Model Initialiazed") model_path = "biomap/checkpoint/model/model.pt" saved_state_dict = torch.load(model_path, map_location=torch.device("cpu")) logging.info(f"Model weights Loaded") model.load_state_dict(saved_state_dict) return model def app(model): if "infered" not in st.session_state: st.session_state["infered"] = False if "submit" not in st.session_state: st.session_state["submit"] = False if "submit2" not in st.session_state: st.session_state["submit2"] = False st.markdown("
The segmentation model is an association of UNet and DinoV1 trained on the dataset CORINE. Land use is divided into 6 differents classes : Each class is assigned a GBS score from 0 to 1
", unsafe_allow_html=True) st.markdown("Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1
", unsafe_allow_html=True) st.markdown("The score is then averaged on the full image.
", unsafe_allow_html=True) if st.session_state["submit"]: fig = inference_on_location(model, st.session_state["lat"], st.session_state["long"], st.session_state["start_date"], st.session_state["end_date"], st.session_state["segment_interval"]) st.session_state["infered"] = True st.session_state["previous_fig"] = fig if st.session_state["submit2"]: fig = inference_on_location_and_month(model, st.session_state["lat_2"], st.session_state["long_2"], st.session_state["date_2"]) st.session_state["infered"] = True st.session_state["previous_fig"] = fig if st.session_state["infered"]: st.plotly_chart(st.session_state["previous_fig"], use_container_width=True) col_1, col_2 = st.columns([0.5, 0.5]) with col_1: m = folium.Map(location=[DEFAULT_LATITUDE, DEFAULT_LONGITUDE], zoom_start=DEFAULT_ZOOM) m.add_child(folium.LatLngPopup()) f_map = st_folium(m, width=FOLIUM_WIDTH, height=FOLIUM_HEIGHT) selected_latitude = DEFAULT_LATITUDE selected_longitude = DEFAULT_LONGITUDE if f_map.get("last_clicked"): selected_latitude = f_map["last_clicked"]["lat"] selected_longitude = f_map["last_clicked"]["lng"] with col_2: tabs1, tabs2 = st.tabs(["TimeLapse", "Single Image"]) with tabs1: submit = st.button("Predict TimeLapse", use_container_width=True, type="primary") st.session_state["submit"] = submit col_tab1_1, col_tab1_2 = st.columns(2) with col_tab1_1: lat = st.text_input("latitude", value=selected_latitude) st.session_state["lat"] = lat with col_tab1_2: long = st.text_input("longitude", value=selected_longitude) st.session_state["long"] = long col_tab1_11, col_tab1_22 = st.columns(2) years = list(range(MIN_YEAR, MAX_YEAR, 1)) with col_tab1_11: start_date = st.selectbox("Start date", years) st.session_state["start_date"] = start_date end_years = [year for year in years if year > start_date] with col_tab1_22: end_date = st.selectbox("End date", end_years) st.session_state["end_date"] = end_date segment_interval = st.radio("Interval of time between two segmentation", options=['month','2months', 'year'],horizontal=True) st.session_state["segment_interval"] = segment_interval with tabs2: submit2 = st.button("Predict Single Image", use_container_width=True, type="primary") st.session_state["submit2"] = submit2 col_tab2_1, col_tab2_2 = st.columns(2) with col_tab2_1: lat_2 = st.text_input("lat.", value=selected_latitude) st.session_state["lat_2"] = lat_2 with col_tab2_2: long_2 = st.text_input("long.", value=selected_longitude) st.session_state["long_2"] = long_2 date_2 = st.text_input("date", "2021-01-01", placeholder="2021-01-01") st.session_state["date_2"] = date_2 if __name__ == "__main__": model = init_app("my_train_config.yml") app(model)