Subh775 commited on
Commit
b79a53d
·
verified ·
1 Parent(s): d77a368

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -92
app.py CHANGED
@@ -14,15 +14,15 @@ from huggingface_hub import hf_hub_download
14
  HF_MODEL_REPO_ID = "LeafNet75/Leaf-Annotate-v2"
15
  DEVICE = "cpu"
16
  IMG_SIZE = 256
17
- CONFIDENCE_THRESHOLD = 0.5
18
 
19
  # --- DATA MODELS FOR API ---
20
  class InferenceRequest(BaseModel):
21
- image: str # base64 encoded image string
22
- scribble_mask: str # base64 encoded scribble mask string
23
 
24
  class InferenceResponse(BaseModel):
25
- predicted_mask: str # base64 encoded raw binary mask string
26
 
27
  # --- INITIALIZE FASTAPI APP ---
28
  app = FastAPI()
@@ -30,8 +30,16 @@ app = FastAPI()
30
  # --- LOAD MODEL ON STARTUP ---
31
  def load_model():
32
  print(f"Loading model '{HF_MODEL_REPO_ID}'...")
33
- model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename="best_model.pth")
34
-
 
 
 
 
 
 
 
 
35
  model = smp.Unet(
36
  encoder_name="mobilenet_v2",
37
  encoder_weights=None,
@@ -48,35 +56,15 @@ model = load_model()
48
 
49
  # --- HELPER FUNCTIONS ---
50
  def base64_to_cv2_rgba(base64_string: str):
51
- try:
52
- # Handle data URL format
53
- if "," in base64_string:
54
- header, encoded = base64_string.split(",", 1)
55
- else:
56
- encoded = base64_string
57
-
58
- img_data = base64.b64decode(encoded)
59
- pil_image = Image.open(io.BytesIO(img_data))
60
-
61
- # Convert to RGBA if not already
62
- if pil_image.mode != 'RGBA':
63
- pil_image = pil_image.convert('RGBA')
64
-
65
- # Convert PIL to numpy array and then to OpenCV format
66
- np_array = np.array(pil_image)
67
- return cv2.cvtColor(np_array, cv2.COLOR_RGBA2BGRA)
68
- except Exception as e:
69
- print(f"Error in base64_to_cv2_rgba: {e}")
70
- raise
71
 
72
  def cv2_to_base64(image: np.ndarray):
73
- try:
74
- _, buffer = cv2.imencode('.png', image)
75
- png_as_text = base64.b64encode(buffer).decode('utf-8')
76
- return f"data:image/png;base64,{png_as_text}"
77
- except Exception as e:
78
- print(f"Error in cv2_to_base64: {e}")
79
- raise
80
 
81
  # --- API ENDPOINTS ---
82
  @app.get("/")
@@ -85,62 +73,30 @@ def read_root():
85
 
86
  @app.post("/predict", response_model=InferenceResponse)
87
  async def predict(request: InferenceRequest):
88
- try:
89
- # Convert base64 images to OpenCV format
90
- image_cv = base64_to_cv2_rgba(request.image)
91
- scribble_cv = base64_to_cv2_rgba(request.scribble_mask)
92
-
93
- # Convert scribble mask to grayscale if it has multiple channels
94
- if len(scribble_cv.shape) > 2 and scribble_cv.shape[2] > 1:
95
- scribble_cv = cv2.cvtColor(scribble_cv, cv2.COLOR_BGRA2GRAY)
96
-
97
- # Get original dimensions - FIXED SYNTAX ERROR
98
- h, w, *_ = image_cv.shape
99
-
100
- # Resize images to model input size
101
- image_resized = cv2.resize(
102
- cv2.cvtColor(image_cv, cv2.COLOR_BGRA2RGB),
103
- (IMG_SIZE, IMG_SIZE),
104
- interpolation=cv2.INTER_AREA
105
- )
106
- scribble_resized = cv2.resize(
107
- scribble_cv,
108
- (IMG_SIZE, IMG_SIZE),
109
- interpolation=cv2.INTER_NEAREST
110
- )
111
-
112
- # Convert to tensors and normalize
113
- image_tensor = torch.from_numpy(image_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
114
- scribble_tensor = torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
115
-
116
- # Concatenate image and scribble mask as 4-channel input
117
- input_tensor = torch.cat([image_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
118
-
119
- # Run inference
120
- with torch.no_grad():
121
- output = model(input_tensor)
122
-
123
- # Post-process output
124
- probs = torch.sigmoid(output)
125
- binary_mask = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy()
126
-
127
- # Resize output mask back to original image dimensions
128
- output_mask_resized = cv2.resize(
129
- binary_mask,
130
- (w, h),
131
- interpolation=cv2.INTER_NEAREST
132
- )
133
-
134
- # Convert to uint8 format
135
- output_mask_uint8 = (output_mask_resized * 255).astype(np.uint8)
136
-
137
- # Convert to base64 for response
138
- result_base64 = cv2_to_base64(output_mask_uint8)
139
-
140
- return InferenceResponse(predicted_mask=result_base64)
141
-
142
- except Exception as e:
143
- print(f"Error in predict endpoint: {e}")
144
- import traceback
145
- traceback.print_exc()
146
- raise
 
14
  HF_MODEL_REPO_ID = "LeafNet75/Leaf-Annotate-v2"
15
  DEVICE = "cpu"
16
  IMG_SIZE = 256
17
+ CONFIDENCE_THRESHOLD = 0.298
18
 
19
  # --- DATA MODELS FOR API ---
20
  class InferenceRequest(BaseModel):
21
+ image: str
22
+ scribble_mask: str
23
 
24
  class InferenceResponse(BaseModel):
25
+ predicted_mask: str
26
 
27
  # --- INITIALIZE FASTAPI APP ---
28
  app = FastAPI()
 
30
  # --- LOAD MODEL ON STARTUP ---
31
  def load_model():
32
  print(f"Loading model '{HF_MODEL_REPO_ID}'...")
33
+ try:
34
+ model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename="best_model.pth")
35
+ except Exception as e:
36
+ # Fallback for local testing if the model file is in the same directory
37
+ if os.path.exists("best_model.pth"):
38
+ print("Could not download from Hub, using local 'best_model.pth'.")
39
+ model_path = "best_model.pth"
40
+ else:
41
+ raise e
42
+
43
  model = smp.Unet(
44
  encoder_name="mobilenet_v2",
45
  encoder_weights=None,
 
56
 
57
  # --- HELPER FUNCTIONS ---
58
  def base64_to_cv2_rgba(base64_string: str):
59
+ header, encoded = base64_string.split(",", 1)
60
+ img_data = base64.b64decode(encoded)
61
+ pil_image = Image.open(io.BytesIO(img_data))
62
+ return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGBA2BGRA)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def cv2_to_base64(image: np.ndarray):
65
+ _, buffer = cv2.imencode('.png', image)
66
+ png_as_text = base64.b64encode(buffer).decode('utf-8')
67
+ return f"data:image/png;base64,{png_as_text}"
 
 
 
 
68
 
69
  # --- API ENDPOINTS ---
70
  @app.get("/")
 
73
 
74
  @app.post("/predict", response_model=InferenceResponse)
75
  async def predict(request: InferenceRequest):
76
+ image_cv = base64_to_cv2_rgba(request.image)
77
+ scribble_cv = base64_to_cv2_rgba(request.scribble_mask)
78
+
79
+ if len(scribble_cv.shape) > 2 and scribble_cv.shape[2] > 1:
80
+ scribble_cv = cv2.cvtColor(scribble_cv, cv2.COLOR_BGRA2GRAY)
81
+
82
+ h, w, _ = image_cv.shape
83
+
84
+ image_resized = cv2.resize(cv2.cvtColor(image_cv, cv2.COLOR_BGRA2RGB), (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
85
+ scribble_resized = cv2.resize(scribble_cv, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
86
+
87
+ image_tensor = torch.from_numpy(image_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
88
+ scribble_tensor = torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
89
+ input_tensor = torch.cat([image_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
90
+
91
+ with torch.no_grad():
92
+ output = model(input_tensor)
93
+
94
+ probs = torch.sigmoid(output)
95
+ binary_mask = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy()
96
+
97
+ output_mask_resized = cv2.resize(binary_mask, (w, h), interpolation=cv2.INTER_NEAREST)
98
+ output_mask_uint8 = (output_mask_resized * 255).astype(np.uint8)
99
+
100
+ result_base64 = cv2_to_base64(output_mask_uint8)
101
+
102
+ return InferenceResponse(predicted_mask=result_base64)