Spaces:
Running
Running
| import base64 | |
| import io | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from fastapi import FastAPI | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| import segmentation_models_pytorch as smp | |
| from huggingface_hub import hf_hub_download | |
| # --- CONFIGURATION --- | |
| HF_MODEL_REPO_ID = "LeafNet75/Leaf-Annotate-v2" | |
| DEVICE = "cpu" | |
| IMG_SIZE = 256 | |
| CONFIDENCE_THRESHOLD = 0.298 | |
| # --- DATA MODELS FOR API --- | |
| class InferenceRequest(BaseModel): | |
| image: str | |
| scribble_mask: str | |
| class InferenceResponse(BaseModel): | |
| predicted_mask: str | |
| # --- INITIALIZE FASTAPI APP --- | |
| app = FastAPI() | |
| # --- LOAD MODEL ON STARTUP --- | |
| def load_model(): | |
| print(f"Loading model '{HF_MODEL_REPO_ID}'...") | |
| try: | |
| model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename="best_model.pth") | |
| except Exception as e: | |
| # Fallback for local testing if the model file is in the same directory | |
| if os.path.exists("best_model.pth"): | |
| print("Could not download from Hub, using local 'best_model.pth'.") | |
| model_path = "best_model.pth" | |
| else: | |
| raise e | |
| model = smp.Unet( | |
| encoder_name="mobilenet_v2", | |
| encoder_weights=None, | |
| in_channels=4, | |
| classes=1, | |
| ) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.to(DEVICE) | |
| model.eval() | |
| print("Model loaded successfully.") | |
| return model | |
| model = load_model() | |
| # --- HELPER FUNCTIONS --- | |
| def base64_to_cv2_rgba(base64_string: str): | |
| header, encoded = base64_string.split(",", 1) | |
| img_data = base64.b64decode(encoded) | |
| pil_image = Image.open(io.BytesIO(img_data)) | |
| return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGBA2BGRA) | |
| def cv2_to_base64(image: np.ndarray): | |
| _, buffer = cv2.imencode('.png', image) | |
| png_as_text = base64.b64encode(buffer).decode('utf-8') | |
| return f"data:image/png;base64,{png_as_text}" | |
| # --- API ENDPOINTS --- | |
| def read_root(): | |
| return FileResponse('index.html') | |
| async def predict(request: InferenceRequest): | |
| image_cv = base64_to_cv2_rgba(request.image) | |
| scribble_cv = base64_to_cv2_rgba(request.scribble_mask) | |
| if len(scribble_cv.shape) > 2 and scribble_cv.shape[2] > 1: | |
| scribble_cv = cv2.cvtColor(scribble_cv, cv2.COLOR_BGRA2GRAY) | |
| h, w, _ = image_cv.shape | |
| image_resized = cv2.resize(cv2.cvtColor(image_cv, cv2.COLOR_BGRA2RGB), (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA) | |
| scribble_resized = cv2.resize(scribble_cv, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST) | |
| image_tensor = torch.from_numpy(image_resized.astype(np.float32)).permute(2, 0, 1) / 255.0 | |
| scribble_tensor = torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0 | |
| input_tensor = torch.cat([image_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| probs = torch.sigmoid(output) | |
| binary_mask = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy() | |
| output_mask_resized = cv2.resize(binary_mask, (w, h), interpolation=cv2.INTER_NEAREST) | |
| output_mask_uint8 = (output_mask_resized * 255).astype(np.uint8) | |
| result_base64 = cv2_to_base64(output_mask_uint8) | |
| return InferenceResponse(predicted_mask=result_base64) |