Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile | |
| import uvicorn | |
| from typing import List | |
| from io import BytesIO | |
| import numpy as np | |
| import rasterio | |
| from pydantic import BaseModel | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from mmcv import Config | |
| from mmseg.apis import init_segmentor | |
| import gradio as gr | |
| from functools import partial | |
| import time | |
| import os | |
| # Initialize the FastAPI app | |
| app = FastAPI() | |
| # Load the model and config | |
| config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", | |
| filename="multi_temporal_crop_classification_Prithvi_100M.py", | |
| token=os.environ.get("token")) | |
| ckpt = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", | |
| filename='multi_temporal_crop_classification_Prithvi_100M.pth', | |
| token=os.environ.get("token")) | |
| config = Config.fromfile(config_path) | |
| config.model.backbone.pretrained = None | |
| model = init_segmentor(config, ckpt, device='cpu') | |
| # Use the test pipeline directly | |
| custom_test_pipeline = model.cfg.data.test.pipeline | |
| # Define the input/output model for FastAPI | |
| class PredictionOutput(BaseModel): | |
| t1: List[float] | |
| t2: List[float] | |
| t3: List[float] | |
| prediction: List[float] | |
| # Define the inference function | |
| def inference_on_file(file_path, model, custom_test_pipeline): | |
| with rasterio.open(file_path) as src: | |
| img = src.read() | |
| # Apply preprocessing using the custom pipeline | |
| processed_img = apply_pipeline(custom_test_pipeline, img) | |
| # Run inference | |
| output = model.inference(processed_img) | |
| # Post-process the output to get the RGB and prediction images | |
| rgb1 = postprocess_output(output[0]) | |
| rgb2 = postprocess_output(output[1]) | |
| rgb3 = postprocess_output(output[2]) | |
| return rgb1, rgb2, rgb3, output | |
| def apply_pipeline(pipeline, img): | |
| # Implement your custom pipeline processing here | |
| # This could include normalization, resizing, etc. | |
| return img | |
| def postprocess_output(output): | |
| # Convert the model's output into an RGB image or other formats as needed | |
| return output | |
| async def predict(file: UploadFile = File(...)): | |
| # Read the uploaded file | |
| target_image = BytesIO(await file.read()) | |
| # Save the file temporarily if needed | |
| with open("temp_image.tif", "wb") as f: | |
| f.write(target_image.getvalue()) | |
| # Run the prediction | |
| rgb1, rgb2, rgb3, output = inference_on_file("temp_image.tif", model, custom_test_pipeline) | |
| # Return the results | |
| return { | |
| "t1": rgb1.tolist(), | |
| "t2": rgb2.tolist(), | |
| "t3": rgb3.tolist(), | |
| "prediction": output.tolist() | |
| } | |
| # Optional: Serve the Gradio interface (if you still want to use it with FastAPI) | |
| def run_gradio_interface(): | |
| func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(value='# Prithvi multi temporal crop classification') | |
| gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to classify crop and other land use categories using multi temporal data. More details can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification).\n | |
| The user needs to provide an HLS geotiff image, including 18 bands for 3 time-step, and each time-step includes the channels described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order.''') | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.File() | |
| btn = gr.Button("Submit") | |
| with gr.Row(): | |
| inp1 = gr.Image(image_mode='RGB', scale=10, label='T1') | |
| inp2 = gr.Image(image_mode='RGB', scale=10, label='T2') | |
| inp3 = gr.Image(image_mode='RGB', scale=10, label='T3') | |
| out = gr.Image(image_mode='RGB', scale=10, label='Model prediction') | |
| btn.click(fn=func, inputs=inp, outputs=[inp1, inp2, inp3, out]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Examples(examples=["chip_102_345_merged.tif", | |
| "chip_104_104_merged.tif", | |
| "chip_109_421_merged.tif"], | |
| inputs=inp, | |
| outputs=[inp1, inp2, inp3, out], | |
| preprocess=preprocess_example, | |
| fn=func, | |
| cache_examples=True) | |
| with gr.Column(): | |
| gr.Markdown(value='### Model prediction legend') | |
| gr.Image(value='Legend.png', image_mode='RGB', show_label=False) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| run_gradio_interface() | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |