|
|
""" |
|
|
Minimal segmentation manager. |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import cv2 |
|
|
import torch |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from transformers import AutoModelForImageSegmentation |
|
|
from typing import Optional |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class SegmentationManager: |
|
|
"""Minimal BRIA segmentation.""" |
|
|
|
|
|
def __init__(self, model_name: str = "briaai/RMBG-2.0", device: str = "auto", |
|
|
threshold: float = 0.5, trust_remote_code: bool = True, |
|
|
cache_dir: Optional[str] = None, local_files_only: bool = False): |
|
|
"""Initialize segmentation.""" |
|
|
self.model_name = model_name |
|
|
self.threshold = threshold |
|
|
self.device = "cuda" if device == "auto" and torch.cuda.is_available() else device |
|
|
|
|
|
|
|
|
import os |
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
|
|
|
if cache_dir is None: |
|
|
cache_dir = "/tmp/huggingface_cache" |
|
|
|
|
|
logger.info(f"Loading BRIA model: {model_name} (cache: {cache_dir})") |
|
|
self.model = AutoModelForImageSegmentation.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=trust_remote_code, |
|
|
cache_dir=cache_dir, |
|
|
local_files_only=local_files_only, |
|
|
token=hf_token, |
|
|
low_cpu_mem_usage=True, |
|
|
).eval().to(self.device) |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((384, 384)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
logger.info(f"BRIA model loaded on device: {self.device}") |
|
|
|
|
|
def segment_image_soft(self, image: np.ndarray) -> np.ndarray: |
|
|
"""Segment image and return soft mask [0,1].""" |
|
|
try: |
|
|
logger.info(f"Segmentation: input image shape={image.shape}, dtype={image.dtype}") |
|
|
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
pil_image = Image.fromarray(rgb_image) |
|
|
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) |
|
|
try: |
|
|
logger.info(f"Segmentation: tensor shape={input_tensor.shape}, device={self.device}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
with torch.no_grad(): |
|
|
preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy() |
|
|
logger.info(f"Segmentation: raw preds shape={preds.shape}, dtype={preds.dtype}") |
|
|
|
|
|
original_size = (image.shape[1], image.shape[0]) |
|
|
soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR) |
|
|
logger.info(f"Segmentation: resized soft_mask shape={soft_mask.shape}, dtype={soft_mask.dtype}") |
|
|
return np.clip(soft_mask, 0.0, 1.0) |
|
|
except Exception as e: |
|
|
logger.error(f"Segmentation failed: {e}") |
|
|
return np.zeros(image.shape[:2], dtype=np.float32) |