danielhshi8224
standalone cls
e322980
# # app.py β€” Object Detection only (multi-image YOLO, up to 10)
# import os
# import csv
# import tempfile
# from pathlib import Path
# from typing import List, Tuple
# import gradio as gr
# from PIL import Image
# # Try import ultralytics (ensure it's in requirements.txt)
# try:
# from ultralytics import YOLO
# except Exception:
# YOLO = None
# BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# MAX_BATCH = 10
# # Option A: local file baked into Space (easiest if allowed)
# YOLO_WEIGHTS = os.path.join(BASE_DIR, "yolo11_best.pt")
# # Option B (optional): pull from a private HF model repo using a Space secret
# # Set these env vars in your Space if you want auto-download:
# # HF_TOKEN=<read token> YOLO_REPO_ID="yourname/yolo-detector"
# HF_TOKEN = os.environ.get("HF_TOKEN")
# YOLO_REPO_ID = os.environ.get("YOLO_REPO_ID")
# def _download_from_hub_if_needed() -> str | None:
# """If YOLO_REPO_ID is set, download weights with huggingface_hub; else return None."""
# if not YOLO_REPO_ID:
# return None
# try:
# from huggingface_hub import snapshot_download
# local_dir = snapshot_download(
# repo_id=YOLO_REPO_ID, repo_type="model", token=HF_TOKEN
# )
# # try common filenames
# for name in ("yolo11_best.pt", "best.pt", "yolo.pt", "weights.pt"):
# cand = Path(local_dir) / name
# if cand.exists():
# return str(cand)
# except Exception as e:
# print("[YOLO] Hub download failed:", e)
# return None
# _yolo_model = None
# def _load_yolo():
# """Load YOLO weights either from local file or HF Hub."""
# global _yolo_model
# if _yolo_model is not None:
# return _yolo_model
# if YOLO is None:
# raise RuntimeError("ultralytics package not installed. Add 'ultralytics' to requirements.txt")
# model_path = None
# if os.path.exists(YOLO_WEIGHTS):
# model_path = YOLO_WEIGHTS
# else:
# hub_path = _download_from_hub_if_needed()
# if hub_path:
# model_path = hub_path
# if not model_path:
# raise FileNotFoundError(
# "YOLO weights not found. Either include 'yolo11_best.pt' in the repo root, "
# "or set YOLO_REPO_ID (+ HF_TOKEN if private) to pull from the Hub."
# )
# _yolo_model = YOLO(model_path)
# return _yolo_model
# def detect_objects_batch(files, conf=0.25, iou=0.25):
# """
# Run YOLO detection on multiple images (up to 10).
# Returns: gallery of annotated images, rows table, csv filepath
# """
# if YOLO is None:
# return [], [], None
# if not files:
# return [], [], None
# try:
# ymodel = _load_yolo()
# except Exception as e:
# print("YOLO load error:", e)
# return [], [], None
# gallery, table_rows = [], []
# for f in files[:MAX_BATCH]:
# path = getattr(f, "name", None) or getattr(f, "path", None) or f
# try:
# results = ymodel.predict(source=path, conf=conf, iou=iou, imgsz=640, verbose=False)
# except Exception as e:
# print(f"Detection failed for {path}:", e)
# continue
# res = results[0]
# # annotated image
# ann_path = None
# try:
# ann_img = res.plot()
# ann_pil = Image.fromarray(ann_img)
# out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
# os.makedirs(out_dir, exist_ok=True)
# ann_filename = Path(path).stem + "_annotated.jpg"
# ann_path = os.path.join(out_dir, ann_filename)
# ann_pil.save(ann_path)
# except Exception:
# try:
# out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
# res.save(save_dir=out_dir)
# saved_files = getattr(res, "files", [])
# ann_path = saved_files[0] if saved_files else None
# except Exception:
# ann_path = None
# # extract detections
# boxes = getattr(res, "boxes", None)
# if boxes is None or len(boxes) == 0:
# table_rows.append([os.path.basename(path), 0, "", "", ""])
# img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \
# else Image.open(path).convert("RGB")
# gallery.append((img_for_gallery, f"{os.path.basename(path)}\nNo detections"))
# continue
# det_labels, det_scores, det_boxes = [], [], []
# for box in boxes:
# cls = int(box.cls.cpu().item()) if hasattr(box, "cls") else None
# # conf
# try:
# confscore = float(box.conf.cpu().item()) if hasattr(box, "conf") else None
# except Exception:
# try:
# confscore = float(box.conf.item())
# except Exception:
# confscore = None
# # xyxy
# coords = []
# if hasattr(box, "xyxy"):
# try:
# arr = box.xyxy.cpu().numpy()
# if getattr(arr, "ndim", None) == 2 and arr.shape[0] == 1:
# coords = arr[0].tolist()
# elif getattr(arr, "ndim", None) == 1:
# coords = arr.tolist()
# else:
# coords = arr.reshape(-1).tolist()
# except Exception:
# try:
# coords = box.xyxy.tolist()
# except Exception:
# coords = []
# det_labels.append(ymodel.names.get(cls, str(cls)) if cls is not None else "")
# det_scores.append(round(confscore, 4) if confscore is not None else "")
# try:
# det_boxes.append([round(float(x), 2) for x in coords])
# except Exception:
# det_boxes.append([str(coords)])
# label_conf_pairs = [f"{l}:{s}" for l, s in zip(det_labels, det_scores)]
# boxes_repr = ["[" + ", ".join(map(str, b)) + "]" for b in det_boxes]
# table_rows.append([
# os.path.basename(path),
# len(det_labels),
# ", ".join(label_conf_pairs),
# ", ".join(boxes_repr),
# "; ".join([str(b) for b in det_boxes]),
# ])
# img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \
# else Image.open(path).convert("RGB")
# gallery.append((img_for_gallery, f"{os.path.basename(path)}\n{len(det_labels)} detections"))
# # write CSV
# csv_path = None
# try:
# tmp = tempfile.NamedTemporaryFile(
# delete=False, suffix=".csv", prefix="yolo_preds_", dir=BASE_DIR,
# mode="w", newline='', encoding='utf-8'
# )
# writer = csv.writer(tmp)
# writer.writerow(["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"])
# for r in table_rows:
# writer.writerow(r)
# tmp.flush(); tmp.close()
# csv_path = tmp.name
# except Exception as e:
# print("Failed to write CSV:", e)
# csv_path = None
# return gallery, table_rows, csv_path
# # ---------- UI ----------
# if YOLO is None:
# demo = gr.Interface(
# fn=lambda *a, **k: ("Ultralytics not installed; add 'ultralytics' to requirements.txt",),
# inputs=[],
# outputs="text",
# title="🌊 BenthicAI β€” Object Detection",
# description="Ultralytics is not installed."
# )
# else:
# demo = gr.Interface(
# fn=detect_objects_batch,
# inputs=[
# gr.Files(label="Upload images (max 10)"),
# gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Conf threshold"),
# gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="IoU threshold"),
# ],
# outputs=[
# gr.Gallery(label="Detections (annotated)", height=500, rows=3),
# gr.Dataframe(headers=["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"],
# label="Detection Table"),
# gr.File(label="Download CSV"),
# ],
# title="🌊 BenthicAI β€” Object Detection",
# description=(
# "Run YOLO object detection on multiple images. "
# "Place 'yolo11_best.pt' in the repo root, OR set YOLO_REPO_ID (+ HF_TOKEN if private) "
# "to fetch from the Hub."
# ),
# )
# if __name__ == "__main__":
# demo.launch(server_name="0.0.0.0", server_port=7860)
# app.py β€” Image Classification only (single + batch up to 10)
import os
import csv
import tempfile
from pathlib import Path
from typing import List, Tuple
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_ID = "dshi01/convnext-tiny-224-7clss" # your HF model repo id
PROCESSOR_ID = "facebook/convnext-tiny-224" # feature extractor
print(f"[IC] Loading model: {MODEL_ID}")
processor = AutoImageProcessor.from_pretrained(PROCESSOR_ID)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
# Build id2label list (stable order)
ID2LABEL = [
model.config.id2label.get(str(i), model.config.id2label.get(i, f"Label_{i}"))
for i in range(model.config.num_labels)
]
def classify_image(image):
"""Single-image classification."""
if not isinstance(image, Image.Image):
image = Image.fromarray(image).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=1)[0].tolist()
return {ID2LABEL[i]: float(p) for i, p in enumerate(probs)}
MAX_BATCH = 10
def classify_images_batch(files):
"""
Batch classification (up to 10).
Returns: gallery [(img, caption)], table rows, CSV filepath
"""
if not files:
return [], [], None
files = files[:MAX_BATCH]
# Load PILs
pil_images, names = [], []
for f in files:
path = getattr(f, "name", None) or getattr(f, "path", None) or f
try:
img = Image.open(path).convert("RGB")
pil_images.append(img)
names.append(os.path.basename(path))
except Exception:
continue
if not pil_images:
return [], [], None
inputs = processor(images=pil_images, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=1)
gallery = []
table_rows = [] # [filename, top1_label, top1_conf, top3_labels, top3_confs]
for idx, (img, fname) in enumerate(zip(pil_images, names)):
p = probs[idx].tolist()
top_idxs = sorted(range(len(p)), key=lambda i: p[i], reverse=True)[:3]
top1 = top_idxs[0]
caption = f"{ID2LABEL[top1]} ({p[top1]:.2%})"
gallery.append((img, f"{fname}\n{caption}"))
top3_labels = [ID2LABEL[i] for i in top_idxs]
top3_scores = [round(p[i], 4) for i in top_idxs]
table_rows.append([
fname,
ID2LABEL[top1],
round(p[top1], 4),
", ".join(top3_labels),
", ".join(map(str, top3_scores)),
])
# Create CSV for download
csv_path = None
try:
tmp = tempfile.NamedTemporaryFile(
delete=False, suffix=".csv", prefix="predictions_", dir=BASE_DIR,
mode="w", newline='', encoding='utf-8'
)
writer = csv.writer(tmp)
writer.writerow(["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"])
for row in table_rows:
writer.writerow(row)
tmp.flush(); tmp.close()
csv_path = tmp.name
except Exception:
csv_path = None
return gallery, table_rows, csv_path
# ---------- UI ----------
single = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="Upload Underwater Image"),
outputs=gr.Label(num_top_classes=len(ID2LABEL), label="Species Classification"),
title="🌊 BenthicAI β€” Single Image",
description="Classify one image into one of 7 benthic species."
)
batch = gr.Interface(
fn=classify_images_batch,
inputs=gr.Files(label="Upload up to 10 images"),
outputs=[
gr.Gallery(label="Results (Top-1 in caption)", height=500, rows=3),
gr.Dataframe(
headers=["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"],
label="Predictions Table",
wrap=True
),
gr.File(label="Download CSV")
],
title="🌊 BenthicAI β€” Batch (up to 10)",
description="Upload multiple images (max 10)."
)
demo = gr.TabbedInterface([single, batch], ["Single", "Batch"])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)