| import os | |
| import time | |
| import h5py | |
| import numpy as np | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| from railnet_model import RailNetSystem | |
| os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| model = RailNetSystem.from_pretrained(".").cuda() | |
| model.load_weights(".") | |
| def wait_for_stable_file(file_path, timeout=5, check_interval=0.2): | |
| start_time = time.time() | |
| last_size = -1 | |
| while time.time() - start_time < timeout: | |
| current_size = os.path.getsize(file_path) | |
| if current_size == last_size: | |
| return True | |
| last_size = current_size | |
| time.sleep(check_interval) | |
| return False | |
| def process_cbct_file(h5_file, save_dir="./output"): | |
| if not wait_for_stable_file(h5_file.name): | |
| raise RuntimeError("File upload has not been completed or is unstable, please try again.") | |
| try: | |
| with h5py.File(h5_file.name, "r") as f: | |
| if "image" not in f or "label" not in f: | |
| raise KeyError("The file is missing ‘image’ or ‘label’ value") | |
| image = f["image"][:] | |
| label = f["label"][:] | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to read the .h5 file: {str(e)}") | |
| name = os.path.basename(h5_file.name).replace(".h5", "") | |
| pred, dice, jc, hd, asd = model(image, label, save_dir, name) | |
| return pred, f"Dice: {dice:.4f}, Jaccard: {jc:.4f}, 95HD: {hd:.2f}, ASD: {asd:.2f}" | |
| def render_plotly_volume(pred, x_eye=1.25, y_eye=1.25, z_eye=1.25): | |
| downsample_factor = 2 | |
| pred_ds = pred[::downsample_factor, ::downsample_factor, ::downsample_factor] | |
| fig = go.Figure(data=go.Volume( | |
| x=np.repeat(np.arange(pred_ds.shape[0]), pred_ds.shape[1] * pred_ds.shape[2]), | |
| y=np.tile(np.repeat(np.arange(pred_ds.shape[1]), pred_ds.shape[2]), pred_ds.shape[0]), | |
| z=np.tile(np.arange(pred_ds.shape[2]), pred_ds.shape[0] * pred_ds.shape[1]), | |
| value=pred_ds.flatten(), | |
| isomin=0.5, | |
| isomax=1.0, | |
| opacity=0.1, | |
| surface_count=1, | |
| colorscale=[[0, 'rgb(255, 0, 0)'], [1, 'rgb(255, 0, 0)']], | |
| showscale=False | |
| )) | |
| fig.update_layout( | |
| scene=dict( | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| zaxis=dict(visible=False), | |
| camera=dict(eye=dict(x=x_eye, y=y_eye, z=z_eye)) | |
| ), | |
| margin=dict(l=0, r=0, b=0, t=0) | |
| ) | |
| return fig | |
| def clear_all(): | |
| return None, "", None | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<div style='text-align: center; font-size: 28px; font-weight: bold;'>🦷 Demo of RailNet: A CBCT Tooth Segmentation System</div>") | |
| gr.Markdown("<div style='text-align: center; font-size: 20px'>✅ Steps: Upload a CBCT example file (.h5) → Automatic inference and metrics display → View 3D segmentation result (Mouse drag and scroll wheel zooming)</div>") | |
| gr.Markdown("<div style='height: 20px;'></div>") | |
| gr.Markdown("<div style='font-size: 20px; font-weight: bold;'>📂 Step 1: Upload the .h5 example file containing both ‘image’ and ‘label’ values</div>") | |
| file_input = gr.File() | |
| with gr.Row(): | |
| clear_btn = gr.Button("清除", variant="secondary") | |
| submit_btn = gr.Button("提交", variant="primary") | |
| gr.Markdown("<div style='height: 20px;'></div>") | |
| gr.Markdown("<div style='font-size: 20px; font-weight: bold;'>📊 Step 2: Metrics (Dice, Jaccard, 95HD, ASD)</div>") | |
| result_text = gr.Textbox() | |
| hidden_pred = gr.State(value=None) | |
| gr.Markdown("<div style='height: 20px;'></div>") | |
| gr.Markdown("<div style='font-size: 20px; font-weight: bold;'>👁️ Step 3: 3D Visualisation</div>") | |
| plot_output = gr.Plot() | |
| def handle_upload(h5_file): | |
| pred, metrics = process_cbct_file(h5_file) | |
| fig = render_plotly_volume(pred) | |
| return metrics, pred, fig | |
| submit_btn.click( | |
| fn=handle_upload, | |
| inputs=[file_input], | |
| outputs=[result_text, hidden_pred, plot_output] | |
| ) | |
| def update_view(pred, x_eye, y_eye, z_eye): | |
| if pred is None: | |
| return gr.update() | |
| return render_plotly_volume(pred, x_eye, y_eye, z_eye) | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[file_input, result_text, plot_output] | |
| ) | |
| demo.launch() | |