|
|
import io |
|
|
import os |
|
|
import gdown |
|
|
import base64 |
|
|
from typing import Optional |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from fastapi import FastAPI, UploadFile, File, Form |
|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from detectron2.engine import DefaultPredictor |
|
|
from detectron2.config import get_cfg |
|
|
from detectron2.projects.point_rend import add_pointrend_config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Rooftop Segmentation API") |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EPSILONS = [0.01, 0.005, 0.004, 0.003, 0.001] |
|
|
|
|
|
@app.get("/epsilons") |
|
|
def get_epsilons(): |
|
|
return {"epsilons": EPSILONS} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH_IRREGULAR = "/tmp/model_irregular_flat.pth" |
|
|
DRIVE_FILE_ID = "15vi4zPhCs3aBnGepVnXFOqQjxdK1jpnA" |
|
|
|
|
|
def download_irregular_model(): |
|
|
if not os.path.exists(MODEL_PATH_IRREGULAR): |
|
|
url = f"https://drive.google.com/uc?id={DRIVE_FILE_ID}" |
|
|
|
|
|
tmp_dir = "/tmp/gdown" |
|
|
os.makedirs(tmp_dir, exist_ok=True) |
|
|
|
|
|
os.environ["GDOWN_CACHE_DIR"] = tmp_dir |
|
|
|
|
|
print("Downloading irregular-flat Detectron2 model...") |
|
|
gdown.download( |
|
|
url, |
|
|
MODEL_PATH_IRREGULAR, |
|
|
quiet=False, |
|
|
fuzzy=True, |
|
|
use_cookies=False |
|
|
) |
|
|
print("Download complete.") |
|
|
else: |
|
|
print("Irregular-flat model already exists, skipping download.") |
|
|
|
|
|
|
|
|
download_irregular_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists(MODEL_PATH_IRREGULAR): |
|
|
print("Irregular-flat model is ready at", MODEL_PATH_IRREGULAR) |
|
|
else: |
|
|
print("Irregular-flat model NOT found! Something went wrong!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_model_rect(weights_path: str): |
|
|
cfg = get_cfg() |
|
|
add_pointrend_config(cfg) |
|
|
cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml" |
|
|
cfg.merge_from_file(cfg_path) |
|
|
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 |
|
|
cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES |
|
|
cfg.MODEL.WEIGHTS = weights_path |
|
|
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 |
|
|
cfg.MODEL.DEVICE = "cpu" |
|
|
return DefaultPredictor(cfg) |
|
|
|
|
|
def setup_model_irregular(weights_path: str): |
|
|
cfg = get_cfg() |
|
|
add_pointrend_config(cfg) |
|
|
cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml" |
|
|
cfg.merge_from_file(cfg_path) |
|
|
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 |
|
|
cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES |
|
|
cfg.MODEL.WEIGHTS = weights_path |
|
|
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 |
|
|
cfg.MODEL.DEVICE = "cpu" |
|
|
return DefaultPredictor(cfg) |
|
|
|
|
|
|
|
|
predictor_rect = setup_model_rect("/app/model_rect_final.pth") |
|
|
predictor_irregular_flat = setup_model_irregular(MODEL_PATH_IRREGULAR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess_rect(mask: np.ndarray, epsilon: float) -> Optional[np.ndarray]: |
|
|
mask_uint8 = (mask * 255).astype(np.uint8) |
|
|
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
if not contours: |
|
|
return None |
|
|
c = max(contours, key=cv2.contourArea) |
|
|
eps = epsilon * cv2.arcLength(c, True) |
|
|
approx = cv2.approxPolyDP(c, eps, True) |
|
|
simp = np.zeros_like(mask_uint8) |
|
|
cv2.fillPoly(simp, [approx], 255) |
|
|
return simp |
|
|
|
|
|
def postprocess_irregular(mask: np.ndarray, epsilon: float) -> Optional[np.ndarray]: |
|
|
mask_uint8 = (mask * 255).astype(np.uint8) |
|
|
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
if not contours: |
|
|
return None |
|
|
c = max(contours, key=cv2.contourArea) |
|
|
eps = epsilon * cv2.arcLength(c, True) |
|
|
polygon = cv2.approxPolyDP(c, eps, True) |
|
|
return polygon.reshape(-1, 2) |
|
|
|
|
|
def mask_to_polygon(mask: np.ndarray) -> Optional[np.ndarray]: |
|
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
if not contours: |
|
|
return None |
|
|
largest = max(contours, key=cv2.contourArea) |
|
|
return largest.reshape(-1, 2) |
|
|
|
|
|
def im_to_b64_png(im: np.ndarray) -> str: |
|
|
_, buffer = cv2.imencode(".png", im) |
|
|
return base64.b64encode(buffer).decode() |
|
|
|
|
|
def overlay_polygon(im: np.ndarray, polygon: Optional[np.ndarray]) -> np.ndarray: |
|
|
overlay = im.copy() |
|
|
if polygon is not None: |
|
|
cv2.polylines(overlay, [polygon.astype(np.int32)], True, (0,0,255), 2) |
|
|
return overlay |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return {"message": "Rooftop Segmentation API is running!"} |
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict( |
|
|
file: UploadFile = File(...), |
|
|
rooftop_type: str = Form(...), |
|
|
epsilon: float = Form(0.004) |
|
|
): |
|
|
contents = await file.read() |
|
|
try: |
|
|
im_pil = Image.open(io.BytesIO(contents)).convert("RGB") |
|
|
except Exception as e: |
|
|
return JSONResponse(status_code=400, content={"error": "Invalid image", "detail": str(e)}) |
|
|
|
|
|
im = np.array(im_pil)[:, :, ::-1].copy() |
|
|
|
|
|
if rooftop_type.lower() == "rectangular": |
|
|
predictor = predictor_rect |
|
|
post_fn = lambda mask: postprocess_rect(mask, epsilon) |
|
|
model_used = "model_rect_final.pth" |
|
|
elif rooftop_type.lower() == "irregular": |
|
|
predictor = predictor_irregular_flat |
|
|
post_fn = lambda mask: postprocess_irregular(mask, epsilon) |
|
|
model_used = "model_irregular_flat.pth" |
|
|
else: |
|
|
return JSONResponse(status_code=400, content={"error": "Invalid rooftop_type. Choose 'rectangular' or 'irregular'."}) |
|
|
|
|
|
outputs = predictor(im) |
|
|
instances = outputs["instances"].to("cpu") |
|
|
|
|
|
if len(instances) == 0: |
|
|
return {"polygon": None, "image": None, "model_used": model_used, "rooftop_type": rooftop_type, "epsilon": epsilon} |
|
|
|
|
|
idx = int(instances.scores.argmax().item()) |
|
|
raw_mask = instances.pred_masks[idx].numpy().astype(np.uint8) |
|
|
|
|
|
result_mask = post_fn(raw_mask) |
|
|
polygon = mask_to_polygon(result_mask) if rooftop_type.lower() == "rectangular" else result_mask |
|
|
|
|
|
overlay = overlay_polygon(im, polygon) |
|
|
img_b64 = im_to_b64_png(overlay) |
|
|
|
|
|
return { |
|
|
"polygon": polygon.tolist() if polygon is not None else None, |
|
|
"image": img_b64, |
|
|
"model_used": model_used, |
|
|
"rooftop_type": rooftop_type, |
|
|
"epsilon": epsilon |
|
|
} |
|
|
|