Segment-Leaf / app.py
Subh775's picture
Update app.py
b79a53d verified
import base64
import io
import cv2
import numpy as np
import torch
from fastapi import FastAPI
from fastapi.responses import FileResponse
from pydantic import BaseModel
from PIL import Image
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download
# --- CONFIGURATION ---
HF_MODEL_REPO_ID = "LeafNet75/Leaf-Annotate-v2"
DEVICE = "cpu"
IMG_SIZE = 256
CONFIDENCE_THRESHOLD = 0.298
# --- DATA MODELS FOR API ---
class InferenceRequest(BaseModel):
image: str
scribble_mask: str
class InferenceResponse(BaseModel):
predicted_mask: str
# --- INITIALIZE FASTAPI APP ---
app = FastAPI()
# --- LOAD MODEL ON STARTUP ---
def load_model():
print(f"Loading model '{HF_MODEL_REPO_ID}'...")
try:
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename="best_model.pth")
except Exception as e:
# Fallback for local testing if the model file is in the same directory
if os.path.exists("best_model.pth"):
print("Could not download from Hub, using local 'best_model.pth'.")
model_path = "best_model.pth"
else:
raise e
model = smp.Unet(
encoder_name="mobilenet_v2",
encoder_weights=None,
in_channels=4,
classes=1,
)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.to(DEVICE)
model.eval()
print("Model loaded successfully.")
return model
model = load_model()
# --- HELPER FUNCTIONS ---
def base64_to_cv2_rgba(base64_string: str):
header, encoded = base64_string.split(",", 1)
img_data = base64.b64decode(encoded)
pil_image = Image.open(io.BytesIO(img_data))
return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGBA2BGRA)
def cv2_to_base64(image: np.ndarray):
_, buffer = cv2.imencode('.png', image)
png_as_text = base64.b64encode(buffer).decode('utf-8')
return f"data:image/png;base64,{png_as_text}"
# --- API ENDPOINTS ---
@app.get("/")
def read_root():
return FileResponse('index.html')
@app.post("/predict", response_model=InferenceResponse)
async def predict(request: InferenceRequest):
image_cv = base64_to_cv2_rgba(request.image)
scribble_cv = base64_to_cv2_rgba(request.scribble_mask)
if len(scribble_cv.shape) > 2 and scribble_cv.shape[2] > 1:
scribble_cv = cv2.cvtColor(scribble_cv, cv2.COLOR_BGRA2GRAY)
h, w, _ = image_cv.shape
image_resized = cv2.resize(cv2.cvtColor(image_cv, cv2.COLOR_BGRA2RGB), (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
scribble_resized = cv2.resize(scribble_cv, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
image_tensor = torch.from_numpy(image_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
scribble_tensor = torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
input_tensor = torch.cat([image_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(input_tensor)
probs = torch.sigmoid(output)
binary_mask = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy()
output_mask_resized = cv2.resize(binary_mask, (w, h), interpolation=cv2.INTER_NEAREST)
output_mask_uint8 = (output_mask_resized * 255).astype(np.uint8)
result_base64 = cv2_to_base64(output_mask_uint8)
return InferenceResponse(predicted_mask=result_base64)