Chayanat commited on
Commit
dba1a4d
·
verified ·
1 Parent(s): 781c0d1
Files changed (1) hide show
  1. app.py +78 -86
app.py CHANGED
@@ -210,8 +210,6 @@ def drawOnTop(img, landmarks, original_shape):
210
  # Store corrected landmarks for CTR calculation
211
  return image, (RL_corrected, LL_corrected, H_corrected, tilt_angle)
212
 
213
- return image
214
-
215
 
216
  def loadModel(device):
217
  A, AD, D, U = genMatrixesLungsHeart()
@@ -312,108 +310,102 @@ def calculate_ctr(landmarks, corrected_landmarks=None):
312
 
313
  def detect_image_rotation(img):
314
  """Detect rotation angle of chest X-ray using basic image analysis"""
315
- # Apply edge detection
316
- edges = cv2.Canny((img * 255).astype(np.uint8), 50, 150)
317
-
318
- # Find lines using Hough transform
319
- lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=100)
320
-
321
- if lines is not None:
322
- angles = []
323
- for rho, theta in lines[:min(10, len(lines))]: # Consider top 10 lines
324
- angle = np.degrees(theta) - 90 # Convert to rotation angle
325
- # Filter for nearly horizontal or vertical lines
326
- if abs(angle) < 30 or abs(angle) > 60:
327
- angles.append(angle)
328
 
329
- if angles:
330
- # Take median angle to avoid outliers
331
- rotation_angle = np.median(angles)
332
- if abs(rotation_angle) > 2: # Only if significant rotation
333
- return rotation_angle
334
-
335
- return 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  def rotate_image(img, angle):
338
  """Rotate image by given angle"""
339
- if abs(angle) < 1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  return img, 0
341
-
342
- h, w = img.shape[:2]
343
- center = (w // 2, h // 2)
344
-
345
- # Get rotation matrix
346
- rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
347
-
348
- # Calculate new dimensions
349
- cos_angle = abs(rotation_matrix[0, 0])
350
- sin_angle = abs(rotation_matrix[0, 1])
351
- new_w = int((h * sin_angle) + (w * cos_angle))
352
- new_h = int((h * cos_angle) + (w * sin_angle))
353
-
354
- # Adjust translation
355
- rotation_matrix[0, 2] += (new_w / 2) - center[0]
356
- rotation_matrix[1, 2] += (new_h / 2) - center[1]
357
-
358
- # Rotate image
359
- rotated = cv2.warpAffine(img, rotation_matrix, (new_w, new_h),
360
- borderMode=cv2.BORDER_CONSTANT, borderValue=0)
361
-
362
- return rotated, angle
363
 
364
  def segment(input_img):
365
  global hybrid, device
366
 
367
- if hybrid is None:
368
- hybrid = loadModel(device)
 
369
 
370
- original_img = cv2.imread(input_img, 0) / 255.0
371
- original_shape = original_img.shape[:2]
372
-
373
- # Step 1: Detect and correct rotation BEFORE AI processing
374
- detected_rotation = detect_image_rotation(original_img)
375
-
376
- if abs(detected_rotation) > 2:
377
- # Rotate image to make it upright for AI processing
378
- corrected_img, rotation_applied = rotate_image(original_img, -detected_rotation)
379
- processing_img = corrected_img
380
- was_rotated = True
381
- else:
382
- processing_img = original_img
383
- rotation_applied = 0
384
  was_rotated = False
 
385
 
386
- # Step 2: Preprocess the (potentially corrected) image
387
- img, (h, w, padding) = preprocess(processing_img)
388
 
389
- # Step 3: AI segmentation on corrected image
390
- data = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device).float()
391
 
392
- with torch.no_grad():
393
- output = hybrid(data)[0].cpu().numpy().reshape(-1, 2)
394
 
395
- # Step 4: Remove preprocessing
396
- output = removePreprocess(output, (h, w, padding))
397
-
398
- # Step 5: If we rotated the image, rotate landmarks back to original orientation
399
- if was_rotated:
400
- corrected_h, corrected_w = processing_img.shape[:2]
401
- corrected_center = np.array([corrected_w/2, corrected_h/2])
402
- output_rotated = rotate_points(output.astype(float), detected_rotation, corrected_center)
403
 
404
- # Adjust coordinates back to original image size and position
405
- # This is a simplified approach - you might need more sophisticated coordinate transformation
406
- scale_x = original_shape[1] / corrected_w
407
- scale_y = original_shape[0] / corrected_h
408
- output_rotated[:, 0] *= scale_x
409
- output_rotated[:, 1] *= scale_y
410
-
411
- output = output_rotated.astype('int')
412
- else:
413
  output = output.astype('int')
414
 
415
- # Step 6: Draw results on original image
416
- outseg, corrected_data = drawOnTop(original_img, output, original_shape)
 
 
 
 
 
417
 
418
  seg_to_save = (outseg.copy() * 255).astype('uint8')
419
  cv2.imwrite("tmp/overlap_segmentation.png", cv2.cvtColor(seg_to_save, cv2.COLOR_RGB2BGR))
 
210
  # Store corrected landmarks for CTR calculation
211
  return image, (RL_corrected, LL_corrected, H_corrected, tilt_angle)
212
 
 
 
213
 
214
  def loadModel(device):
215
  A, AD, D, U = genMatrixesLungsHeart()
 
310
 
311
  def detect_image_rotation(img):
312
  """Detect rotation angle of chest X-ray using basic image analysis"""
313
+ try:
314
+ # Apply edge detection
315
+ edges = cv2.Canny((img * 255).astype(np.uint8), 50, 150)
 
 
 
 
 
 
 
 
 
 
316
 
317
+ # Find lines using Hough transform
318
+ lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=100)
319
+
320
+ if lines is not None and len(lines) > 0:
321
+ angles = []
322
+ for line in lines[:min(10, len(lines))]: # Consider top 10 lines
323
+ rho, theta = line[0]
324
+ angle = np.degrees(theta) - 90 # Convert to rotation angle
325
+ # Filter for nearly horizontal or vertical lines
326
+ if abs(angle) < 30 or abs(angle) > 60:
327
+ angles.append(angle)
328
+
329
+ if angles:
330
+ # Take median angle to avoid outliers
331
+ rotation_angle = np.median(angles)
332
+ if abs(rotation_angle) > 2: # Only if significant rotation
333
+ return rotation_angle
334
+
335
+ return 0
336
+ except Exception as e:
337
+ print(f"Error in rotation detection: {e}")
338
+ return 0
339
 
340
  def rotate_image(img, angle):
341
  """Rotate image by given angle"""
342
+ try:
343
+ if abs(angle) < 1:
344
+ return img, 0
345
+
346
+ h, w = img.shape[:2]
347
+ center = (w // 2, h // 2)
348
+
349
+ # Get rotation matrix
350
+ rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
351
+
352
+ # Calculate new dimensions
353
+ cos_angle = abs(rotation_matrix[0, 0])
354
+ sin_angle = abs(rotation_matrix[0, 1])
355
+ new_w = int((h * sin_angle) + (w * cos_angle))
356
+ new_h = int((h * cos_angle) + (w * sin_angle))
357
+
358
+ # Adjust translation
359
+ rotation_matrix[0, 2] += (new_w / 2) - center[0]
360
+ rotation_matrix[1, 2] += (new_h / 2) - center[1]
361
+
362
+ # Rotate image
363
+ rotated = cv2.warpAffine(img, rotation_matrix, (new_w, new_h),
364
+ borderMode=cv2.BORDER_CONSTANT, borderValue=0)
365
+
366
+ return rotated, angle
367
+ except Exception as e:
368
+ print(f"Error in image rotation: {e}")
369
  return img, 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
  def segment(input_img):
372
  global hybrid, device
373
 
374
+ try:
375
+ if hybrid is None:
376
+ hybrid = loadModel(device)
377
 
378
+ original_img = cv2.imread(input_img, 0) / 255.0
379
+ original_shape = original_img.shape[:2]
380
+
381
+ # Step 1: For now, skip rotation detection to avoid errors
382
+ # TODO: Re-implement rotation detection after fixing coordinate transformation
383
+ detected_rotation = 0 # Temporarily disabled
 
 
 
 
 
 
 
 
384
  was_rotated = False
385
+ processing_img = original_img
386
 
387
+ # Step 2: Preprocess the image
388
+ img, (h, w, padding) = preprocess(processing_img)
389
 
390
+ # Step 3: AI segmentation
391
+ data = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device).float()
392
 
393
+ with torch.no_grad():
394
+ output = hybrid(data)[0].cpu().numpy().reshape(-1, 2)
395
 
396
+ # Step 4: Remove preprocessing
397
+ output = removePreprocess(output, (h, w, padding))
 
 
 
 
 
 
398
 
399
+ # Step 5: Convert output to int
 
 
 
 
 
 
 
 
400
  output = output.astype('int')
401
 
402
+ # Step 6: Draw results on original image
403
+ outseg, corrected_data = drawOnTop(original_img, output, original_shape)
404
+
405
+ except Exception as e:
406
+ print(f"Error in segmentation: {e}")
407
+ # Return a basic error response
408
+ return None, None, 0, f"Error: {str(e)}"
409
 
410
  seg_to_save = (outseg.copy() * 255).astype('uint8')
411
  cv2.imwrite("tmp/overlap_segmentation.png", cv2.cvtColor(seg_to_save, cv2.COLOR_RGB2BGR))