Spaces:
Sleeping
Sleeping
| from datasets import load_dataset, get_dataset_config_names | |
| from functools import partial | |
| from pandas import DataFrame | |
| from PIL import Image | |
| import gradio as gr | |
| import numpy as np | |
| import tqdm | |
| import json | |
| import os | |
| DATASET = "satellogic/EarthView" | |
| DEBUG = False | |
| sets = { | |
| "satellogic": { | |
| "shards" : 3676, | |
| }, | |
| "sentinel_1": { | |
| "shards" : 1763, | |
| }, | |
| "neon": { | |
| "config" : "default", | |
| "shards" : 607, | |
| "path" : "data", | |
| } | |
| } | |
| def open_dataset(dataset, set_name, split, batch_size, state, shard = -1): | |
| if shard == -1: | |
| # Trick to open the whole dataset | |
| data_files = None | |
| shards = 100 | |
| else: | |
| config = sets[set_name].get("config", set_name) | |
| shards = sets[set_name]["shards"] | |
| path = sets[set_name].get("path", set_name) | |
| data_files = {"train":[f"{path}/{split}-{shard:05d}-of-{shards:05d}.parquet"]} | |
| if DEBUG: | |
| ds = lambda:None | |
| ds.n_shards = 1234 | |
| dsi = range(100) | |
| else: | |
| ds = load_dataset( | |
| dataset, | |
| config, | |
| split=split, | |
| cache_dir="dataset", | |
| data_files=data_files, | |
| streaming=True, | |
| token=os.environ.get("HF_TOKEN", None)) | |
| dsi = iter(ds) | |
| state["config"] = config | |
| state["dsi"] = dsi | |
| return ( | |
| gr.update(label=f"Shards (max {shards})", value=shard, maximum=shards), | |
| *get_images(batch_size, state), | |
| state | |
| ) | |
| def item_to_images(config, item): | |
| metadata = item["metadata"] | |
| if type(metadata) == str: | |
| metadata = json.loads(metadata) | |
| item = { | |
| k: np.asarray(v).astype("uint8") | |
| for k,v in item.items() | |
| if k != "metadata" | |
| } | |
| item["metadata"] = metadata | |
| if config == "satellogic": | |
| item["rgb"] = [ | |
| Image.fromarray(image.transpose(1,2,0)) | |
| for image in item["rgb"] | |
| ] | |
| item["1m"] = [ | |
| Image.fromarray(image[0,:,:]) | |
| for image in item["1m"] | |
| ] | |
| elif config == "sentinel_1": | |
| # Mapping of V and H to RGB. May not be correct | |
| # https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels | |
| i10m = item["10m"] | |
| i10m = np.concatenate( | |
| ( i10m, | |
| np.expand_dims( | |
| i10m[:,0,:,:]/(i10m[:,1,:,:]+0.01)*256, | |
| 1 | |
| ).astype("uint8") | |
| ), | |
| 1 | |
| ) | |
| item["10m"] = [ | |
| Image.fromarray(image.transpose(1,2,0)) | |
| for image in i10m | |
| ] | |
| elif config == "default": | |
| item["rgb"] = [ | |
| Image.fromarray(image.transpose(1,2,0)) | |
| for image in item["rgb"] | |
| ] | |
| item["chm"] = [ | |
| Image.fromarray(image[0]) | |
| for image in item["chm"] | |
| ] | |
| # The next is a very arbitrary conversion from the 369 hyperspectral data to RGB | |
| # It just averages each 1/3 of the bads and assigns it to a channel | |
| item["1m"] = [ | |
| Image.fromarray( | |
| np.concatenate(( | |
| np.expand_dims(np.average(image[:124],0),2), | |
| np.expand_dims(np.average(image[124:247],0),2), | |
| np.expand_dims(np.average(image[247:],0),2)) | |
| ,2).astype("uint8")) | |
| for image in item["1m"] | |
| ] | |
| return item | |
| def get_images(batch_size, state): | |
| config = state["config"] | |
| images = [] | |
| metadatas = [] | |
| for i in tqdm.trange(batch_size, desc=f"Getting images"): | |
| if DEBUG: | |
| image = np.random.randint(0,255,(384,384,3)) | |
| metadata = {"bounds":[[1,1,4,4]], } | |
| else: | |
| try: | |
| item = next(state["dsi"]) | |
| except StopIteration: | |
| break | |
| metadata = item["metadata"] | |
| item = item_to_images(config, item) | |
| if config == "satellogic": | |
| images.extend(item["rgb"]) | |
| images.extend(item["1m"]) | |
| if config == "sentinel_1": | |
| images.extend(item["10m"]) | |
| if config == "default": | |
| images.extend(item["rgb"]) | |
| images.extend(item["chm"]) | |
| images.extend(item["1m"]) | |
| metadatas.append(item["metadata"]) | |
| return images, DataFrame(metadatas) | |
| def update_shape(rows, columns): | |
| return gr.update(rows=rows, columns=columns) | |
| def new_state(): | |
| return gr.State({}) | |
| if __name__ == "__main__": | |
| with gr.Blocks(title="Dataset Explorer", fill_height = True) as demo: | |
| state = new_state() | |
| gr.Markdown(f"# Viewer for [{DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset") | |
| batch_size = gr.Number(10, label = "Batch Size", render=False) | |
| shard = gr.Slider(label="Shard", minimum=0, maximum=10000, step=1, render=False) | |
| table = gr.DataFrame(render = False) | |
| # headers=["Index","TimeStamp","Bounds","CRS"], | |
| gallery = gr.Gallery( | |
| label=DATASET, | |
| interactive=False, | |
| columns=5, rows=2, render=False) | |
| with gr.Row(): | |
| dataset = gr.Textbox(label="Dataset", value=DATASET, interactive=False) | |
| config = gr.Dropdown(choices=sets.keys(), label="Config", value="satellogic", ) | |
| split = gr.Textbox(label="Split", value="train") | |
| initial_shard = gr.Number(label = "Initial shard", value=0, info="-1 for whole dataset") | |
| gr.Button("Load (minutes)").click( | |
| open_dataset, | |
| inputs=[dataset, config, split, batch_size, state, initial_shard], | |
| outputs=[shard, gallery, table, state]) | |
| gallery.render() | |
| with gr.Row(): | |
| batch_size.render() | |
| rows = gr.Number(2, label="Rows") | |
| columns = gr.Number(5, label="Coluns") | |
| rows.change(update_shape, [rows, columns], [gallery]) | |
| columns.change(update_shape, [rows, columns], [gallery]) | |
| with gr.Row(): | |
| shard.render() | |
| shard.release( | |
| open_dataset, | |
| inputs=[dataset, config, split, batch_size, state, shard], | |
| outputs=[shard, gallery, table, state]) | |
| btn = gr.Button("Next Batch (same shard)", scale=0) | |
| btn.click(get_images, [batch_size, state], [gallery, table]) | |
| btn.click() | |
| table.render() | |
| demo.launch(show_api=False) | |