Chayanat commited on
Commit
781c0d1
·
verified ·
1 Parent(s): 1ece780
Files changed (1) hide show
  1. app.py +102 -14
app.py CHANGED
@@ -310,50 +310,138 @@ def calculate_ctr(landmarks, corrected_landmarks=None):
310
  return round(ctr, 3), abs(tilt_angle)
311
 
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  def segment(input_img):
314
  global hybrid, device
315
 
316
  if hybrid is None:
317
  hybrid = loadModel(device)
318
 
319
- input_img = cv2.imread(input_img, 0) / 255.0
320
- original_shape = input_img.shape[:2]
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
- img, (h, w, padding) = preprocess(input_img)
 
323
 
 
324
  data = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device).float()
325
 
326
  with torch.no_grad():
327
  output = hybrid(data)[0].cpu().numpy().reshape(-1, 2)
328
 
 
329
  output = removePreprocess(output, (h, w, padding))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
- output = output.astype('int')
332
-
333
- outseg, corrected_data = drawOnTop(input_img, output, original_shape)
334
 
335
  seg_to_save = (outseg.copy() * 255).astype('uint8')
336
  cv2.imwrite("tmp/overlap_segmentation.png", cv2.cvtColor(seg_to_save, cv2.COLOR_RGB2BGR))
337
 
338
  ctr_value, tilt_angle = calculate_ctr(output, corrected_data)
339
 
340
- # Add tilt warning to interpretation
 
 
 
 
 
341
  tilt_warning = ""
342
  if tilt_angle > 5:
343
- tilt_warning = f" (⚠️ Image tilted {tilt_angle:.1f}° - measurement corrected)"
344
  elif tilt_angle > 2:
345
- tilt_warning = f" (Image tilted {tilt_angle:.1f}° - corrected)"
346
 
347
  if ctr_value < 0.5:
348
- interpretation = f"Normal{tilt_warning}"
349
  elif 0.51 <= ctr_value <= 0.55:
350
- interpretation = f"Mild Cardiomegaly (CTR 51-55%){tilt_warning}"
351
  elif 0.56 <= ctr_value <= 0.60:
352
- interpretation = f"Moderate Cardiomegaly (CTR 56-60%){tilt_warning}"
353
  elif ctr_value > 0.60:
354
- interpretation = f"Severe Cardiomegaly (CTR > 60%){tilt_warning}"
355
  else:
356
- interpretation = f"Cardiomegaly{tilt_warning}"
357
 
358
  return outseg, "tmp/overlap_segmentation.png", ctr_value, interpretation
359
 
 
310
  return round(ctr, 3), abs(tilt_angle)
311
 
312
 
313
+ def detect_image_rotation(img):
314
+ """Detect rotation angle of chest X-ray using basic image analysis"""
315
+ # Apply edge detection
316
+ edges = cv2.Canny((img * 255).astype(np.uint8), 50, 150)
317
+
318
+ # Find lines using Hough transform
319
+ lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=100)
320
+
321
+ if lines is not None:
322
+ angles = []
323
+ for rho, theta in lines[:min(10, len(lines))]: # Consider top 10 lines
324
+ angle = np.degrees(theta) - 90 # Convert to rotation angle
325
+ # Filter for nearly horizontal or vertical lines
326
+ if abs(angle) < 30 or abs(angle) > 60:
327
+ angles.append(angle)
328
+
329
+ if angles:
330
+ # Take median angle to avoid outliers
331
+ rotation_angle = np.median(angles)
332
+ if abs(rotation_angle) > 2: # Only if significant rotation
333
+ return rotation_angle
334
+
335
+ return 0
336
+
337
+ def rotate_image(img, angle):
338
+ """Rotate image by given angle"""
339
+ if abs(angle) < 1:
340
+ return img, 0
341
+
342
+ h, w = img.shape[:2]
343
+ center = (w // 2, h // 2)
344
+
345
+ # Get rotation matrix
346
+ rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
347
+
348
+ # Calculate new dimensions
349
+ cos_angle = abs(rotation_matrix[0, 0])
350
+ sin_angle = abs(rotation_matrix[0, 1])
351
+ new_w = int((h * sin_angle) + (w * cos_angle))
352
+ new_h = int((h * cos_angle) + (w * sin_angle))
353
+
354
+ # Adjust translation
355
+ rotation_matrix[0, 2] += (new_w / 2) - center[0]
356
+ rotation_matrix[1, 2] += (new_h / 2) - center[1]
357
+
358
+ # Rotate image
359
+ rotated = cv2.warpAffine(img, rotation_matrix, (new_w, new_h),
360
+ borderMode=cv2.BORDER_CONSTANT, borderValue=0)
361
+
362
+ return rotated, angle
363
+
364
  def segment(input_img):
365
  global hybrid, device
366
 
367
  if hybrid is None:
368
  hybrid = loadModel(device)
369
 
370
+ original_img = cv2.imread(input_img, 0) / 255.0
371
+ original_shape = original_img.shape[:2]
372
+
373
+ # Step 1: Detect and correct rotation BEFORE AI processing
374
+ detected_rotation = detect_image_rotation(original_img)
375
+
376
+ if abs(detected_rotation) > 2:
377
+ # Rotate image to make it upright for AI processing
378
+ corrected_img, rotation_applied = rotate_image(original_img, -detected_rotation)
379
+ processing_img = corrected_img
380
+ was_rotated = True
381
+ else:
382
+ processing_img = original_img
383
+ rotation_applied = 0
384
+ was_rotated = False
385
 
386
+ # Step 2: Preprocess the (potentially corrected) image
387
+ img, (h, w, padding) = preprocess(processing_img)
388
 
389
+ # Step 3: AI segmentation on corrected image
390
  data = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device).float()
391
 
392
  with torch.no_grad():
393
  output = hybrid(data)[0].cpu().numpy().reshape(-1, 2)
394
 
395
+ # Step 4: Remove preprocessing
396
  output = removePreprocess(output, (h, w, padding))
397
+
398
+ # Step 5: If we rotated the image, rotate landmarks back to original orientation
399
+ if was_rotated:
400
+ corrected_h, corrected_w = processing_img.shape[:2]
401
+ corrected_center = np.array([corrected_w/2, corrected_h/2])
402
+ output_rotated = rotate_points(output.astype(float), detected_rotation, corrected_center)
403
+
404
+ # Adjust coordinates back to original image size and position
405
+ # This is a simplified approach - you might need more sophisticated coordinate transformation
406
+ scale_x = original_shape[1] / corrected_w
407
+ scale_y = original_shape[0] / corrected_h
408
+ output_rotated[:, 0] *= scale_x
409
+ output_rotated[:, 1] *= scale_y
410
+
411
+ output = output_rotated.astype('int')
412
+ else:
413
+ output = output.astype('int')
414
 
415
+ # Step 6: Draw results on original image
416
+ outseg, corrected_data = drawOnTop(original_img, output, original_shape)
 
417
 
418
  seg_to_save = (outseg.copy() * 255).astype('uint8')
419
  cv2.imwrite("tmp/overlap_segmentation.png", cv2.cvtColor(seg_to_save, cv2.COLOR_RGB2BGR))
420
 
421
  ctr_value, tilt_angle = calculate_ctr(output, corrected_data)
422
 
423
+ # Add rotation info to interpretation
424
+ rotation_warning = ""
425
+ if was_rotated:
426
+ rotation_warning = f" (🔄 Image was rotated {detected_rotation:.1f}° for AI processing)"
427
+
428
+ # Add remaining tilt warning (after AI processing correction)
429
  tilt_warning = ""
430
  if tilt_angle > 5:
431
+ tilt_warning = f" (⚠️ Remaining tilt: {tilt_angle:.1f}°)"
432
  elif tilt_angle > 2:
433
+ tilt_warning = f" (Minor tilt: {tilt_angle:.1f}°)"
434
 
435
  if ctr_value < 0.5:
436
+ interpretation = f"Normal{rotation_warning}{tilt_warning}"
437
  elif 0.51 <= ctr_value <= 0.55:
438
+ interpretation = f"Mild Cardiomegaly (CTR 51-55%){rotation_warning}{tilt_warning}"
439
  elif 0.56 <= ctr_value <= 0.60:
440
+ interpretation = f"Moderate Cardiomegaly (CTR 56-60%){rotation_warning}{tilt_warning}"
441
  elif ctr_value > 0.60:
442
+ interpretation = f"Severe Cardiomegaly (CTR > 60%){rotation_warning}{tilt_warning}"
443
  else:
444
+ interpretation = f"Cardiomegaly{rotation_warning}{tilt_warning}"
445
 
446
  return outseg, "tmp/overlap_segmentation.png", ctr_value, interpretation
447