Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,92 +1,43 @@
|
|
| 1 |
-
|
| 2 |
-
import
|
| 3 |
-
from typing import List
|
| 4 |
-
from io import BytesIO
|
| 5 |
-
import numpy as np
|
| 6 |
-
import rasterio
|
| 7 |
-
from pydantic import BaseModel
|
| 8 |
-
import torch
|
| 9 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from mmcv import Config
|
| 11 |
-
from mmseg.apis import init_segmentor
|
| 12 |
-
import gradio as gr
|
| 13 |
-
from functools import partial
|
| 14 |
-
import time
|
| 15 |
-
import os
|
| 16 |
|
| 17 |
-
|
| 18 |
-
app = FastAPI()
|
| 19 |
|
| 20 |
-
|
| 21 |
-
config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
|
| 22 |
-
filename="multi_temporal_crop_classification_Prithvi_100M.py",
|
| 23 |
-
token=os.environ.get("token"))
|
| 24 |
-
ckpt = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
|
| 25 |
-
filename='multi_temporal_crop_classification_Prithvi_100M.pth',
|
| 26 |
-
token=os.environ.get("token"))
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
model = init_segmentor(config, ckpt, device='cpu')
|
| 31 |
|
| 32 |
-
|
| 33 |
-
custom_test_pipeline = model.cfg.data.test.pipeline
|
| 34 |
|
| 35 |
-
|
| 36 |
-
class PredictionOutput(BaseModel):
|
| 37 |
-
t1: List[float]
|
| 38 |
-
t2: List[float]
|
| 39 |
-
t3: List[float]
|
| 40 |
-
prediction: List[float]
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
img = src.read()
|
| 46 |
|
| 47 |
-
|
| 48 |
-
processed_img = apply_pipeline(custom_test_pipeline, img)
|
| 49 |
-
|
| 50 |
-
# Run inference
|
| 51 |
-
output = model.inference(processed_img)
|
| 52 |
-
|
| 53 |
-
# Post-process the output to get the RGB and prediction images
|
| 54 |
-
rgb1 = postprocess_output(output[0])
|
| 55 |
-
rgb2 = postprocess_output(output[1])
|
| 56 |
-
rgb3 = postprocess_output(output[2])
|
| 57 |
-
|
| 58 |
-
return rgb1, rgb2, rgb3, output
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
return img
|
| 64 |
|
| 65 |
-
|
| 66 |
-
# Convert the model's output into an RGB image or other formats as needed
|
| 67 |
-
return output
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
target_image = BytesIO(await file.read())
|
| 73 |
-
|
| 74 |
-
# Save the file temporarily if needed
|
| 75 |
-
with open("temp_image.tif", "wb") as f:
|
| 76 |
-
f.write(target_image.getvalue())
|
| 77 |
|
| 78 |
-
# Run the prediction
|
| 79 |
-
rgb1, rgb2, rgb3, output = inference_on_file("temp_image.tif", model, custom_test_pipeline)
|
| 80 |
-
|
| 81 |
-
# Return the results
|
| 82 |
-
return {
|
| 83 |
-
"t1": rgb1.tolist(),
|
| 84 |
-
"t2": rgb2.tolist(),
|
| 85 |
-
"t3": rgb3.tolist(),
|
| 86 |
-
"prediction": output.tolist()
|
| 87 |
-
}
|
| 88 |
-
|
| 89 |
-
# Optional: Serve the Gradio interface (if you still want to use it with FastAPI)
|
| 90 |
cdl_color_map = [{'value': 1, 'label': 'Natural vegetation', 'rgb': (233,255,190)},
|
| 91 |
{'value': 2, 'label': 'Forest', 'rgb': (149,206,147)},
|
| 92 |
{'value': 3, 'label': 'Corn', 'rgb': (255,212,0)},
|
|
@@ -316,8 +267,4 @@ with gr.Blocks() as demo:
|
|
| 316 |
gr.Image(value='Legend.png', image_mode='RGB', show_label=False)
|
| 317 |
|
| 318 |
|
| 319 |
-
demo.launch()
|
| 320 |
-
|
| 321 |
-
if __name__ == "__main__":
|
| 322 |
-
run_gradio_interface()
|
| 323 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
| 1 |
+
######### pull files
|
| 2 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from huggingface_hub import hf_hub_download
|
| 4 |
+
config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
|
| 5 |
+
filename="multi_temporal_crop_classification_Prithvi_100M.py",
|
| 6 |
+
token=os.environ.get("token"))
|
| 7 |
+
ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
|
| 8 |
+
filename='multi_temporal_crop_classification_Prithvi_100M.pth',
|
| 9 |
+
token=os.environ.get("token"))
|
| 10 |
+
##########
|
| 11 |
+
import argparse
|
| 12 |
from mmcv import Config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
from mmseg.models import build_segmentor
|
|
|
|
| 15 |
|
| 16 |
+
from mmseg.datasets.pipelines import Compose, LoadImageFromFile
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
import rasterio
|
| 19 |
+
import torch
|
|
|
|
| 20 |
|
| 21 |
+
from mmseg.apis import init_segmentor
|
|
|
|
| 22 |
|
| 23 |
+
from mmcv.parallel import collate, scatter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
import numpy as np
|
| 26 |
+
import glob
|
| 27 |
+
import os
|
|
|
|
| 28 |
|
| 29 |
+
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
import numpy as np
|
| 32 |
+
import gradio as gr
|
| 33 |
+
from functools import partial
|
|
|
|
| 34 |
|
| 35 |
+
import pdb
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
import matplotlib.pyplot as plt
|
| 38 |
+
|
| 39 |
+
from skimage import exposure
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
cdl_color_map = [{'value': 1, 'label': 'Natural vegetation', 'rgb': (233,255,190)},
|
| 42 |
{'value': 2, 'label': 'Forest', 'rgb': (149,206,147)},
|
| 43 |
{'value': 3, 'label': 'Corn', 'rgb': (255,212,0)},
|
|
|
|
| 267 |
gr.Image(value='Legend.png', image_mode='RGB', show_label=False)
|
| 268 |
|
| 269 |
|
| 270 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|