import torch.multiprocessing import torchvision.transforms as T from utils import transform_to_pil def inference(image, model): # tensorize & normalize img preprocess = T.Compose( [ T.ToPILImage(), T.Resize((320, 320)), # T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) # Preprocess opened img x = preprocess(image) # launch inference on cpu x = torch.unsqueeze(x, dim=0).cpu() model = model.cpu() with torch.no_grad(): feats, code = model.net(x) linear_pred = model.linear_probe(x, code) linear_pred = linear_pred.argmax(1) output = { "img": x[: model.cfg.n_images].detach().cpu(), "linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(), } img, label, labeled_img = transform_to_pil(output) return img, labeled_img, label if __name__ == "__main__": import hydra from model import LitUnsupervisedSegmenter from utils_gee import extract_img, transform_ee_img latitude = 2.98 longitude = 48.81 start_date = '2020-03-20' end_date = '2020-04-20' location = [float(latitude), float(longitude)] # Extract img numpy from earth engine and transform it to PIL img img = extract_img(location, start_date, end_date) image = transform_ee_img( img, max=0.3 ) # max value is the value from numpy file that will be equal to 255 print("image loaded") # Initialize hydra with configs hydra.initialize(config_path="configs", job_name="corine") cfg = hydra.compose(config_name="my_train_config.yml") # Load the model model_path = "checkpoint/model/model.pt" saved_state_dict = torch.load(model_path, map_location=torch.device("cpu")) nbclasses = cfg.dir_dataset_n_classes model = LitUnsupervisedSegmenter(nbclasses, cfg) print("model initialized") model.load_state_dict(saved_state_dict) print("model loaded") # img.save("output/image.png") img, labeled_img, label = inference(image, model) img.save("output/img.png") label.save("output/label.png") labeled_img.save("output/labeled_img.png") # def get_list_date(start_date, end_date): # """Get all the date between the start date and the end date # Args: # start_date (str): start date at the format '%Y-%m-%d' # end_date (str): end date at the format '%Y-%m-%d' # Returns: # list[str]: all the date between the start date and the end date # """ # start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d").date() # end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d").date() # list_date = [start_date] # date = start_date # while date < end_date: # date = date + datetime.timedelta(days=1) # list_date.append(date) # list_date.append(end_date) # list_date2 = [x.strftime("%Y-%m-%d") for x in list_date] # return list_date2 # def get_length_interval(start_date, end_date): # """Return how many days there is between the start date and the end date # Args: # start_date (str): start date at the format '%Y-%m-%d' # end_date (str): end date at the format '%Y-%m-%d' # Returns: # int : number of days between start date and the end date # """ # try: # return len(get_list_date(start_date, end_date)) # except ValueError: # return 0 # def infer_unique_date(latitude, longitude, date, model=model): # """Perform an inference on a latitude and a longitude at a specific date # Args: # latitude (float): the latitude of the landscape # longitude (float): the longitude of the landscape # date (str): date for the inference at the format '%Y-%m-%d' # model (_type_, optional): _description_. Defaults to model. # Returns: # img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape # """ # start_date = date # end_date = date # location = [float(latitude), float(longitude)] # # Extract img numpy from earth engine and transform it to PIL img # img = extract_img(location, start_date, end_date) # img_test = transform_ee_img( # img, max=0.3 # ) # max value is the value from numpy file that will be equal to 255 # # tensorize & normalize img # preprocess = T.Compose( # [ # T.ToPILImage(), # T.Resize((320, 320)), # # T.CenterCrop(224), # T.ToTensor(), # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ] # ) # # Preprocess opened img # x = preprocess(img_test) # # launch inference on cpu # x = torch.unsqueeze(x, dim=0).cpu() # model = model.cpu() # with torch.no_grad(): # feats, code = model.net(x) # linear_pred = model.linear_probe(x, code) # linear_pred = linear_pred.argmax(1) # output = { # "img": x[: model.cfg.n_images].detach().cpu(), # "linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(), # } # img, label, labeled_img = transform_to_pil(output) # biodiv_score = compute_biodiv_score(labeled_img) # return img, labeled_img, biodiv_score # def get_img_array(start_date, end_date, latitude, longitude, model=model): # list_date = get_list_date(start_date, end_date) # list_img = [] # for date in list_date: # list_img.append(img) # return list_img # def variable_outputs(start_date, end_date, latitude, longitude, day, model=model): # """Perform an inference on the day number day starting from the start at the latitude and longitude selected # Args: # latitude (float): the latitude of the landscape # longitude (float): the longitude of the landscape # start_date (str): the start date for our inference # end_date (str): the end date for our inference # model (_type_, optional): _description_. Defaults to model. # Returns: # img,labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape at the selected, longitude, latitude and date # """ # list_date = get_list_date(start_date, end_date) # k = int(day) # date = list_date[k] # img, labeled_img, biodiv_score = infer_unique_date( # latitude, longitude, date, model=model # ) # return img, labeled_img, biodiv_score # def variable_outputs2( # start_date, end_date, latitude, longitude, day_number, model=model # ): # """Perform an inference on the day number day starting from the start at the latitude and longitude selected # Args: # latitude (float): the latitude of the landscape # longitude (float): the longitude of the landscape # start_date (str): the start date for our inference # end_date (str): the end date for our inference # model (_type_, optional): _description_. Defaults to model. # Returns: # list[img,labeled_img,biodiv_score]: the original landscape, the labeled landscape and the biodiversity score and the landscape at the selected, longitude, latitude and date # """ # list_date = get_list_date(start_date, end_date) # k = int(day_number) # date = list_date[k] # img, labeled_img, biodiv_score = infer_unique_date( # latitude, longitude, date, model=model # ) # return [img, labeled_img, biodiv_score] # def gif_maker(img_array): # output_file = "test2.mkv" # image_test = img_array[0] # size = (320, 320) # print(size) # out = cv2.VideoWriter( # output_file, cv2.VideoWriter_fourcc(*"avc1"), 15, frameSize=size # ) # for i in range(len(img_array)): # image = img_array[i] # pix = np.array(image.getdata()) # out.write(pix) # out.release() # return output_file # def infer_multiple_date(start_date, end_date, latitude, longitude, model=model): # """Perform an inference on all the dates between the start date and the end date at the latitude and longitude # Args: # latitude (float): the latitude of the landscape # longitude (float): the longitude of the landscape # start_date (str): the start date for our inference # end_date (str): the end date for our inference # model (_type_, optional): _description_. Defaults to model. # Returns: # list_img,list_labeled_img,list_biodiv_score: list of the original landscape, the labeled landscape and the biodiversity score and the landscape # """ # list_date = get_list_date(start_date, end_date) # list_img = [] # list_labeled_img = [] # list_biodiv_score = [] # for date in list_date: # img, labeled_img, biodiv_score = infer_unique_date( # latitude, longitude, date, model=model # ) # list_img.append(img) # list_labeled_img.append(labeled_img) # list_biodiv_score.append(biodiv_score) # return gif_maker(list_img), gif_maker(list_labeled_img), list_biodiv_score[0]