from datasets import load_dataset as _load_dataset from os import environ from PIL import Image import numpy as np import json from pyarrow.parquet import ParquetFile from pyarrow import Table as pa_Table from datasets import Dataset DATASET = "satellogic/EarthView" sets = { "satellogic": { "shards" : 7863, }, "sentinel_1": { "shards" : 1763, }, "neon": { "config" : "default", "shards" : 607, "path" : "data", }, "sentinel_2": { "shards" : 19997, }, } def get_subsets(): return sets.keys() def get_nshards(subset): return sets[subset]["shards"] def get_path(subset): return sets[subset].get("path", subset) def get_config(subset): return sets[subset].get("config", subset) def load_dataset(subset, dataset="satellogic/EarthView", split="train", shards = None, streaming=True, **kwargs): config = get_config(subset) nshards = get_nshards(subset) path = get_path(subset) if shards is None: data_files = None else: if subset == "sentinel_2": data_files = [f"{path}/sentinel_2-{shard//10}/{split}-{shard % 10:05d}-of-00010.parquet" for shard in shards] else: data_files = [f"{path}/{split}-{shard:05d}-of-{nshards:05d}.parquet" for shard in shards] data_files = {split: data_files} ds = _load_dataset( path=dataset, name=config, save_infos=True, split=split, data_files=data_files, streaming=streaming, token=environ.get("HF_TOKEN", None), **kwargs) return ds def load_parquet(subset_or_filename, batch_size=100): if subset_or_filename in get_subsets(): filename = f"dataset/{subset_or_filename}/sample.parquet" else: filename = subset_or_filename pqfile = ParquetFile(filename) batch = pqfile.iter_batches(batch_size=batch_size) return Dataset(pa_Table.from_batches(batch)) def item_to_images(subset, item): """ Converts the images within an item (arrays), as retrieved from the dataset to proper PIL.Image subset: The name of the Subset, one of "satellogic", "neon", "sentinel-1" item: The item as retrieved from the subset returns the item, with arrays converted to PIL.Image """ metadata = item["metadata"] if type(metadata) == str: metadata = json.loads(metadata) item = { k: np.asarray(v).astype("uint8") for k,v in item.items() if k != "metadata" } item["metadata"] = metadata if subset == "satellogic": # item["rgb"] = [ # Image.fromarray(np.average(image.transpose(1,2,0), 2).astype("uint8")) # for image in item["rgb"] # ] rgbs = [] for rgb in item["rgb"]: rgbs.append(Image.fromarray(rgb.transpose(1,2,0))) # rgbs.append(Image.fromarray(rgb[0,:,:])) # Red # rgbs.append(Image.fromarray(rgb[1,:,:])) # Green # rgbs.append(Image.fromarray(rgb[2,:,:])) # Blue item["rgb"] = rgbs item["1m"] = [ Image.fromarray(image[0,:,:]) for image in item["1m"] ] count = len(item["1m"]) elif subset == "sentinel_1": # Mapping of V and H to RGB. May not be correct # https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels i10m = item["10m"] i10m = np.concatenate( ( i10m, np.expand_dims( i10m[:,0,:,:]/(i10m[:,1,:,:]+0.01)*256, 1 ).astype("uint8") ), 1 ) item["10m"] = [ Image.fromarray(image.transpose(1,2,0)) for image in i10m ] count = len(item["10m"]) elif subset == "sentinel_2": for channel in ['10m', '20m', 'rgb', 'scl']: #, '40m']: data = item[channel] count = len(data) data = np.asarray(data).astype("uint8").transpose(0,2,3,1) if channel == "20m": data = data[:,:,:,[0,2,4]] mode = "L" if channel in ["10m", "scl"] else "RGB" images = [Image.fromarray(data[i].squeeze(), mode=mode) for i in range(count)] item[channel] = images for field in ["solarAngles", "tileGeometry", "viewIncidenceAngles"]: item["metadata"][field] = [json.loads(s) for s in item["metadata"][field]] elif subset == "neon": item["rgb"] = [ Image.fromarray(image.transpose(1,2,0)) for image in item["rgb"] ] item["chm"] = [ Image.fromarray(image[0]) for image in item["chm"] ] # The next is a very arbitrary conversion from the 369 hyperspectral data to RGB # It just averages each 1/3 of the bads and assigns it to a channel item["1m"] = [ Image.fromarray( np.concatenate(( np.expand_dims(np.average(image[:124],0),2), np.expand_dims(np.average(image[124:247],0),2), np.expand_dims(np.average(image[247:],0),2)) ,2).astype("uint8")) for image in item["1m"] ] count = len(item["rgb"]) bounds = item["metadata"]["bounds"] # swap pairs item["metadata"]["bounds"] = [bounds[i+1-l] for i in range(0, len(bounds), 2) for l in range(2)] # fix CRS item["metadata"]["epsg"] = "EPSG:4326" item["metadata"]["count"] = count return item