File size: 3,382 Bytes
9608158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b79a53d
9608158
03278d0
9608158
b79a53d
 
9608158
 
b79a53d
9608158
 
 
 
 
 
 
b79a53d
 
 
 
 
 
 
 
 
 
9608158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03278d0
b79a53d
 
 
 
9608158
 
b79a53d
 
 
9608158
 
 
 
 
 
 
 
b79a53d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import base64
import io
import cv2
import numpy as np
import torch
from fastapi import FastAPI
from fastapi.responses import FileResponse
from pydantic import BaseModel
from PIL import Image
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download

# --- CONFIGURATION ---
HF_MODEL_REPO_ID = "LeafNet75/Leaf-Annotate-v2"
DEVICE = "cpu"
IMG_SIZE = 256
CONFIDENCE_THRESHOLD = 0.298

# --- DATA MODELS FOR API ---
class InferenceRequest(BaseModel):
    image: str
    scribble_mask: str

class InferenceResponse(BaseModel):
    predicted_mask: str

# --- INITIALIZE FASTAPI APP ---
app = FastAPI()

# --- LOAD MODEL ON STARTUP ---
def load_model():
    print(f"Loading model '{HF_MODEL_REPO_ID}'...")
    try:
        model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename="best_model.pth")
    except Exception as e:
        # Fallback for local testing if the model file is in the same directory
        if os.path.exists("best_model.pth"):
            print("Could not download from Hub, using local 'best_model.pth'.")
            model_path = "best_model.pth"
        else:
            raise e

    model = smp.Unet(
        encoder_name="mobilenet_v2",
        encoder_weights=None,
        in_channels=4,
        classes=1,
    )
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    print("Model loaded successfully.")
    return model

model = load_model()

# --- HELPER FUNCTIONS ---
def base64_to_cv2_rgba(base64_string: str):
    header, encoded = base64_string.split(",", 1)
    img_data = base64.b64decode(encoded)
    pil_image = Image.open(io.BytesIO(img_data))
    return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGBA2BGRA)

def cv2_to_base64(image: np.ndarray):
    _, buffer = cv2.imencode('.png', image)
    png_as_text = base64.b64encode(buffer).decode('utf-8')
    return f"data:image/png;base64,{png_as_text}"

# --- API ENDPOINTS ---
@app.get("/")
def read_root():
    return FileResponse('index.html')

@app.post("/predict", response_model=InferenceResponse)
async def predict(request: InferenceRequest):
    image_cv = base64_to_cv2_rgba(request.image)
    scribble_cv = base64_to_cv2_rgba(request.scribble_mask)
    
    if len(scribble_cv.shape) > 2 and scribble_cv.shape[2] > 1:
        scribble_cv = cv2.cvtColor(scribble_cv, cv2.COLOR_BGRA2GRAY)

    h, w, _ = image_cv.shape

    image_resized = cv2.resize(cv2.cvtColor(image_cv, cv2.COLOR_BGRA2RGB), (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
    scribble_resized = cv2.resize(scribble_cv, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)

    image_tensor = torch.from_numpy(image_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
    scribble_tensor = torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
    input_tensor = torch.cat([image_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        output = model(input_tensor)

    probs = torch.sigmoid(output)
    binary_mask = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy()
    
    output_mask_resized = cv2.resize(binary_mask, (w, h), interpolation=cv2.INTER_NEAREST)
    output_mask_uint8 = (output_mask_resized * 255).astype(np.uint8)

    result_base64 = cv2_to_base64(output_mask_uint8)
    
    return InferenceResponse(predicted_mask=result_base64)