File size: 3,210 Bytes
b4123b8
dd1d7f5
b4123b8
 
 
 
 
 
 
 
dd1d7f5
b4123b8
 
 
 
 
 
dd1d7f5
b4123b8
dd1d7f5
 
 
 
b4123b8
 
dd1d7f5
 
dffab99
 
 
 
7ac2007
 
 
 
 
dd1d7f5
 
 
7ac2007
dd1d7f5
dffab99
7ac2007
dd1d7f5
 
981de0a
dd1d7f5
981de0a
dd1d7f5
 
 
981de0a
b4123b8
 
dd1d7f5
b4123b8
31ddfa7
b4123b8
 
 
31ddfa7
 
 
 
dd1d7f5
b4123b8
 
31ddfa7
dd1d7f5
b4123b8
 
31ddfa7
b4123b8
 
dd1d7f5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Minimal segmentation manager.
"""

import numpy as np
import cv2
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Optional
import logging

logger = logging.getLogger(__name__)


class SegmentationManager:
    """Minimal BRIA segmentation."""
    
    def __init__(self, model_name: str = "briaai/RMBG-2.0", device: str = "auto",
                 threshold: float = 0.5, trust_remote_code: bool = True,
                 cache_dir: Optional[str] = None, local_files_only: bool = False):
        """Initialize segmentation."""
        self.model_name = model_name
        self.threshold = threshold
        self.device = "cuda" if device == "auto" and torch.cuda.is_available() else device
        
        # Get HF token from environment (set as Space secret)
        import os
        hf_token = os.environ.get("HF_TOKEN")
        
        # Set cache directory to /tmp to avoid persistent storage issues
        if cache_dir is None:
            cache_dir = "/tmp/huggingface_cache"
        
        logger.info(f"Loading BRIA model: {model_name} (cache: {cache_dir})")
        self.model = AutoModelForImageSegmentation.from_pretrained(
            model_name,
            trust_remote_code=trust_remote_code,
            cache_dir=cache_dir,
            local_files_only=local_files_only,
            token=hf_token,
            low_cpu_mem_usage=True,  # Reduce memory usage during loading
        ).eval().to(self.device)
        
        # Use 384x384 for even faster speed (6x improvement over 1024x1024)
        self.transform = transforms.Compose([
            transforms.Resize((384, 384)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        logger.info(f"BRIA model loaded on device: {self.device}")
    
    def segment_image_soft(self, image: np.ndarray) -> np.ndarray:
        """Segment image and return soft mask [0,1]."""
        try:
            logger.info(f"Segmentation: input image shape={image.shape}, dtype={image.dtype}")
            rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(rgb_image)
            input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
            try:
                logger.info(f"Segmentation: tensor shape={input_tensor.shape}, device={self.device}")
            except Exception:
                pass
            
            with torch.no_grad():
                preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
            logger.info(f"Segmentation: raw preds shape={preds.shape}, dtype={preds.dtype}")
            
            original_size = (image.shape[1], image.shape[0])
            soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR)
            logger.info(f"Segmentation: resized soft_mask shape={soft_mask.shape}, dtype={soft_mask.dtype}")
            return np.clip(soft_mask, 0.0, 1.0)
        except Exception as e:
            logger.error(f"Segmentation failed: {e}")
            return np.zeros(image.shape[:2], dtype=np.float32)