Chayanat commited on
Commit
1c9b95e
·
verified ·
1 Parent(s): b4e1268

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -272
app.py CHANGED
@@ -108,8 +108,8 @@ def drawOnTop(img, landmarks, original_shape):
108
  image = cv2.line(image, (int(rl_top[0]), int(rl_top[1])), (int(ll_top[0]), int(ll_top[1])), (0, 1, 0), 1)
109
 
110
  # Add tilt angle text
111
- tilt_text = f"Tilt: {tilt_angle:.1f}°"
112
- cv2.putText(image, tilt_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 1, 0), 2)
113
 
114
  # Correct landmarks for tilt
115
  if abs(tilt_angle) > 2: # Only correct if tilt is significant
@@ -120,51 +120,24 @@ def drawOnTop(img, landmarks, original_shape):
120
  else:
121
  RL_corrected, LL_corrected, H_corrected = RL, LL, H
122
 
123
- # Get interpolated points for more precise measurements (MOVED UP)
124
- # Reduced for HuggingFace memory constraints
125
- H_dense = interpolate_contour_points(H_corrected, num_points=150)
126
- RL_dense = interpolate_contour_points(RL_corrected, num_points=100)
127
- LL_dense = interpolate_contour_points(LL_corrected, num_points=100)
128
-
129
- # Heart (red line) - calculate positions from corrected coordinates using precise method
130
- heart_left, heart_right = find_precise_extremes(H_dense, direction='horizontal')
131
- heart_y_corrected = np.mean([heart_left[1], heart_right[1]])
132
-
133
- # Thorax (blue line) - calculate positions from corrected coordinates using precise method
134
- rl_left, rl_right = find_precise_extremes(RL_dense, direction='horizontal')
135
- ll_left, ll_right = find_precise_extremes(LL_dense, direction='horizontal')
136
-
137
- # Get the overall leftmost and rightmost points
138
- thorax_left = rl_left if rl_left[0] < ll_left[0] else ll_left
139
- thorax_right = rl_right if rl_right[0] > ll_right[0] else ll_right
140
- thorax_y_corrected = np.mean([thorax_left[1], thorax_right[1]])
141
-
142
- # Add precision info (NOW AFTER CREATING THE VARIABLES)
143
- precision_text = f"Enhanced: {len(H_dense)} heart pts, {len(RL_dense)+len(LL_dense)} lung pts"
144
- cv2.putText(image, precision_text, (10, h-10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (1, 1, 1), 1)
145
-
146
  # Draw the landmarks as dots
147
  for l in RL:
148
- image = cv2.circle(image, (int(l[0]), int(l[1])), 3, (1, 0, 1), -1)
149
  for l in LL:
150
- image = cv2.circle(image, (int(l[0]), int(l[1])), 3, (1, 0, 1), -1)
151
  for l in H:
152
- image = cv2.circle(image, (int(l[0]), int(l[1])), 3, (1, 1, 0), -1)
153
-
154
- # Highlight the precise extreme points used for measurement
155
- # Heart extreme points (larger red circles)
156
- image = cv2.circle(image, (int(heart_left[0]), int(heart_left[1])), 8, (1, 0, 0), -1)
157
- image = cv2.circle(image, (int(heart_right[0]), int(heart_right[1])), 8, (1, 0, 0), -1)
158
-
159
- # Thorax extreme points (larger blue circles)
160
- image = cv2.circle(image, (int(thorax_left[0]), int(thorax_left[1])), 8, (0, 0, 1), -1)
161
- image = cv2.circle(image, (int(thorax_right[0]), int(thorax_right[1])), 8, (0, 0, 1), -1)
162
 
163
  # Draw measurement lines that follow the image tilt for visual accuracy
164
  # Use corrected coordinates for accurate measurement, but draw tilted lines for visual appeal
165
 
 
 
 
 
 
166
  # Rotate back to match the tilted image for display
167
- heart_points_corrected = np.array([[heart_left[0], heart_y_corrected], [heart_right[0], heart_y_corrected]])
168
  heart_points_display = rotate_points(heart_points_corrected, -tilt_angle, image_center) # Rotate back for display
169
 
170
  heart_start = (int(heart_points_display[0, 0]), int(heart_points_display[0, 1]))
@@ -191,9 +164,24 @@ def drawOnTop(img, landmarks, original_shape):
191
  (int(heart_end[0] + perp_x), int(heart_end[1] + perp_y)),
192
  (int(heart_end[0] - perp_x), int(heart_end[1] - perp_y)),
193
  (1, 0, 0), 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  # Rotate back to match the tilted image for display
196
- thorax_points_corrected = np.array([[thorax_left[0], thorax_y_corrected], [thorax_right[0], thorax_y_corrected]])
197
  thorax_points_display = rotate_points(thorax_points_corrected, -tilt_angle, image_center) # Rotate back for display
198
 
199
  thorax_start = (int(thorax_points_display[0, 0]), int(thorax_points_display[0, 1]))
@@ -305,162 +293,19 @@ def removePreprocess(output, info):
305
  return output
306
 
307
 
308
- def interpolate_contour_points(points, num_points=100):
309
- """Interpolate additional points along a contour for more precise measurements"""
310
- try:
311
- if len(points) < 3:
312
- return points
313
-
314
- # Reduce memory usage for HuggingFace
315
- num_points = min(num_points, 150) # Limit max points
316
-
317
- # Close the contour if not already closed
318
- if not np.array_equal(points[0], points[-1]):
319
- points = np.vstack([points, points[0]])
320
-
321
- # Calculate cumulative distances along the contour
322
- distances = np.sqrt(np.sum(np.diff(points, axis=0)**2, axis=1))
323
- cumulative_distances = np.concatenate([[0], np.cumsum(distances)])
324
-
325
- # Create new evenly spaced parameter values
326
- total_length = cumulative_distances[-1]
327
- if total_length == 0:
328
- return points
329
-
330
- new_params = np.linspace(0, total_length, num_points)
331
-
332
- # Interpolate x and y coordinates
333
- new_x = np.interp(new_params, cumulative_distances, points[:, 0])
334
- new_y = np.interp(new_params, cumulative_distances, points[:, 1])
335
-
336
- result = np.column_stack([new_x, new_y]).astype(int)
337
- return result
338
- except Exception as e:
339
- print(f"Error in interpolation: {e}")
340
- return points
341
-
342
- def find_precise_extremes(points, direction='horizontal'):
343
- """Find extreme points with sub-pixel precision using parabolic fitting"""
344
- try:
345
- if len(points) < 3:
346
- return points[0], points[-1] if len(points) > 1 else points[0]
347
-
348
- if direction == 'horizontal':
349
- # Find leftmost and rightmost points
350
- coord_idx = 0 # x-coordinate
351
- else:
352
- # Find topmost and bottommost points
353
- coord_idx = 1 # y-coordinate
354
-
355
- # Find approximate extremes
356
- min_idx = np.argmin(points[:, coord_idx])
357
- max_idx = np.argmax(points[:, coord_idx])
358
-
359
- def refine_extreme(idx, is_maximum=False):
360
- """Refine extreme point using neighboring points"""
361
- try:
362
- if len(points) < 3:
363
- return points[idx]
364
-
365
- # Get neighboring indices (with wrapping)
366
- prev_idx = (idx - 1) % len(points)
367
- next_idx = (idx + 1) % len(points)
368
-
369
- # Get the three points
370
- p1 = points[prev_idx]
371
- p2 = points[idx]
372
- p3 = points[next_idx]
373
-
374
- # Simple parabolic fitting in the coordinate direction
375
- coords = np.array([p1[coord_idx], p2[coord_idx], p3[coord_idx]])
376
- if is_maximum:
377
- best_idx = np.argmax(coords)
378
- else:
379
- best_idx = np.argmin(coords)
380
-
381
- if best_idx == 1: # Original point is still the best
382
- return p2
383
- elif best_idx == 0:
384
- return p1
385
- else:
386
- return p3
387
- except Exception as e:
388
- print(f"Error refining extreme: {e}")
389
- return points[idx]
390
-
391
- # Refine the extreme points
392
- min_point = refine_extreme(min_idx, is_maximum=False)
393
- max_point = refine_extreme(max_idx, is_maximum=True)
394
-
395
- return min_point, max_point
396
- except Exception as e:
397
- print(f"Error finding extremes: {e}")
398
- # Fallback to simple min/max
399
- if direction == 'horizontal':
400
- coord_idx = 0
401
- else:
402
- coord_idx = 1
403
- min_idx = np.argmin(points[:, coord_idx])
404
- max_idx = np.argmax(points[:, coord_idx])
405
- return points[min_idx], points[max_idx]
406
-
407
  def calculate_ctr(landmarks, corrected_landmarks=None):
408
- try:
409
- if corrected_landmarks is not None:
410
- RL, LL, H, tilt_angle = corrected_landmarks
411
- else:
412
- H = landmarks[94:]
413
- RL = landmarks[0:44]
414
- LL = landmarks[44:94]
415
- tilt_angle = 0
416
-
417
- # Reduced points for HuggingFace
418
- RL_dense = interpolate_contour_points(RL, num_points=100)
419
- LL_dense = interpolate_contour_points(LL, num_points=100)
420
- H_dense = interpolate_contour_points(H, num_points=150)
421
-
422
- # Find precise extreme points for heart
423
- heart_left, heart_right = find_precise_extremes(H_dense, direction='horizontal')
424
- cardiac_width = heart_right[0] - heart_left[0]
425
-
426
- # Find precise extreme points for thorax (combining both lungs)
427
- rl_left, rl_right = find_precise_extremes(RL_dense, direction='horizontal')
428
- ll_left, ll_right = find_precise_extremes(LL_dense, direction='horizontal')
429
-
430
- # Get the overall leftmost and rightmost points
431
- thorax_left = rl_left if rl_left[0] < ll_left[0] else ll_left
432
- thorax_right = rl_right if rl_right[0] > ll_right[0] else ll_right
433
- thoracic_width = thorax_right[0] - thorax_left[0]
434
-
435
- # Calculate CTR with additional precision checks
436
- if thoracic_width > 0 and cardiac_width > 0:
437
- ctr = cardiac_width / thoracic_width
438
-
439
- # Sanity check: CTR should be between 0.2 and 1.0
440
- if ctr < 0.2 or ctr > 1.0:
441
- # Fallback to simple calculation
442
- cardiac_width_simple = np.max(H[:, 0]) - np.min(H[:, 0])
443
- thoracic_width_simple = max(np.max(RL[:, 0]), np.max(LL[:, 0])) - min(np.min(RL[:, 0]), np.min(LL[:, 0]))
444
- ctr = cardiac_width_simple / thoracic_width_simple if thoracic_width_simple > 0 else 0
445
- else:
446
- ctr = 0
447
-
448
- return round(ctr, 3), abs(tilt_angle)
449
- except Exception as e:
450
- print(f"Error in CTR calculation: {e}")
451
- # Simple fallback calculation
452
- try:
453
- H = landmarks[94:] if corrected_landmarks is None else corrected_landmarks[2]
454
- RL = landmarks[0:44] if corrected_landmarks is None else corrected_landmarks[0]
455
- LL = landmarks[44:94] if corrected_landmarks is None else corrected_landmarks[1]
456
- tilt_angle = 0 if corrected_landmarks is None else corrected_landmarks[3]
457
-
458
- cardiac_width = np.max(H[:, 0]) - np.min(H[:, 0])
459
- thoracic_width = max(np.max(RL[:, 0]), np.max(LL[:, 0])) - min(np.min(RL[:, 0]), np.min(LL[:, 0]))
460
- ctr = cardiac_width / thoracic_width if thoracic_width > 0 else 0
461
- return round(ctr, 3), abs(tilt_angle)
462
- except:
463
- return 0.5, 0
464
 
465
 
466
  def detect_image_rotation(img):
@@ -527,79 +372,45 @@ def segment(input_img):
527
  global hybrid, device
528
 
529
  try:
530
- # Validate input
531
- if input_img is None:
532
- return None, None, 0, "Error: No image provided"
533
-
534
  if hybrid is None:
535
- print("Loading model...")
536
  hybrid = loadModel(device)
537
- print("Model loaded successfully")
538
 
539
- # Load and validate image
540
- original_img = cv2.imread(input_img, 0)
541
- if original_img is None:
542
- return None, None, 0, "Error: Could not load image"
543
-
544
- original_img = original_img / 255.0
545
  original_shape = original_img.shape[:2]
546
- print(f"Image loaded: {original_shape}")
547
 
548
  # Step 1: For now, skip rotation detection to avoid errors
 
549
  detected_rotation = 0 # Temporarily disabled
550
  was_rotated = False
551
  processing_img = original_img
552
 
553
  # Step 2: Preprocess the image
554
- print("Preprocessing image...")
555
  img, (h, w, padding) = preprocess(processing_img)
556
 
557
  # Step 3: AI segmentation
558
- print("Running AI segmentation...")
559
  data = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device).float()
560
 
561
  with torch.no_grad():
562
  output = hybrid(data)[0].cpu().numpy().reshape(-1, 2)
563
 
564
  # Step 4: Remove preprocessing
565
- print("Post-processing...")
566
  output = removePreprocess(output, (h, w, padding))
567
 
568
  # Step 5: Convert output to int
569
  output = output.astype('int')
570
 
571
  # Step 6: Draw results on original image
572
- print("Drawing results...")
573
  outseg, corrected_data = drawOnTop(original_img, output, original_shape)
574
 
575
- if outseg is None:
576
- return None, None, 0, "Error: Failed to generate output image"
577
-
578
  except Exception as e:
579
  print(f"Error in segmentation: {e}")
580
- import traceback
581
- traceback.print_exc()
582
  return None, None, 0, f"Error: {str(e)}"
583
 
584
- try:
585
- # Save output image
586
- seg_to_save = (outseg.copy() * 255).astype('uint8')
587
-
588
- # Create tmp directory if it doesn't exist
589
- import os
590
- os.makedirs("tmp", exist_ok=True)
591
-
592
- success = cv2.imwrite("tmp/overlap_segmentation.png", cv2.cvtColor(seg_to_save, cv2.COLOR_RGB2BGR))
593
- if not success:
594
- print("Warning: Could not save output image")
595
-
596
- print("Calculating CTR...")
597
- ctr_value, tilt_angle = calculate_ctr(output, corrected_data)
598
- print(f"CTR calculated: {ctr_value}")
599
-
600
- except Exception as e:
601
- print(f"Error in final processing: {e}")
602
- return outseg, None, 0.5, f"Segmentation completed but error in final processing: {str(e)}"
603
 
604
  # Add rotation info to interpretation
605
  rotation_warning = ""
@@ -613,44 +424,34 @@ def segment(input_img):
613
  elif tilt_angle > 2:
614
  tilt_warning = f" (Minor tilt: {tilt_angle:.1f}°)"
615
 
616
- # Enhanced interpretation with precision info (reduced for HuggingFace)
617
- precision_note = " [Enhanced precision: 350+ interpolated points]"
618
-
619
  if ctr_value < 0.5:
620
- interpretation = f"Normal{rotation_warning}{tilt_warning}{precision_note}"
621
  elif 0.51 <= ctr_value <= 0.55:
622
- interpretation = f"Mild Cardiomegaly (CTR 51-55%){rotation_warning}{tilt_warning}{precision_note}"
623
  elif 0.56 <= ctr_value <= 0.60:
624
- interpretation = f"Moderate Cardiomegaly (CTR 56-60%){rotation_warning}{tilt_warning}{precision_note}"
625
  elif ctr_value > 0.60:
626
- interpretation = f"Severe Cardiomegaly (CTR > 60%){rotation_warning}{tilt_warning}{precision_note}"
627
  else:
628
- interpretation = f"Cardiomegaly{rotation_warning}{tilt_warning}{precision_note}"
629
 
630
  return outseg, "tmp/overlap_segmentation.png", ctr_value, interpretation
631
 
632
 
633
  if __name__ == "__main__":
634
- # Clear any existing CUDA cache for HuggingFace
635
- if torch.cuda.is_available():
636
- torch.cuda.empty_cache()
637
-
638
- with gr.Blocks(title="Chest X-ray Segmentation") as demo:
639
  gr.Markdown("""
640
- # Chest X-ray HybridGNet Segmentation
 
 
641
 
642
- Enhanced CTR calculation with improved precision using interpolated landmarks.
 
 
643
 
644
- **Instructions:**
645
- 1. Upload a chest X-ray image (PA or AP) in PNG or JPEG format
646
- 2. Click "Segment Image"
647
- 3. Wait for processing (may take 30-60 seconds)
648
 
649
- **Features:**
650
- - AI-powered lung and heart segmentation
651
- - Enhanced CTR calculation with 350+ interpolated points
652
- - Automatic tilt correction
653
- - Visual measurement lines with precision markers
654
  """)
655
 
656
  with gr.Tab("Segment Image"):
@@ -721,16 +522,6 @@ if __name__ == "__main__":
721
  clear_button.click(lambda: None, None, ctr_output, queue=False)
722
  clear_button.click(lambda: None, None, ctr_interpretation, queue=False)
723
 
724
- image_button.click(
725
- fn=segment,
726
- inputs=image_input,
727
- outputs=[image_output, results, ctr_output, ctr_interpretation],
728
- show_progress=True
729
- )
730
-
731
- demo.launch(
732
- share=False,
733
- server_name="0.0.0.0",
734
- server_port=7860,
735
- enable_queue=True
736
- )
 
108
  image = cv2.line(image, (int(rl_top[0]), int(rl_top[1])), (int(ll_top[0]), int(ll_top[1])), (0, 1, 0), 1)
109
 
110
  # Add tilt angle text
111
+ tilt_text = f"Tilt: {tilt_angle:.1f} degrees"
112
+ cv2.putText(image, tilt_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 1, 0), 2)
113
 
114
  # Correct landmarks for tilt
115
  if abs(tilt_angle) > 2: # Only correct if tilt is significant
 
120
  else:
121
  RL_corrected, LL_corrected, H_corrected = RL, LL, H
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # Draw the landmarks as dots
124
  for l in RL:
125
+ image = cv2.circle(image, (int(l[0]), int(l[1])), 5, (1, 0, 1), -1)
126
  for l in LL:
127
+ image = cv2.circle(image, (int(l[0]), int(l[1])), 5, (1, 0, 1), -1)
128
  for l in H:
129
+ image = cv2.circle(image, (int(l[0]), int(l[1])), 5, (1, 1, 0), -1)
 
 
 
 
 
 
 
 
 
130
 
131
  # Draw measurement lines that follow the image tilt for visual accuracy
132
  # Use corrected coordinates for accurate measurement, but draw tilted lines for visual appeal
133
 
134
+ # Heart (red line) - calculate positions from corrected coordinates
135
+ heart_xmin_corrected = np.min(H_corrected[:, 0])
136
+ heart_xmax_corrected = np.max(H_corrected[:, 0])
137
+ heart_y_corrected = np.mean([H_corrected[np.argmin(H_corrected[:, 0]), 1], H_corrected[np.argmax(H_corrected[:, 0]), 1]])
138
+
139
  # Rotate back to match the tilted image for display
140
+ heart_points_corrected = np.array([[heart_xmin_corrected, heart_y_corrected], [heart_xmax_corrected, heart_y_corrected]])
141
  heart_points_display = rotate_points(heart_points_corrected, -tilt_angle, image_center) # Rotate back for display
142
 
143
  heart_start = (int(heart_points_display[0, 0]), int(heart_points_display[0, 1]))
 
164
  (int(heart_end[0] + perp_x), int(heart_end[1] + perp_y)),
165
  (int(heart_end[0] - perp_x), int(heart_end[1] - perp_y)),
166
  (1, 0, 0), 2)
167
+
168
+ # Thorax (blue line) - calculate positions from corrected coordinates
169
+ thorax_xmin_corrected = min(np.min(RL_corrected[:, 0]), np.min(LL_corrected[:, 0]))
170
+ thorax_xmax_corrected = max(np.max(RL_corrected[:, 0]), np.max(LL_corrected[:, 0]))
171
+
172
+ # Find y at leftmost and rightmost points (corrected)
173
+ if np.min(RL_corrected[:, 0]) < np.min(LL_corrected[:, 0]):
174
+ thorax_ymin_corrected = RL_corrected[np.argmin(RL_corrected[:, 0]), 1]
175
+ else:
176
+ thorax_ymin_corrected = LL_corrected[np.argmin(LL_corrected[:, 0]), 1]
177
+ if np.max(RL_corrected[:, 0]) > np.max(LL_corrected[:, 0]):
178
+ thorax_ymax_corrected = RL_corrected[np.argmax(RL_corrected[:, 0]), 1]
179
+ else:
180
+ thorax_ymax_corrected = LL_corrected[np.argmax(LL_corrected[:, 0]), 1]
181
+ thorax_y_corrected = np.mean([thorax_ymin_corrected, thorax_ymax_corrected])
182
 
183
  # Rotate back to match the tilted image for display
184
+ thorax_points_corrected = np.array([[thorax_xmin_corrected, thorax_y_corrected], [thorax_xmax_corrected, thorax_y_corrected]])
185
  thorax_points_display = rotate_points(thorax_points_corrected, -tilt_angle, image_center) # Rotate back for display
186
 
187
  thorax_start = (int(thorax_points_display[0, 0]), int(thorax_points_display[0, 1]))
 
293
  return output
294
 
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  def calculate_ctr(landmarks, corrected_landmarks=None):
297
+ if corrected_landmarks is not None:
298
+ RL, LL, H, tilt_angle = corrected_landmarks
299
+ else:
300
+ H = landmarks[94:]
301
+ RL = landmarks[0:44]
302
+ LL = landmarks[44:94]
303
+ tilt_angle = 0
304
+
305
+ cardiac_width = np.max(H[:, 0]) - np.min(H[:, 0])
306
+ thoracic_width = max(np.max(RL[:, 0]), np.max(LL[:, 0])) - min(np.min(RL[:, 0]), np.min(LL[:, 0]))
307
+ ctr = cardiac_width / thoracic_width if thoracic_width > 0 else 0
308
+ return round(ctr, 3), abs(tilt_angle)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
 
311
  def detect_image_rotation(img):
 
372
  global hybrid, device
373
 
374
  try:
 
 
 
 
375
  if hybrid is None:
 
376
  hybrid = loadModel(device)
 
377
 
378
+ original_img = cv2.imread(input_img, 0) / 255.0
 
 
 
 
 
379
  original_shape = original_img.shape[:2]
 
380
 
381
  # Step 1: For now, skip rotation detection to avoid errors
382
+ # TODO: Re-implement rotation detection after fixing coordinate transformation
383
  detected_rotation = 0 # Temporarily disabled
384
  was_rotated = False
385
  processing_img = original_img
386
 
387
  # Step 2: Preprocess the image
 
388
  img, (h, w, padding) = preprocess(processing_img)
389
 
390
  # Step 3: AI segmentation
 
391
  data = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device).float()
392
 
393
  with torch.no_grad():
394
  output = hybrid(data)[0].cpu().numpy().reshape(-1, 2)
395
 
396
  # Step 4: Remove preprocessing
 
397
  output = removePreprocess(output, (h, w, padding))
398
 
399
  # Step 5: Convert output to int
400
  output = output.astype('int')
401
 
402
  # Step 6: Draw results on original image
 
403
  outseg, corrected_data = drawOnTop(original_img, output, original_shape)
404
 
 
 
 
405
  except Exception as e:
406
  print(f"Error in segmentation: {e}")
407
+ # Return a basic error response
 
408
  return None, None, 0, f"Error: {str(e)}"
409
 
410
+ seg_to_save = (outseg.copy() * 255).astype('uint8')
411
+ cv2.imwrite("tmp/overlap_segmentation.png", cv2.cvtColor(seg_to_save, cv2.COLOR_RGB2BGR))
412
+
413
+ ctr_value, tilt_angle = calculate_ctr(output, corrected_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  # Add rotation info to interpretation
416
  rotation_warning = ""
 
424
  elif tilt_angle > 2:
425
  tilt_warning = f" (Minor tilt: {tilt_angle:.1f}°)"
426
 
 
 
 
427
  if ctr_value < 0.5:
428
+ interpretation = f"Normal{rotation_warning}{tilt_warning}"
429
  elif 0.51 <= ctr_value <= 0.55:
430
+ interpretation = f"Mild Cardiomegaly (CTR 51-55%){rotation_warning}{tilt_warning}"
431
  elif 0.56 <= ctr_value <= 0.60:
432
+ interpretation = f"Moderate Cardiomegaly (CTR 56-60%){rotation_warning}{tilt_warning}"
433
  elif ctr_value > 0.60:
434
+ interpretation = f"Severe Cardiomegaly (CTR > 60%){rotation_warning}{tilt_warning}"
435
  else:
436
+ interpretation = f"Cardiomegaly{rotation_warning}{tilt_warning}"
437
 
438
  return outseg, "tmp/overlap_segmentation.png", ctr_value, interpretation
439
 
440
 
441
  if __name__ == "__main__":
442
+ with gr.Blocks() as demo:
 
 
 
 
443
  gr.Markdown("""
444
+ # Chest X-ray HybridGNet Segmentation.
445
+
446
+ Demo of the HybridGNet model introduced in "Improving anatomical plausibility in medical image segmentation via hybrid graph neural networks: applications to chest x-ray analysis."
447
 
448
+ Instructions:
449
+ 1. Upload a chest X-ray image (PA or AP) in PNG or JPEG format.
450
+ 2. Click on "Segment Image".
451
 
452
+ Note: Pre-processing is not needed, it will be done automatically and removed after the segmentation.
 
 
 
453
 
454
+ Please check citations below.
 
 
 
 
455
  """)
456
 
457
  with gr.Tab("Segment Image"):
 
522
  clear_button.click(lambda: None, None, ctr_output, queue=False)
523
  clear_button.click(lambda: None, None, ctr_interpretation, queue=False)
524
 
525
+ image_button.click(segment, inputs=image_input, outputs=[image_output, results, ctr_output, ctr_interpretation], queue=False)
526
+
527
+ demo.launch()