Spaces:
Sleeping
Sleeping
| from datasets import load_dataset, IterableDataset | |
| from functools import partial | |
| from pandas import DataFrame | |
| import gradio as gr | |
| import numpy as np | |
| import tqdm | |
| import json | |
| import os | |
| DEBUG = False | |
| sets = { | |
| "satellogic": { | |
| "shards" : 3676, | |
| }, | |
| "sentinel_1": { | |
| "shards" : 1763, | |
| }, | |
| "neon": { | |
| "config" : "default", | |
| "shards" : 607, | |
| "path" : "data", | |
| } | |
| } | |
| def open_dataset(dataset, set_name, split, batch_size, shard = -1): | |
| # I should really move ds.config_name and dsi to a gr.State() | |
| global dsi, ds | |
| if shard == -1: | |
| data_files = None | |
| shards = 100 | |
| else: | |
| config = sets[set_name].get("config", set_name) | |
| shards = sets[set_name]["shards"] | |
| path = sets[set_name].get("path", set_name) | |
| data_files = {"train":[f"{path}/{split}-{shard:05d}-of-{shards:05d}.parquet"]} | |
| if DEBUG: | |
| ds = lambda:None | |
| ds.n_shards = 1234 | |
| dsi = range(100) | |
| else: | |
| ds = load_dataset( | |
| dataset, | |
| config, | |
| split=split, | |
| cache_dir="dataset", | |
| data_files=data_files, | |
| streaming=True, | |
| token=os.environ.get("HF_TOKEN", None)) | |
| dsi = iter(ds) | |
| return ( | |
| gr.update(label=f"Shards (max {shards})", value=shard, maximum=shards), | |
| *get_images(batch_size) | |
| ) | |
| def get_images(batch_size): | |
| global dsi, ds | |
| items = [] | |
| metadatas = [] | |
| for i in tqdm.trange(batch_size, desc=f"Getting images"): | |
| if DEBUG: | |
| image = np.random.randint(0,255,(384,384,3)) | |
| metadata = {"bounds":[[1,1,4,4]], } | |
| else: | |
| try: | |
| item = next(dsi) | |
| except StopIteration: | |
| break | |
| metadata = item["metadata"] | |
| if ds.config_name == "satellogic": | |
| # image = (np.asarray(item["1m"])).astype("uint8") | |
| # items.append(image[0,0,:,:]) | |
| image = np.asarray(item["rgb"][0]).astype(np.uint8) | |
| items.append(image.transpose(1,2,0)) | |
| if ds.config_name == "sentinel_1": | |
| metadata = json.loads(metadata) | |
| data = np.asarray(item["10m"]) | |
| for i in range(data.shape[0]): | |
| # Mapping of V and H to RGB. May not be correct | |
| # https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels | |
| image = np.zeros((3,384,384), "uint8") | |
| image[0] = data[i][0] | |
| image[1] = data[i][1] | |
| image[2] = (image[0]/(image[1]+0.1))*256 | |
| items.append(image.transpose(1,2,0)) | |
| if ds.config_name == "default": | |
| dataRGB = np.asarray(item["rgb"]).astype("uint8") | |
| dataCHM = np.asarray(item["chm"]).astype("uint8") | |
| data1m = np.asarray(item["1m"]).astype("uint8") | |
| for i in range(dataRGB.shape[0]): | |
| image = dataRGB[i,:,:,:] | |
| items.append(image.transpose(1,2,0)) | |
| image = dataCHM[i,0,:,:] | |
| items.append(image) | |
| image = data1m[i,0,:,:] | |
| items.append(image) | |
| metadatas.append(metadata) | |
| return items, DataFrame(metadatas) | |
| def skip(count, batch_size): | |
| global dsi | |
| skip = count*batch_size | |
| gr.Info(f"Skipping {skip} images (it's slow)") | |
| for i in tqdm.trange(skip, desc=f"Skipping {skip} images"): | |
| if DEBUG: | |
| pass | |
| else: | |
| next(dsi) | |
| return get_images(batch_size) | |
| def update_shape(rows, columns): | |
| return gr.update(rows=rows, columns=columns) | |
| with gr.Blocks(title="Dataset Explorer", fill_height = True) as demo: | |
| gr.Markdown("# [satellogic/EarthView](https://huggingface.co/datasets/satellogic/EarthView) Dataset Viewer") | |
| batch_size = gr.Number(10, label = "Batch Size", render=False) | |
| shard = gr.Slider(label="Shard", minimum=0, maximum=10000, step=1, render=False) | |
| table = gr.DataFrame(render = False) | |
| # headers=["Index","TimeStamp","Bounds","CRS"], | |
| gallery = gr.Gallery( | |
| label="satellogic/EarthView", | |
| interactive=False, | |
| columns=5, rows=2, render=False) | |
| with gr.Row(): | |
| dataset = gr.Textbox(label="Dataset", value="satellogic/EarthView") | |
| config = gr.Dropdown(choices=["satellogic", "sentinel_1", "neon"], label="Subset", value="satellogic", ) | |
| split = gr.Textbox(label="Split", value="train") | |
| initial_shard = gr.Number(label = "Initial shard", value=0) | |
| gr.Button("Load (minutes)").click( | |
| open_dataset, | |
| inputs=[dataset, config, split, batch_size, initial_shard], | |
| outputs=[shard, gallery, table]) | |
| gallery.render() | |
| with gr.Row(): | |
| batch_size.render() | |
| rows = gr.Number(2, label="Rows") | |
| columns = gr.Number(5, label="Coluns") | |
| rows.change(update_shape, [rows, columns], [gallery]) | |
| columns.change(update_shape, [rows, columns], [gallery]) | |
| with gr.Row(): | |
| shard.render() | |
| shard.release( | |
| open_dataset, | |
| inputs=[dataset, config, split, batch_size, shard], | |
| outputs=[shard, gallery, table]) | |
| btn = gr.Button("Get More Images", scale=0) | |
| btn.click(get_images, [batch_size], [gallery, table]) | |
| btn.click() | |
| # btn = gr.Button("Skip 10 Batches", scale=0) | |
| # btn.click(partial(skip, 10), [batch], gallery) | |
| # btn = gr.Button("Skip 25 Batches", scale=0) | |
| # btn.click(partial(skip, 25), [batch], gallery) | |
| table.render() | |
| demo.launch(show_api=False) | |