Spaces:
Running
Running
| #!/usr/bin/env python | |
| # furniture_bbox_to_files.py ββββββββββββββββββββββββββββββββββββββββ | |
| # Florence-2 + SAM-2 batch processor with retries *and* file-based images | |
| # -------------------------------------------------------------------- | |
| import os, json, random, time | |
| from pathlib import Path | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from typing import List | |
| import torch, supervision as sv | |
| from PIL import Image, ImageDraw, ImageColor, ImageOps | |
| from tqdm.auto import tqdm | |
| from datasets import load_dataset, Image as HFImage, disable_progress_bar | |
| # βββββ global models ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| from utils.florence import ( | |
| load_florence_model, run_florence_inference, | |
| FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, | |
| ) | |
| from utils.sam import load_sam_image_model, run_sam_inference | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| FLORENCE_MODEL, FLORENCE_PROC = load_florence_model(device=DEVICE) | |
| SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) | |
| # annotators | |
| _PALETTE = sv.ColorPalette.from_hex( | |
| ['#FF1493','#00BFFF','#FF6347','#FFD700','#32CD32','#8A2BE2']) | |
| BOX_ANN = sv.BoxAnnotator(color=_PALETTE, color_lookup=sv.ColorLookup.INDEX) | |
| MASK_ANN = sv.MaskAnnotator(color=_PALETTE, color_lookup=sv.ColorLookup.INDEX) | |
| LBL_ANN = sv.LabelAnnotator( | |
| color=_PALETTE, color_lookup=sv.ColorLookup.INDEX, | |
| text_position=sv.Position.CENTER_OF_MASS, | |
| text_color=sv.Color.from_hex("#000"), border_radius=5) | |
| # βββββ config βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| disable_progress_bar() | |
| DATASET_NAME = "fotographerai/furniture_captioned_segment_prompt" | |
| SPLIT = "train" | |
| IMAGE_COL = "img2" | |
| PROMPT_COL = "segmenting_prompt" | |
| INFLATE_RANGE = (0.01, 0.05) | |
| FILL_COLOR = "#00FF00" | |
| TARGET_SIDE = 1500 | |
| QA_DIR = Path("bbox_review_recaptioned") | |
| GREEN_DIR = QA_DIR / "green"; GREEN_DIR.mkdir(parents=True, exist_ok=True) | |
| ANNO_DIR = QA_DIR / "anno"; ANNO_DIR.mkdir(parents=True, exist_ok=True) | |
| JSON_DIR = QA_DIR / "json"; JSON_DIR.mkdir(parents=True, exist_ok=True) | |
| MAX_WORKERS = 100 | |
| MAX_RETRIES = 5 | |
| RETRY_SLEEP = .3 | |
| FAILED_LOG = QA_DIR / "failed_rows.jsonl" | |
| PROMPT_MAP: dict[str,str] = {} # optional overrides | |
| # βββββ helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_square(img: Image.Image, side: int = TARGET_SIDE) -> Image.Image: | |
| img = ImageOps.contain(img, (side, side)) | |
| pad_w, pad_h = side - img.width, side - img.height | |
| return ImageOps.expand(img, border=(pad_w//2, pad_h//2, | |
| pad_w - pad_w//2, pad_h - pad_h//2), | |
| fill=img.getpixel((0,0))) | |
| def img_to_file(img: Image.Image, fname: str, folder: Path) -> dict: | |
| path = folder / f"{fname}.png" | |
| if not path.exists(): | |
| img.save(path) | |
| return {"path": str(path), "bytes": None} | |
| # βββββ core functions βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def detect_and_segment(img: Image.Image, prompts: str|List[str]) -> sv.Detections: | |
| if isinstance(prompts, str): | |
| prompts = [p.strip() for p in prompts.split(",") if p.strip()] | |
| all_dets = [] | |
| for p in prompts: | |
| _, res = run_florence_inference( | |
| model=FLORENCE_MODEL, processor=FLORENCE_PROC, device=DEVICE, | |
| image=img, task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, text=p) | |
| d = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, res, img.size) | |
| all_dets.append(run_sam_inference(SAM_IMAGE_MODEL, img, d)) | |
| return sv.Detections.merge(all_dets) | |
| def fill_detected_bboxes(img: Image.Image, prompt: str, | |
| inflate_pct: float) -> tuple[Image.Image, sv.Detections]: | |
| dets = detect_and_segment(img, prompt) | |
| filled = img.copy() | |
| draw = ImageDraw.Draw(filled) | |
| rgb = ImageColor.getrgb(FILL_COLOR) | |
| w,h = img.size | |
| for box in dets.xyxy: | |
| x1,y1,x2,y2 = box.astype(float) | |
| dw,dh = (x2-x1)*inflate_pct, (y2-y1)*inflate_pct | |
| draw.rectangle([max(0,x1-dw), max(0,y1-dh), | |
| min(w,x2+dw), min(h,y2+dh)], fill=rgb) | |
| return filled, dets | |
| # βββββ threaded worker ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def process_row(idx: int, sample): | |
| prompt = PROMPT_MAP.get(sample[PROMPT_COL], | |
| sample[PROMPT_COL].split(",",1)[0].strip()) | |
| img_sq = make_square(sample[IMAGE_COL].convert("RGB")) | |
| for attempt in range(1, MAX_RETRIES+1): | |
| try: | |
| filled, dets = fill_detected_bboxes( | |
| img_sq, prompt, inflate_pct=random.uniform(*INFLATE_RANGE)) | |
| if len(dets.xyxy) == 0: | |
| raise ValueError("no detections") | |
| sid = f"{idx:06d}" | |
| json_p = JSON_DIR / f"{sid}_bbox.json" | |
| json_p.write_text(json.dumps({"xyxy": dets.xyxy.tolist()})) | |
| anno = img_sq.copy() | |
| for ann in (MASK_ANN, BOX_ANN, LABEL_ANN): | |
| anno = ann.annotate(anno, dets) | |
| return ("ok", | |
| img_to_file(filled, sid, GREEN_DIR), | |
| img_to_file(anno, sid, ANNO_DIR), | |
| json_p.read_text()) | |
| except Exception as e: | |
| if attempt < MAX_RETRIES: | |
| time.sleep(RETRY_SLEEP) | |
| else: | |
| return ("fail", str(e)) | |
| # βββββ run batch ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ds = load_dataset(DATASET_NAME, split=SPLIT, streaming=False) | |
| N = len(ds) | |
| print("Rows:", N) | |
| filled_col, anno_col, json_col = [None]*N, [None]*N, [None]*N | |
| fails = 0 | |
| with ThreadPoolExecutor(MAX_WORKERS) as pool: | |
| fut2idx = {pool.submit(process_row, i, ds[i]): i for i in range(N)} | |
| for fut in tqdm(as_completed(fut2idx), total=N, desc="Florence+SAM"): | |
| idx = fut2idx[fut] | |
| status, *data = fut.result() | |
| if status == "ok": | |
| filled_col[idx], anno_col[idx], json_col[idx] = data | |
| else: | |
| fails += 1 | |
| FAILED_LOG.write_text(json.dumps({"idx": idx, "reason": data[0]})+"\n") | |
| print(f"β permanently failed rows: {fails}") | |
| keep = [i for i,x in enumerate(filled_col) if x] | |
| new_ds = ds.select(keep) | |
| new_ds = new_ds.add_column("bbox_filled", [filled_col[i] for i in keep]) | |
| new_ds = new_ds.add_column("annotated", [anno_col[i] for i in keep]) | |
| new_ds = new_ds.add_column("bbox_json", [json_col[i] for i in keep]) | |
| new_ds = new_ds.cast_column("bbox_filled", HFImage()) | |
| new_ds = new_ds.cast_column("annotated", HFImage()) | |
| print(f"β successes: {len(new_ds)} / {N}") | |
| print("Columns:", new_ds.column_names) | |
| print("QA artefacts β", QA_DIR.resolve()) | |
| # optional push | |
| new_ds.push_to_hub("fotographerai/surround_furniture_bboxfilled", | |
| private=True, max_shard_size="500MB") | |