import torch.multiprocessing import torchvision.transforms as T from utils import transform_to_pil import logging preprocess = T.Compose( [ T.ToPILImage(), T.Resize((320, 320)), # T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) import numpy as np def inference(images, model): logging.info("Inference on Images") x = torch.stack([preprocess(image) for image in images]).cpu() with torch.no_grad(): _, code = model.net(x) linear_pred = model.linear_probe(x, code) linear_pred = linear_pred.argmax(1) outputs = [{ "img": x[i].detach().cpu(), "linear_preds": linear_pred[i].detach().cpu(), } for i in range(x.shape[0])] # water to natural green for output in outputs: output["linear_preds"] = torch.where(output["linear_preds"] == 5, 3, output["linear_preds"]) return outputs if __name__ == "__main__": import hydra from model import LitUnsupervisedSegmenter from utils_gee import extract_img, transform_ee_img import os latitude = 2.98 longitude = 48.81 start_date = '2020-03-20' end_date = '2020-04-20' location = [float(latitude), float(longitude)] # Extract img numpy from earth engine and transform it to PIL img img = extract_img(location, start_date, end_date) image = transform_ee_img( img, max=0.3 ) # max value is the value from numpy file that will be equal to 255 print("image loaded") # Initialize hydra with configs hydra.initialize(config_path="configs", job_name="corine") cfg = hydra.compose(config_name="my_train_config.yml") # Load the model model_path = os.path.join(os.path.dirname(__file__), "checkpoint/model/model.pt") saved_state_dict = torch.load(model_path, map_location=torch.device("cpu")) nbclasses = cfg.dir_dataset_n_classes model = LitUnsupervisedSegmenter(nbclasses, cfg) print("model initialized") model.load_state_dict(saved_state_dict) print("model loaded") # img.save("output/image.png") inference([image], model) inference([image,image], model)