Subh775 commited on
Commit
03278d0
·
verified ·
1 Parent(s): b759bd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -21
app.py CHANGED
@@ -14,20 +14,20 @@ from huggingface_hub import hf_hub_download
14
  HF_MODEL_REPO_ID = "LeafNet75/Leaf-Annotate-v2"
15
  DEVICE = "cpu"
16
  IMG_SIZE = 256
 
17
 
18
- # --- DATA MODELS FOR API (using Pydantic) ---
19
  class InferenceRequest(BaseModel):
20
  image: str # base64 encoded image string
21
  scribble_mask: str # base64 encoded scribble mask string
22
 
23
  class InferenceResponse(BaseModel):
24
- predicted_mask: str # base64 encoded predicted mask string
25
 
26
  # --- INITIALIZE FASTAPI APP ---
27
  app = FastAPI()
28
 
29
  # --- LOAD MODEL ON STARTUP ---
30
- # The model is loaded once when the application starts to ensure fast inference times.
31
  def load_model():
32
  print(f"Loading model '{HF_MODEL_REPO_ID}'...")
33
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename="best_model.pth")
@@ -47,62 +47,48 @@ def load_model():
47
  model = load_model()
48
 
49
  # --- HELPER FUNCTIONS ---
50
- def base64_to_cv2(base64_string: str):
51
- # Remove the "data:image/..." header
52
  header, encoded = base64_string.split(",", 1)
53
  img_data = base64.b64decode(encoded)
54
-
55
- # Use Pillow to open the image data and convert to OpenCV format
56
  pil_image = Image.open(io.BytesIO(img_data))
57
  return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGBA2BGRA)
58
 
59
  def cv2_to_base64(image: np.ndarray):
60
- # Convert image back to a base64 string to send to the frontend
61
  _, buffer = cv2.imencode('.png', image)
62
  png_as_text = base64.b64encode(buffer).decode('utf-8')
63
  return f"data:image/png;base64,{png_as_text}"
64
 
65
-
66
  # --- API ENDPOINTS ---
67
  @app.get("/")
68
  def read_root():
69
- # Serve the frontend HTML file
70
  return FileResponse('index.html')
71
 
72
  @app.post("/predict", response_model=InferenceResponse)
73
  async def predict(request: InferenceRequest):
74
- # 1. Decode input data
75
- image_cv = base64_to_cv2(request.image)
76
- scribble_cv = base64_to_cv2(request.scribble_mask)
77
 
78
- # Ensure scribble is grayscale
79
  if len(scribble_cv.shape) > 2 and scribble_cv.shape[2] > 1:
80
  scribble_cv = cv2.cvtColor(scribble_cv, cv2.COLOR_BGRA2GRAY)
81
 
82
  h, w, _ = image_cv.shape
83
 
84
- # 2. Preprocess the data for the model
85
  image_resized = cv2.resize(cv2.cvtColor(image_cv, cv2.COLOR_BGRA2RGB), (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
86
  scribble_resized = cv2.resize(scribble_cv, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
87
 
88
  image_tensor = torch.from_numpy(image_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
89
  scribble_tensor = torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
90
-
91
  input_tensor = torch.cat([image_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
92
 
93
- # 3. Run Inference
94
  with torch.no_grad():
95
  output = model(input_tensor)
96
 
97
- # 4. Post-process the output
98
  probs = torch.sigmoid(output)
99
- binary_mask = (probs > 0.5).float().squeeze().cpu().numpy()
100
 
101
- # Resize mask to the original input canvas size
102
  output_mask_resized = cv2.resize(binary_mask, (w, h), interpolation=cv2.INTER_NEAREST)
103
  output_mask_uint8 = (output_mask_resized * 255).astype(np.uint8)
104
 
105
- # 5. Encode the result and return
106
  result_base64 = cv2_to_base64(output_mask_uint8)
107
 
108
  return InferenceResponse(predicted_mask=result_base64)
 
14
  HF_MODEL_REPO_ID = "LeafNet75/Leaf-Annotate-v2"
15
  DEVICE = "cpu"
16
  IMG_SIZE = 256
17
+ CONFIDENCE_THRESHOLD = 0.5
18
 
19
+ # --- DATA MODELS FOR API ---
20
  class InferenceRequest(BaseModel):
21
  image: str # base64 encoded image string
22
  scribble_mask: str # base64 encoded scribble mask string
23
 
24
  class InferenceResponse(BaseModel):
25
+ predicted_mask: str # base64 encoded raw binary mask string
26
 
27
  # --- INITIALIZE FASTAPI APP ---
28
  app = FastAPI()
29
 
30
  # --- LOAD MODEL ON STARTUP ---
 
31
  def load_model():
32
  print(f"Loading model '{HF_MODEL_REPO_ID}'...")
33
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename="best_model.pth")
 
47
  model = load_model()
48
 
49
  # --- HELPER FUNCTIONS ---
50
+ def base64_to_cv2_rgba(base64_string: str):
 
51
  header, encoded = base64_string.split(",", 1)
52
  img_data = base64.b64decode(encoded)
 
 
53
  pil_image = Image.open(io.BytesIO(img_data))
54
  return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGBA2BGRA)
55
 
56
  def cv2_to_base64(image: np.ndarray):
 
57
  _, buffer = cv2.imencode('.png', image)
58
  png_as_text = base64.b64encode(buffer).decode('utf-8')
59
  return f"data:image/png;base64,{png_as_text}"
60
 
 
61
  # --- API ENDPOINTS ---
62
  @app.get("/")
63
  def read_root():
 
64
  return FileResponse('index.html')
65
 
66
  @app.post("/predict", response_model=InferenceResponse)
67
  async def predict(request: InferenceRequest):
68
+ image_cv = base64_to_cv2_rgba(request.image)
69
+ scribble_cv = base64_to_cv2_rgba(request.scribble_mask)
 
70
 
 
71
  if len(scribble_cv.shape) > 2 and scribble_cv.shape[2] > 1:
72
  scribble_cv = cv2.cvtColor(scribble_cv, cv2.COLOR_BGRA2GRAY)
73
 
74
  h, w, _ = image_cv.shape
75
 
 
76
  image_resized = cv2.resize(cv2.cvtColor(image_cv, cv2.COLOR_BGRA2RGB), (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
77
  scribble_resized = cv2.resize(scribble_cv, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
78
 
79
  image_tensor = torch.from_numpy(image_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
80
  scribble_tensor = torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
 
81
  input_tensor = torch.cat([image_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
82
 
 
83
  with torch.no_grad():
84
  output = model(input_tensor)
85
 
 
86
  probs = torch.sigmoid(output)
87
+ binary_mask = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy()
88
 
 
89
  output_mask_resized = cv2.resize(binary_mask, (w, h), interpolation=cv2.INTER_NEAREST)
90
  output_mask_uint8 = (output_mask_resized * 255).astype(np.uint8)
91
 
 
92
  result_base64 = cv2_to_base64(output_mask_uint8)
93
 
94
  return InferenceResponse(predicted_mask=result_base64)