Chayanat commited on
Commit
799b7ad
·
verified ·
1 Parent(s): 0be59db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -31
app.py CHANGED
@@ -6,7 +6,6 @@ from models.HybridGNet2IGSC import Hybrid
6
  from utils.utils import scipy_to_torch_sparse, genMatrixesLungsHeart
7
  import scipy.sparse as sp
8
  import torch
9
- from zipfile import ZipFile
10
 
11
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12
  hybrid = None
@@ -178,13 +177,6 @@ def removePreprocess(output, info):
178
  return output
179
 
180
 
181
- def zip_files(files):
182
- with ZipFile("complete_results.zip", "w") as zipObj:
183
- for idx, file in enumerate(files):
184
- zipObj.write(file, arcname=file.split("/")[-1])
185
- return "complete_results.zip"
186
-
187
-
188
  def calculate_ctr(landmarks):
189
  H = landmarks[94:]
190
  RL = landmarks[0:44]
@@ -192,7 +184,7 @@ def calculate_ctr(landmarks):
192
  cardiac_width = np.max(H[:, 0]) - np.min(H[:, 0])
193
  thoracic_width = max(np.max(RL[:, 0]), np.max(LL[:, 0])) - min(np.min(RL[:, 0]), np.min(LL[:, 0]))
194
  ctr = cardiac_width / thoracic_width if thoracic_width > 0 else 0
195
- return ctr
196
 
197
 
198
  def segment(input_img):
@@ -220,28 +212,10 @@ def segment(input_img):
220
  seg_to_save = (outseg.copy() * 255).astype('uint8')
221
  cv2.imwrite("tmp/overlap_segmentation.png", cv2.cvtColor(seg_to_save, cv2.COLOR_RGB2BGR))
222
 
223
- RL = output[0:44]
224
- LL = output[44:94]
225
- H = output[94:]
226
-
227
- np.savetxt("tmp/RL_landmarks.txt", RL, delimiter=" ", fmt="%d")
228
- np.savetxt("tmp/LL_landmarks.txt", LL, delimiter=" ", fmt="%d")
229
- np.savetxt("tmp/H_landmarks.txt", H, delimiter=" ", fmt="%d")
230
-
231
- RL_mask, LL_mask, H_mask = getMasks(output, original_shape[0], original_shape[1])
232
-
233
- cv2.imwrite("tmp/RL_mask.png", RL_mask)
234
- cv2.imwrite("tmp/LL_mask.png", LL_mask)
235
- cv2.imwrite("tmp/H_mask.png", H_mask)
236
-
237
- zip = zip_files(
238
- ["tmp/RL_landmarks.txt", "tmp/LL_landmarks.txt", "tmp/H_landmarks.txt", "tmp/RL_mask.png", "tmp/LL_mask.png",
239
- "tmp/H_mask.png", "tmp/overlap_segmentation.png"])
240
-
241
  ctr_value = calculate_ctr(output)
 
242
 
243
- return outseg, ["tmp/RL_landmarks.txt", "tmp/LL_landmarks.txt", "tmp/H_landmarks.txt", "tmp/RL_mask.png",
244
- "tmp/LL_mask.png", "tmp/H_mask.png", "tmp/overlap_segmentation.png", zip], ctr_value
245
 
246
 
247
  if __name__ == "__main__":
@@ -275,8 +249,12 @@ if __name__ == "__main__":
275
 
276
  with gr.Column():
277
  image_output = gr.Image(type="filepath", height=750)
 
 
 
 
 
278
  results = gr.File()
279
- ctr_output = gr.Number(label="CTR (Cardiothoracic Ratio)")
280
 
281
  gr.Markdown("""
282
  If you use this code, please cite:
@@ -322,7 +300,8 @@ if __name__ == "__main__":
322
  clear_button.click(lambda: None, None, image_input, queue=False)
323
  clear_button.click(lambda: None, None, image_output, queue=False)
324
  clear_button.click(lambda: None, None, ctr_output, queue=False)
 
325
 
326
- image_button.click(segment, inputs=image_input, outputs=[image_output, results, ctr_output], queue=False)
327
 
328
  demo.launch()
 
6
  from utils.utils import scipy_to_torch_sparse, genMatrixesLungsHeart
7
  import scipy.sparse as sp
8
  import torch
 
9
 
10
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
11
  hybrid = None
 
177
  return output
178
 
179
 
 
 
 
 
 
 
 
180
  def calculate_ctr(landmarks):
181
  H = landmarks[94:]
182
  RL = landmarks[0:44]
 
184
  cardiac_width = np.max(H[:, 0]) - np.min(H[:, 0])
185
  thoracic_width = max(np.max(RL[:, 0]), np.max(LL[:, 0])) - min(np.min(RL[:, 0]), np.min(LL[:, 0]))
186
  ctr = cardiac_width / thoracic_width if thoracic_width > 0 else 0
187
+ return round(ctr, 3)
188
 
189
 
190
  def segment(input_img):
 
212
  seg_to_save = (outseg.copy() * 255).astype('uint8')
213
  cv2.imwrite("tmp/overlap_segmentation.png", cv2.cvtColor(seg_to_save, cv2.COLOR_RGB2BGR))
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  ctr_value = calculate_ctr(output)
216
+ interpretation = "Normal" if ctr_value < 0.5 else "Cardiomegaly"
217
 
218
+ return outseg, "tmp/overlap_segmentation.png", ctr_value, interpretation
 
219
 
220
 
221
  if __name__ == "__main__":
 
249
 
250
  with gr.Column():
251
  image_output = gr.Image(type="filepath", height=750)
252
+
253
+ with gr.Row():
254
+ ctr_output = gr.Number(label="CTR (Cardiothoracic Ratio)")
255
+ ctr_interpretation = gr.Textbox(label="Interpretation", interactive=False)
256
+
257
  results = gr.File()
 
258
 
259
  gr.Markdown("""
260
  If you use this code, please cite:
 
300
  clear_button.click(lambda: None, None, image_input, queue=False)
301
  clear_button.click(lambda: None, None, image_output, queue=False)
302
  clear_button.click(lambda: None, None, ctr_output, queue=False)
303
+ clear_button.click(lambda: None, None, ctr_interpretation, queue=False)
304
 
305
+ image_button.click(segment, inputs=image_input, outputs=[image_output, results, ctr_output, ctr_interpretation], queue=False)
306
 
307
  demo.launch()