fahimehorvatinia's picture
update
2592b1e verified
import os
# Force cache to /tmp before any imports
os.environ['HF_HOME'] = '/tmp/huggingface'
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface/transformers'
os.environ['TORCH_HOME'] = '/tmp/torch'
import gradio as gr
import tempfile
from pathlib import Path
from wrapper import run_pipeline_on_image
from PIL import Image
# Base directory for preset images
BASE_DIR = Path(__file__).resolve().parent
# Preset images available for selection
PRESET_IMAGES = {
"Sorghum": str(BASE_DIR / "Sorghum.tif"),
"Corn": str(BASE_DIR / "Corn.tif"),
"Cotton": str(BASE_DIR / "Cotton.tif"),
}
# Path to header logo
LOGO_PATH = str(BASE_DIR / "logo.png")
def process(file_path, preset_choice):
"""Process image and yield results progressively for immediate display."""
# Determine dataset type (single-plant mode for Corn, multi-plant for others)
single_plant_mode = False
if preset_choice:
chosen = PRESET_IMAGES.get(preset_choice)
if chosen:
file_path = chosen
# Corn uses single-plant mode
single_plant_mode = (preset_choice == "Corn")
if not file_path:
# Return 10 outputs (removed YOLO tips)
return None, None, None, None, None, None, None, [], None, ""
with tempfile.TemporaryDirectory() as tmpdir:
src = Path(file_path)
ext = src.suffix.lstrip('.') or 'tif'
img_path = Path(tmpdir) / f"input.{ext}"
try:
# Copy raw uploaded bytes
img_bytes = src.read_bytes()
img_path.write_bytes(img_bytes)
except Exception:
# Fallback: save via PIL if direct copy fails
Image.open(src).save(img_path)
# Don't show immediate preview - wait for pipeline's InputImage for correctness
input_preview = None
# Helper to load PIL images
def load_pil(path_str):
try:
if not path_str:
return None
im = Image.open(path_str)
copied = im.copy()
im.close()
return copied
except Exception:
return None
# Run the pipeline progressively (generator)
for outputs in run_pipeline_on_image(str(img_path), tmpdir, save_artifacts=True, single_plant_mode=single_plant_mode):
# Load all available outputs progressively
composite = load_pil(outputs.get('Composite'))
overlay = load_pil(outputs.get('Overlay'))
mask = load_pil(outputs.get('Mask'))
input_img = load_pil(outputs.get('InputImage')) or input_preview
size_img = load_pil(str(Path(tmpdir) / 'results/size.size_analysis.png'))
# Texture images (green band)
lbp_path = Path(tmpdir) / 'texture_output/lbp_green.png'
hog_path = Path(tmpdir) / 'texture_output/hog_green.png'
lac1_path = Path(tmpdir) / 'texture_output/lac1_green.png'
texture_img = load_pil(str(lbp_path)) if lbp_path.exists() else None
hog_img = load_pil(str(hog_path)) if hog_path.exists() else None
lac1_img = load_pil(str(lac1_path)) if lac1_path.exists() else None
# Vegetation indices
order = ['NDVI', 'GNDVI', 'SAVI']
gallery_items = [load_pil(outputs[k]) for k in order if k in outputs]
stats_text = outputs.get('StatsText', '')
# Yield intermediate/final results as they become available
yield (
input_img,
composite,
mask,
overlay,
texture_img,
hog_img,
lac1_img,
gallery_items,
size_img,
stats_text,
)
with gr.Blocks(fill_width=True) as demo:
# Header logo (no share/download/fullscreen buttons), full width
gr.Image(
value=LOGO_PATH,
show_label=False,
interactive=False,
height=80,
container=False,
show_download_button=False,
show_share_button=False,
show_fullscreen_button=False,
)
gr.Markdown("# 🌿 Automated Plant Analysis Demo")
gr.Markdown("Upload a plant image (TIFF preferred) to compute and visualize composite, mask, overlay, texture, vegetation indices, and statistics.")
with gr.Row():
with gr.Column():
# Use File input to preserve raw TIFFs
inp = gr.File(
type="filepath",
file_types=[".tif", ".tiff", ".png", ".jpg"],
label="Upload Image"
)
preset = gr.Radio(
choices=list(PRESET_IMAGES.keys()),
label="Or choose a preset image",
value=None
)
run = gr.Button("Run Pipeline", variant="primary")
# Row 1: input image
with gr.Row():
input_img = gr.Image(type="pil", label="Input Image", interactive=False, height=380)
# Row 2: composite, mask, overlay
with gr.Row():
composite_img = gr.Image(type="pil", label="Composite (Segmentation Input)", interactive=False)
mask_img = gr.Image(type="pil", label="Mask", interactive=False)
overlay_img = gr.Image(type="pil", label="Segmentation Overlay", interactive=False)
# Row 3: textures
with gr.Row():
texture_img = gr.Image(type="pil", label="Texture LBP (Green Band)", interactive=False)
hog_img = gr.Image(type="pil", label="Texture HOG (Green Band)", interactive=False)
lac1_img = gr.Image(type="pil", label="Texture Lac1 (Green Band)", interactive=False)
# Row 4: vegetation indices
gallery = gr.Gallery(label="Vegetation Indices", columns=3, height="auto")
# Row 5: morphology size (YOLO removed)
with gr.Row():
size_img = gr.Image(type="pil", label="Plant Morphology", interactive=False)
# Final: statistics table
stats = gr.Textbox(label="Statistics", lines=4)
run.click(
process,
inputs=[inp, preset],
outputs=[
input_img,
composite_img,
mask_img,
overlay_img,
texture_img,
hog_img,
lac1_img,
gallery,
size_img,
stats,
]
)
if __name__ == "__main__":
demo.launch()