| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import sys | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						import threading | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import warnings | 
					
					
						
						| 
							 | 
						import time | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						warnings.filterwarnings("ignore", message=".*copying from a non-meta parameter.*") | 
					
					
						
						| 
							 | 
						warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						current_dir = os.path.dirname(os.path.abspath(__file__)) | 
					
					
						
						| 
							 | 
						parent_dir = os.path.dirname(os.path.dirname(current_dir)) | 
					
					
						
						| 
							 | 
						if parent_dir not in sys.path: | 
					
					
						
						| 
							 | 
						    sys.path.insert(0, parent_dir) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from transformers import ( | 
					
					
						
						| 
							 | 
						    AutoProcessor, | 
					
					
						
						| 
							 | 
						    AutoImageProcessor, | 
					
					
						
						| 
							 | 
						    AutoModelForObjectDetection, | 
					
					
						
						| 
							 | 
						    DetrImageProcessor, | 
					
					
						
						| 
							 | 
						    DetrForObjectDetection, | 
					
					
						
						| 
							 | 
						    AutoModelForImageSegmentation, | 
					
					
						
						| 
							 | 
						    YolosImageProcessor, | 
					
					
						
						| 
							 | 
						    YolosForObjectDetection, | 
					
					
						
						| 
							 | 
						    SamModel, | 
					
					
						
						| 
							 | 
						    SamProcessor | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						try: | 
					
					
						
						| 
							 | 
						    import timm | 
					
					
						
						| 
							 | 
						    TIMM_AVAILABLE = True | 
					
					
						
						| 
							 | 
						except ImportError: | 
					
					
						
						| 
							 | 
						    TIMM_AVAILABLE = False | 
					
					
						
						| 
							 | 
						    logging.warning("TIMM library not available - color extraction model will not be loaded") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def setup_device(): | 
					
					
						
						| 
							 | 
						    if os.getenv("SPACE_ID"): | 
					
					
						
						| 
							 | 
						        return "cpu" | 
					
					
						
						| 
							 | 
						    elif torch.cuda.is_available(): | 
					
					
						
						| 
							 | 
						        device_count = torch.cuda.device_count() | 
					
					
						
						| 
							 | 
						        if device_count >= 1: | 
					
					
						
						| 
							 | 
						            return "cuda" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return "cpu" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def check_cuda_availability(): | 
					
					
						
						| 
							 | 
						    if os.getenv("SPACE_ID"): | 
					
					
						
						| 
							 | 
						        logging.info("Running in Hugging Face Spaces (Zero GPU) - GPU will be available in decorated functions") | 
					
					
						
						| 
							 | 
						        return False | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if not torch.cuda.is_available(): | 
					
					
						
						| 
							 | 
						        logging.warning("\n" + "="*60 + "\n" +  | 
					
					
						
						| 
							 | 
						                     "WARNING: CUDA NOT AVAILABLE!\n" + | 
					
					
						
						| 
							 | 
						                     "Running on CPU. Performance will be significantly reduced.\n" + | 
					
					
						
						| 
							 | 
						                     "="*60 + "\n") | 
					
					
						
						| 
							 | 
						        return False | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    device_count = torch.cuda.device_count() | 
					
					
						
						| 
							 | 
						    if device_count > 0: | 
					
					
						
						| 
							 | 
						        for i in range(device_count): | 
					
					
						
						| 
							 | 
						            props = torch.cuda.get_device_properties(i) | 
					
					
						
						| 
							 | 
						            logging.info(f"GPU {i}: {props.name} (Memory: {props.total_memory / (1024**3):.1f} GB)") | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        logging.info("CUDA available but no GPUs detected") | 
					
					
						
						| 
							 | 
						    return True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def check_hardware_environment(): | 
					
					
						
						| 
							 | 
						    gpu_available = check_cuda_availability() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if os.getenv("SPACE_ID"): | 
					
					
						
						| 
							 | 
						        ensure_zerogpu() | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        if gpu_available: | 
					
					
						
						| 
							 | 
						            logging.info(f"Running on {setup_device().upper()}") | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            logging.info("Running on CPU") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def ensure_zerogpu(): | 
					
					
						
						| 
							 | 
						    space_id = os.getenv("SPACE_ID") | 
					
					
						
						| 
							 | 
						    hf_token = os.getenv("HF_TOKEN") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if not space_id: | 
					
					
						
						| 
							 | 
						        logging.info("Not running in Hugging Face Spaces") | 
					
					
						
						| 
							 | 
						        return | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        from huggingface_hub import HfApi | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        api = HfApi(token=hf_token) if hf_token else HfApi() | 
					
					
						
						| 
							 | 
						        space_info = api.get_space_runtime(space_id) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        current_hardware = getattr(space_info, 'hardware', None) | 
					
					
						
						| 
							 | 
						        logging.info(f"Current space hardware: {current_hardware}") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if current_hardware and "a10g" not in current_hardware.lower(): | 
					
					
						
						| 
							 | 
						            logging.warning(f"Space is running on {current_hardware}, not zero-a10g") | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if hf_token: | 
					
					
						
						| 
							 | 
						                try: | 
					
					
						
						| 
							 | 
						                    api.request_space_hardware(repo_id=space_id, hardware="zero-a10g") | 
					
					
						
						| 
							 | 
						                    logging.info("Requested hardware change to zero-a10g") | 
					
					
						
						| 
							 | 
						                except Exception as e: | 
					
					
						
						| 
							 | 
						                    logging.error(f"Failed to request hardware change: {e}") | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                logging.warning("Cannot request hardware change without HF_TOKEN") | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            logging.info("Space is already running on zero-a10g") | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						    except ImportError: | 
					
					
						
						| 
							 | 
						        logging.warning("huggingface_hub not available, cannot verify space hardware") | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        logging.error(f"Unexpected error in ensure_zerogpu: {str(e)}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						DEVICE = setup_device() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						RTDETR_FULL_PRECISION = True | 
					
					
						
						| 
							 | 
						HEAD_DETECTION_FULL_PRECISION = True | 
					
					
						
						| 
							 | 
						RMBG_FULL_PRECISION = True | 
					
					
						
						| 
							 | 
						BIREFNET_FULL_PRECISION = True | 
					
					
						
						| 
							 | 
						YOLOS_FASHIONPEDIA_FULL_PRECISION = True | 
					
					
						
						| 
							 | 
						TIMM_COLOR_EXTRACTOR_FULL_PRECISION = True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						USE_TORCH_COMPILE = True | 
					
					
						
						| 
							 | 
						TORCH_COMPILE_MODE = "reduce-overhead" | 
					
					
						
						| 
							 | 
						TORCH_COMPILE_BACKEND = "inductor" | 
					
					
						
						| 
							 | 
						ENABLE_CHANNELS_LAST = True | 
					
					
						
						| 
							 | 
						ENABLE_CUDA_GRAPHS = True | 
					
					
						
						| 
							 | 
						USE_MIXED_PRECISION = True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						RTDETR_REPO = "PekingU/rtdetr_r50vd" | 
					
					
						
						| 
							 | 
						HEAD_DETECTION_REPO = "sanali209/DT_face_head_char" | 
					
					
						
						| 
							 | 
						RMBG_REPO = "briaai/RMBG-2.0"   | 
					
					
						
						| 
							 | 
						BIREFNET_REPO = "ZhengPeng7/BiRefNet-matting"   | 
					
					
						
						| 
							 | 
						YOLOS_FASHIONPEDIA_REPO = "valentinafeve/yolos-fashionpedia" | 
					
					
						
						| 
							 | 
						TIMM_COLOR_EXTRACTOR_REPO = "hf-hub:timm/resnet50.a2_in1k" | 
					
					
						
						| 
							 | 
						SAM_REPO = "facebook/sam-vit-base"   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						BIREFNET_CONFIG_PYTHON_TEMPLATE = """from transformers.configuration_utils import PretrainedConfig | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class BiRefNetConfig(PretrainedConfig): | 
					
					
						
						| 
							 | 
						    model_type = "SegformerForSemanticSegmentation" | 
					
					
						
						| 
							 | 
						    num_channels = 3 | 
					
					
						
						| 
							 | 
						    backbone = "mit_b5" | 
					
					
						
						| 
							 | 
						    hidden_size = 768 | 
					
					
						
						| 
							 | 
						    num_hidden_layers = 12 | 
					
					
						
						| 
							 | 
						    num_attention_heads = 12 | 
					
					
						
						| 
							 | 
						    bb_pretrained = False | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						BIREFNET_CONFIG_JSON = """{ | 
					
					
						
						| 
							 | 
						  "_name_or_path": "briaai/RMBG-2.0", | 
					
					
						
						| 
							 | 
						  "architectures": ["BiRefNet"], | 
					
					
						
						| 
							 | 
						  "auto_map": { | 
					
					
						
						| 
							 | 
						    "AutoConfig": "BiRefNet_config.BiRefNetConfig", | 
					
					
						
						| 
							 | 
						    "AutoModelForImageSegmentation": "birefnet.BiRefNet" | 
					
					
						
						| 
							 | 
						  }, | 
					
					
						
						| 
							 | 
						  "bb_pretrained": false | 
					
					
						
						| 
							 | 
						}""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						BIREFNET_CONFIG_FILES = { | 
					
					
						
						| 
							 | 
						    "BiRefNet_config.py": BIREFNET_CONFIG_PYTHON_TEMPLATE, | 
					
					
						
						| 
							 | 
						    "config.json": BIREFNET_CONFIG_JSON | 
					
					
						
						| 
							 | 
						} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						BIREFNET_DOWNLOAD_FILES = ["birefnet.py", "preprocessor_config.json"] | 
					
					
						
						| 
							 | 
						BIREFNET_WEIGHT_FILES = ["model.safetensors", "pytorch_model.bin"] | 
					
					
						
						| 
							 | 
						DEFAULT_LOCAL_RMBG_DIR = "models/rmbg2" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						ERROR_NO_HF_TOKEN = "HF_TOKEN environment variable not set. Please set it in your Space secrets." | 
					
					
						
						| 
							 | 
						ERROR_ACCESS_DENIED = "Access denied to model. Please check your credentials." | 
					
					
						
						| 
							 | 
						ERROR_AUTH_FAILED = "Authentication failed. Please set HF_TOKEN environment variable." | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						RTDETR_PROCESSOR = None | 
					
					
						
						| 
							 | 
						RTDETR_MODEL = None | 
					
					
						
						| 
							 | 
						HEAD_PROCESSOR = None | 
					
					
						
						| 
							 | 
						HEAD_MODEL = None | 
					
					
						
						| 
							 | 
						RMBG_MODEL = None | 
					
					
						
						| 
							 | 
						BIREFNET_MODEL = None | 
					
					
						
						| 
							 | 
						YOLOS_PROCESSOR = None | 
					
					
						
						| 
							 | 
						YOLOS_MODEL = None | 
					
					
						
						| 
							 | 
						TIMM_COLOR_MODEL = None | 
					
					
						
						| 
							 | 
						TIMM_COLOR_TRANSFORMS = None | 
					
					
						
						| 
							 | 
						SAM_MODEL = None | 
					
					
						
						| 
							 | 
						SAM_PROCESSOR = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						MODELS_LOADED = False | 
					
					
						
						| 
							 | 
						LOAD_ERROR = "" | 
					
					
						
						| 
							 | 
						LOAD_LOCK = threading.Lock() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def patch_spaces_device_handling(): | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        import spaces.zero.torch.patching as spaces_patching | 
					
					
						
						| 
							 | 
						        original_untyped_storage_new = spaces_patching._untyped_storage_new_register | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        def patched_untyped_storage_new_register(storage_cls): | 
					
					
						
						| 
							 | 
						            def wrapper(*args, **kwargs): | 
					
					
						
						| 
							 | 
						                device = kwargs.get('device') | 
					
					
						
						| 
							 | 
						                if device is not None and isinstance(device, str): | 
					
					
						
						| 
							 | 
						                    kwargs['device'] = torch.device(device) | 
					
					
						
						| 
							 | 
						                return original_untyped_storage_new(storage_cls)(*args, **kwargs) | 
					
					
						
						| 
							 | 
						            return wrapper | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        spaces_patching._untyped_storage_new_register = patched_untyped_storage_new_register | 
					
					
						
						| 
							 | 
						        logging.info("Successfully patched spaces device handling") | 
					
					
						
						| 
							 | 
						        return True | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        logging.debug(f"Spaces patching not available or failed: {e}") | 
					
					
						
						| 
							 | 
						        return False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def is_spaces_environment(): | 
					
					
						
						| 
							 | 
						    return os.getenv("SPACE_ID") is not None or "spaces" in sys.modules | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def create_config_files(local_dir: str) -> None: | 
					
					
						
						| 
							 | 
						    os.makedirs(local_dir, exist_ok=True) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for filename, content in BIREFNET_CONFIG_FILES.items(): | 
					
					
						
						| 
							 | 
						        file_path = os.path.join(local_dir, filename) | 
					
					
						
						| 
							 | 
						        if not os.path.exists(file_path): | 
					
					
						
						| 
							 | 
						            with open(file_path, "w") as f: | 
					
					
						
						| 
							 | 
						                f.write(content) | 
					
					
						
						| 
							 | 
						            logging.info(f"Created {filename} in {local_dir}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def download_birefnet_files(local_dir: str, token: str) -> None: | 
					
					
						
						| 
							 | 
						    from huggingface_hub import hf_hub_download | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for file in BIREFNET_DOWNLOAD_FILES: | 
					
					
						
						| 
							 | 
						        file_path = os.path.join(local_dir, file) | 
					
					
						
						| 
							 | 
						        if not os.path.exists(file_path): | 
					
					
						
						| 
							 | 
						            try: | 
					
					
						
						| 
							 | 
						                hf_hub_download( | 
					
					
						
						| 
							 | 
						                    repo_id=RMBG_REPO, | 
					
					
						
						| 
							 | 
						                    filename=file, | 
					
					
						
						| 
							 | 
						                    token=token, | 
					
					
						
						| 
							 | 
						                    local_dir=local_dir, | 
					
					
						
						| 
							 | 
						                    local_dir_use_symlinks=False | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						                logging.info(f"Downloaded {file} to {local_dir}") | 
					
					
						
						| 
							 | 
						            except Exception as e: | 
					
					
						
						| 
							 | 
						                logging.error(f"Failed to download {file}: {e}") | 
					
					
						
						| 
							 | 
						                raise RuntimeError(f"Failed to download {file} from {RMBG_REPO}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def download_model_weights(local_dir: str, token: str) -> None: | 
					
					
						
						| 
							 | 
						    from huggingface_hub import hf_hub_download | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    weights_exist = any( | 
					
					
						
						| 
							 | 
						        os.path.exists(os.path.join(local_dir, weight_file))  | 
					
					
						
						| 
							 | 
						        for weight_file in BIREFNET_WEIGHT_FILES | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if weights_exist: | 
					
					
						
						| 
							 | 
						        return | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        hf_hub_download( | 
					
					
						
						| 
							 | 
						            repo_id=RMBG_REPO, | 
					
					
						
						| 
							 | 
						            filename="model.safetensors", | 
					
					
						
						| 
							 | 
						            token=token, | 
					
					
						
						| 
							 | 
						            local_dir=local_dir, | 
					
					
						
						| 
							 | 
						            local_dir_use_symlinks=False | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        logging.info(f"Downloaded model.safetensors to {local_dir}") | 
					
					
						
						| 
							 | 
						        return | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        logging.warning(f"Failed to download model.safetensors: {e}") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        hf_hub_download( | 
					
					
						
						| 
							 | 
						            repo_id=RMBG_REPO, | 
					
					
						
						| 
							 | 
						            filename="pytorch_model.bin", | 
					
					
						
						| 
							 | 
						            token=token, | 
					
					
						
						| 
							 | 
						            local_dir=local_dir, | 
					
					
						
						| 
							 | 
						            local_dir_use_symlinks=False | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        logging.info(f"Downloaded pytorch_model.bin to {local_dir}") | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        logging.error(f"Failed to download pytorch_model.bin: {e}") | 
					
					
						
						| 
							 | 
						        raise RuntimeError(f"Failed to download model weights from {RMBG_REPO}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def ensure_birefnet_files(local_dir: str, token: str) -> None: | 
					
					
						
						| 
							 | 
						    create_config_files(local_dir) | 
					
					
						
						| 
							 | 
						    download_birefnet_files(local_dir, token) | 
					
					
						
						| 
							 | 
						    download_model_weights(local_dir, token) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def ensure_models_loaded() -> None: | 
					
					
						
						| 
							 | 
						    global MODELS_LOADED, LOAD_ERROR | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if not MODELS_LOADED: | 
					
					
						
						| 
							 | 
						        if is_spaces_environment(): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            time.sleep(1) | 
					
					
						
						| 
							 | 
						            print("="*70) | 
					
					
						
						| 
							 | 
						            print("ZERO GPU MODEL LOADING: 1. Models NOT loaded at startup") | 
					
					
						
						| 
							 | 
						            print("="*70) | 
					
					
						
						| 
							 | 
						            logging.info("ZERO GPU MODEL LOADING: Models NOT loaded at startup") | 
					
					
						
						| 
							 | 
						            logging.info("ZERO GPU MODEL LOADING: Models will be loaded on-demand in GPU context") | 
					
					
						
						| 
							 | 
						            return | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        with LOAD_LOCK: | 
					
					
						
						| 
							 | 
						            if not MODELS_LOADED: | 
					
					
						
						| 
							 | 
						                if LOAD_ERROR: | 
					
					
						
						| 
							 | 
						                    raise RuntimeError(f"Models failed to load: {LOAD_ERROR}") | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                try: | 
					
					
						
						| 
							 | 
						                    load_models() | 
					
					
						
						| 
							 | 
						                except Exception as e: | 
					
					
						
						| 
							 | 
						                    LOAD_ERROR = str(e) | 
					
					
						
						| 
							 | 
						                    raise | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def load_model_with_precision(model_class, repo_id: str, full_precision: bool, device_map: bool = True, trust_remote_code: bool = False): | 
					
					
						
						| 
							 | 
						    global DEVICE | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        spaces_env = is_spaces_environment() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if spaces_env: | 
					
					
						
						| 
							 | 
						            torch_device = torch.device("cpu") | 
					
					
						
						| 
							 | 
						            patch_spaces_device_handling() | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            if DEVICE == "cuda": | 
					
					
						
						| 
							 | 
						                torch.cuda.empty_cache() | 
					
					
						
						| 
							 | 
						            torch_device = torch.device(DEVICE) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        load_kwargs = { | 
					
					
						
						| 
							 | 
						            "torch_dtype": torch.float32 if full_precision else torch.float16, | 
					
					
						
						| 
							 | 
						            "trust_remote_code": trust_remote_code, | 
					
					
						
						| 
							 | 
						            "low_cpu_mem_usage": True, | 
					
					
						
						| 
							 | 
						            "use_safetensors": True | 
					
					
						
						| 
							 | 
						        } | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if spaces_env: | 
					
					
						
						| 
							 | 
						            load_kwargs["device_map"] = None | 
					
					
						
						| 
							 | 
						        elif DEVICE == "cuda" and device_map and torch.cuda.device_count() > 1: | 
					
					
						
						| 
							 | 
						            load_kwargs["device_map"] = "auto" | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						            model = model_class.from_pretrained(repo_id, **load_kwargs) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if not spaces_env and not hasattr(model, 'hf_device_map'): | 
					
					
						
						| 
							 | 
						                model = model.to(torch_device) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if not full_precision and DEVICE == "cuda": | 
					
					
						
						| 
							 | 
						                    model = model.half() | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						        except (ValueError, RuntimeError, OSError, UnicodeDecodeError) as e: | 
					
					
						
						| 
							 | 
						            logging.warning(f"Failed to load model with initial configuration: {e}") | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if "Unable to load weights from pytorch checkpoint" in str(e) or "UnicodeDecodeError" in str(e): | 
					
					
						
						| 
							 | 
						                logging.info(f"Attempting to clear cache and retry for {repo_id}") | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                try: | 
					
					
						
						| 
							 | 
						                    from huggingface_hub import scan_cache_dir | 
					
					
						
						| 
							 | 
						                    cache_info = scan_cache_dir() | 
					
					
						
						| 
							 | 
						                    for repo in cache_info.repos: | 
					
					
						
						| 
							 | 
						                        if repo_id.replace("/", "--") in repo.repo_id: | 
					
					
						
						| 
							 | 
						                            repo.delete() | 
					
					
						
						| 
							 | 
						                            logging.info(f"Cleared cache for {repo_id}") | 
					
					
						
						| 
							 | 
						                            break | 
					
					
						
						| 
							 | 
						                except Exception as cache_e: | 
					
					
						
						| 
							 | 
						                    logging.warning(f"Cache clearing failed: {cache_e}") | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                try: | 
					
					
						
						| 
							 | 
						                    load_kwargs_retry = { | 
					
					
						
						| 
							 | 
						                        "torch_dtype": torch.float32, | 
					
					
						
						| 
							 | 
						                        "trust_remote_code": trust_remote_code, | 
					
					
						
						| 
							 | 
						                        "force_download": True, | 
					
					
						
						| 
							 | 
						                        "device_map": None, | 
					
					
						
						| 
							 | 
						                        "low_cpu_mem_usage": True | 
					
					
						
						| 
							 | 
						                    } | 
					
					
						
						| 
							 | 
						                    model = model_class.from_pretrained(repo_id, **load_kwargs_retry) | 
					
					
						
						| 
							 | 
						                    model = model.to(torch_device) | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                except Exception as retry_e: | 
					
					
						
						| 
							 | 
						                    logging.warning(f"Retry with force_download failed: {retry_e}") | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    try: | 
					
					
						
						| 
							 | 
						                        load_kwargs_tf = { | 
					
					
						
						| 
							 | 
						                            "from_tf": True, | 
					
					
						
						| 
							 | 
						                            "torch_dtype": torch.float32, | 
					
					
						
						| 
							 | 
						                            "trust_remote_code": trust_remote_code, | 
					
					
						
						| 
							 | 
						                            "device_map": None, | 
					
					
						
						| 
							 | 
						                            "low_cpu_mem_usage": True | 
					
					
						
						| 
							 | 
						                        } | 
					
					
						
						| 
							 | 
						                        model = model_class.from_pretrained(repo_id, **load_kwargs_tf) | 
					
					
						
						| 
							 | 
						                        model = model.to(torch_device) | 
					
					
						
						| 
							 | 
						                        logging.info(f"Successfully loaded {repo_id} from TensorFlow checkpoint") | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                    except Exception as tf_e: | 
					
					
						
						| 
							 | 
						                        logging.warning(f"TensorFlow fallback failed: {tf_e}") | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        try: | 
					
					
						
						| 
							 | 
						                            load_kwargs_basic = { | 
					
					
						
						| 
							 | 
						                                "torch_dtype": torch.float32, | 
					
					
						
						| 
							 | 
						                                "trust_remote_code": trust_remote_code, | 
					
					
						
						| 
							 | 
						                                "device_map": None, | 
					
					
						
						| 
							 | 
						                                "use_safetensors": False, | 
					
					
						
						| 
							 | 
						                                "local_files_only": False | 
					
					
						
						| 
							 | 
						                            } | 
					
					
						
						| 
							 | 
						                            model = model_class.from_pretrained(repo_id, **load_kwargs_basic) | 
					
					
						
						| 
							 | 
						                            model = model.to(torch_device) | 
					
					
						
						| 
							 | 
						                            logging.info(f"Successfully loaded {repo_id} with basic configuration") | 
					
					
						
						| 
							 | 
						                             | 
					
					
						
						| 
							 | 
						                        except Exception as basic_e: | 
					
					
						
						| 
							 | 
						                            logging.error(f"All fallback strategies failed for {repo_id}: {basic_e}") | 
					
					
						
						| 
							 | 
						                            raise RuntimeError(f"Unable to load model {repo_id} after all retry attempts: {basic_e}") | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                load_kwargs_fallback = { | 
					
					
						
						| 
							 | 
						                    "torch_dtype": torch.float32, | 
					
					
						
						| 
							 | 
						                    "trust_remote_code": trust_remote_code, | 
					
					
						
						| 
							 | 
						                    "device_map": None | 
					
					
						
						| 
							 | 
						                } | 
					
					
						
						| 
							 | 
						                model = model_class.from_pretrained(repo_id, **load_kwargs_fallback) | 
					
					
						
						| 
							 | 
						                model = model.to(torch_device) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        model.eval() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if not spaces_env: | 
					
					
						
						| 
							 | 
						            with torch.no_grad(): | 
					
					
						
						| 
							 | 
						                logging.info(f"Verifying model {repo_id} is on correct device") | 
					
					
						
						| 
							 | 
						                param = next(model.parameters()) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if DEVICE == "cuda" and not param.is_cuda: | 
					
					
						
						| 
							 | 
						                    model = model.to(torch_device) | 
					
					
						
						| 
							 | 
						                    logging.warning(f"Forced model {repo_id} to {DEVICE}") | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                logging.info(f"Model {repo_id} device: {param.device}") | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            logging.info(f"Model {repo_id} loaded on CPU (Zero GPU environment)") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return model | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        logging.error(f"Failed to load model from {repo_id} on {DEVICE}: {e}") | 
					
					
						
						| 
							 | 
						        raise | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def handle_rmbg_access_error(error_msg: str) -> None: | 
					
					
						
						| 
							 | 
						    if "403" in error_msg and "gated repo" in error_msg: | 
					
					
						
						| 
							 | 
						        logging.error("\n" + "="*60 + "\n" | 
					
					
						
						| 
							 | 
						            "ERROR: Access denied to RMBG-2.0 model!\n" | 
					
					
						
						| 
							 | 
						            "You need to request access at: https://huggingface.co/briaai/RMBG-2.0\n" + | 
					
					
						
						| 
							 | 
						            "="*60 + "\n") | 
					
					
						
						| 
							 | 
						        raise RuntimeError(ERROR_ACCESS_DENIED) | 
					
					
						
						| 
							 | 
						    elif "401" in error_msg: | 
					
					
						
						| 
							 | 
						        logging.error("\n" + "="*60 + "\n" | 
					
					
						
						| 
							 | 
						            "ERROR: Authentication failed!\n" | 
					
					
						
						| 
							 | 
						            "Please set your HF_TOKEN environment variable.\n" + | 
					
					
						
						| 
							 | 
						            "="*60 + "\n") | 
					
					
						
						| 
							 | 
						        raise RuntimeError(ERROR_AUTH_FAILED) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise RuntimeError(error_msg) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def load_rtdetr_model() -> None: | 
					
					
						
						| 
							 | 
						    global RTDETR_PROCESSOR, RTDETR_MODEL | 
					
					
						
						| 
							 | 
						    logging.info("Loading RT-DETR model...") | 
					
					
						
						| 
							 | 
						    RTDETR_PROCESSOR = AutoProcessor.from_pretrained(RTDETR_REPO) | 
					
					
						
						| 
							 | 
						    RTDETR_MODEL = load_model_with_precision( | 
					
					
						
						| 
							 | 
						        AutoModelForObjectDetection,  | 
					
					
						
						| 
							 | 
						        RTDETR_REPO,  | 
					
					
						
						| 
							 | 
						        RTDETR_FULL_PRECISION, | 
					
					
						
						| 
							 | 
						        device_map=False | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    logging.info("RT-DETR model loaded successfully") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_head_detection_model() -> None: | 
					
					
						
						| 
							 | 
						    global HEAD_PROCESSOR, HEAD_MODEL | 
					
					
						
						| 
							 | 
						    logging.info("Loading Head Detection model...") | 
					
					
						
						| 
							 | 
						    HEAD_PROCESSOR = AutoImageProcessor.from_pretrained(HEAD_DETECTION_REPO) | 
					
					
						
						| 
							 | 
						    HEAD_MODEL = load_model_with_precision( | 
					
					
						
						| 
							 | 
						        DetrForObjectDetection,  | 
					
					
						
						| 
							 | 
						        HEAD_DETECTION_REPO,  | 
					
					
						
						| 
							 | 
						        HEAD_DETECTION_FULL_PRECISION, | 
					
					
						
						| 
							 | 
						        device_map=False | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    logging.info("Head Detection model loaded successfully") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_birefnet_model() -> None: | 
					
					
						
						| 
							 | 
						    global BIREFNET_MODEL | 
					
					
						
						| 
							 | 
						    logging.info("Loading BiRefNet-matting model...") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        BIREFNET_MODEL = load_model_with_precision( | 
					
					
						
						| 
							 | 
						            AutoModelForImageSegmentation, | 
					
					
						
						| 
							 | 
						            BIREFNET_REPO, | 
					
					
						
						| 
							 | 
						            RMBG_FULL_PRECISION, | 
					
					
						
						| 
							 | 
						            trust_remote_code=True, | 
					
					
						
						| 
							 | 
						            device_map=False | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        logging.info("BiRefNet-matting model loaded successfully") | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        logging.error(f"Failed to load BiRefNet-matting model: {e}") | 
					
					
						
						| 
							 | 
						        raise | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_rmbg_model() -> None: | 
					
					
						
						| 
							 | 
						    global RMBG_MODEL | 
					
					
						
						| 
							 | 
						    logging.info("Loading RMBG model...") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    token = os.getenv("HF_TOKEN", "") | 
					
					
						
						| 
							 | 
						    if not token: | 
					
					
						
						| 
							 | 
						        logging.error(ERROR_NO_HF_TOKEN) | 
					
					
						
						| 
							 | 
						        logging.warning("RMBG model requires HF_TOKEN. Skipping RMBG model loading...") | 
					
					
						
						| 
							 | 
						        RMBG_MODEL = None | 
					
					
						
						| 
							 | 
						        return | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    local_dir = DEFAULT_LOCAL_RMBG_DIR | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        ensure_birefnet_files(local_dir, token) | 
					
					
						
						| 
							 | 
						    except RuntimeError as e: | 
					
					
						
						| 
							 | 
						        handle_rmbg_access_error(str(e)) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    os.environ["HF_HOME"] = os.path.dirname(local_dir) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        RMBG_MODEL = load_model_with_precision( | 
					
					
						
						| 
							 | 
						            AutoModelForImageSegmentation,  | 
					
					
						
						| 
							 | 
						            local_dir,  | 
					
					
						
						| 
							 | 
						            RMBG_FULL_PRECISION, | 
					
					
						
						| 
							 | 
						            trust_remote_code=True, | 
					
					
						
						| 
							 | 
						            device_map=False | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if USE_TORCH_COMPILE and DEVICE == "cuda": | 
					
					
						
						| 
							 | 
						            try: | 
					
					
						
						| 
							 | 
						                RMBG_MODEL = torch.compile( | 
					
					
						
						| 
							 | 
						                    RMBG_MODEL, | 
					
					
						
						| 
							 | 
						                    mode=TORCH_COMPILE_MODE, | 
					
					
						
						| 
							 | 
						                    backend=TORCH_COMPILE_BACKEND, | 
					
					
						
						| 
							 | 
						                    fullgraph=False, | 
					
					
						
						| 
							 | 
						                    dynamic=False | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						                logging.info(f"RMBG model compiled with mode={TORCH_COMPILE_MODE}, backend={TORCH_COMPILE_BACKEND}") | 
					
					
						
						| 
							 | 
						            except Exception as e: | 
					
					
						
						| 
							 | 
						                logging.warning(f"Failed to compile RMBG model: {e}") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logging.info("RMBG-2.0 model loaded successfully from local directory") | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        error_msg = str(e) | 
					
					
						
						| 
							 | 
						        handle_rmbg_access_error(error_msg) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_yolos_fashionpedia_model() -> None: | 
					
					
						
						| 
							 | 
						    global YOLOS_PROCESSOR, YOLOS_MODEL | 
					
					
						
						| 
							 | 
						    logging.info("Loading YOLOS FashionPedia model...") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        YOLOS_PROCESSOR = AutoImageProcessor.from_pretrained( | 
					
					
						
						| 
							 | 
						            YOLOS_FASHIONPEDIA_REPO, | 
					
					
						
						| 
							 | 
						            size={"height": 512, "width": 512} | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						    except Exception: | 
					
					
						
						| 
							 | 
						        logging.warning("Failed to set custom size for YOLOS processor, using default") | 
					
					
						
						| 
							 | 
						        YOLOS_PROCESSOR = AutoImageProcessor.from_pretrained(YOLOS_FASHIONPEDIA_REPO) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    YOLOS_MODEL = load_model_with_precision( | 
					
					
						
						| 
							 | 
						        YolosForObjectDetection, | 
					
					
						
						| 
							 | 
						        YOLOS_FASHIONPEDIA_REPO, | 
					
					
						
						| 
							 | 
						        YOLOS_FASHIONPEDIA_FULL_PRECISION, | 
					
					
						
						| 
							 | 
						        device_map=False | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    logging.info("YOLOS FashionPedia model loaded successfully") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_timm_color_model() -> None: | 
					
					
						
						| 
							 | 
						    global TIMM_COLOR_MODEL, TIMM_COLOR_TRANSFORMS | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if not TIMM_AVAILABLE: | 
					
					
						
						| 
							 | 
						        logging.warning("TIMM not available - skipping color extraction model") | 
					
					
						
						| 
							 | 
						        return | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    logging.info("Loading TIMM ResNet50 A2 color extraction model...") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        TIMM_COLOR_MODEL = timm.create_model( | 
					
					
						
						| 
							 | 
						            'hf-hub:timm/resnet50.a2_in1k', | 
					
					
						
						| 
							 | 
						            pretrained=True, | 
					
					
						
						| 
							 | 
						            num_classes=0, | 
					
					
						
						| 
							 | 
						            global_pool='' | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if not is_spaces_environment(): | 
					
					
						
						| 
							 | 
						            TIMM_COLOR_MODEL = TIMM_COLOR_MODEL.to(DEVICE) | 
					
					
						
						| 
							 | 
						            if not TIMM_COLOR_EXTRACTOR_FULL_PRECISION and DEVICE == "cuda": | 
					
					
						
						| 
							 | 
						                TIMM_COLOR_MODEL = TIMM_COLOR_MODEL.half() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        TIMM_COLOR_MODEL.eval() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        data_config = timm.data.resolve_data_config({}, model=TIMM_COLOR_MODEL) | 
					
					
						
						| 
							 | 
						        TIMM_COLOR_TRANSFORMS = timm.data.create_transform(**data_config, is_training=False) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        from src.processing.return_images.timm_resnet50_color import timm_color_extractor | 
					
					
						
						| 
							 | 
						        timm_color_extractor.initialize_model(TIMM_COLOR_MODEL, TIMM_COLOR_TRANSFORMS) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logging.info("TIMM ResNet50 A2 color extraction model loaded successfully") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        logging.warning(f"Failed to load TIMM color extraction model: {e}") | 
					
					
						
						| 
							 | 
						        TIMM_COLOR_MODEL = None | 
					
					
						
						| 
							 | 
						        TIMM_COLOR_TRANSFORMS = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def load_models() -> None: | 
					
					
						
						| 
							 | 
						    global MODELS_LOADED, LOAD_ERROR | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    with LOAD_LOCK: | 
					
					
						
						| 
							 | 
						        if MODELS_LOADED: | 
					
					
						
						| 
							 | 
						            logging.info("Models already loaded") | 
					
					
						
						| 
							 | 
						            return | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if is_spaces_environment(): | 
					
					
						
						| 
							 | 
						            logging.info("ZERO GPU MODEL LOADING: User request triggered model loading") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    check_hardware_environment() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    from src.config.constants import BACKGROUND_REMOVAL_MODEL | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    models_status = { | 
					
					
						
						| 
							 | 
						        "rtdetr": False, | 
					
					
						
						| 
							 | 
						        "head_detection": False, | 
					
					
						
						| 
							 | 
						        "background_removal": False, | 
					
					
						
						| 
							 | 
						        "yolos": False | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    critical_errors = [] | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        load_rtdetr_model() | 
					
					
						
						| 
							 | 
						        models_status["rtdetr"] = True | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        critical_errors.append(f"RT-DETR: {str(e)}") | 
					
					
						
						| 
							 | 
						        logging.error(f"Failed to load RT-DETR model: {e}") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        load_head_detection_model() | 
					
					
						
						| 
							 | 
						        models_status["head_detection"] = True | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        critical_errors.append(f"Head Detection: {str(e)}") | 
					
					
						
						| 
							 | 
						        logging.error(f"Failed to load Head Detection model: {e}") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    bg_removal_loaded = False | 
					
					
						
						| 
							 | 
						    if BACKGROUND_REMOVAL_MODEL == 1: | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						            load_rmbg_model() | 
					
					
						
						| 
							 | 
						            bg_removal_loaded = RMBG_MODEL is not None | 
					
					
						
						| 
							 | 
						            logging.info(f"RMBG model loaded: {bg_removal_loaded}") | 
					
					
						
						| 
							 | 
						        except Exception as e: | 
					
					
						
						| 
							 | 
						            logging.warning(f"Failed to load RMBG model: {e}") | 
					
					
						
						| 
							 | 
						    elif BACKGROUND_REMOVAL_MODEL == 2: | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						            load_birefnet_model() | 
					
					
						
						| 
							 | 
						            bg_removal_loaded = BIREFNET_MODEL is not None | 
					
					
						
						| 
							 | 
						            logging.info(f"BiRefNet model loaded: {bg_removal_loaded}") | 
					
					
						
						| 
							 | 
						        except Exception as e: | 
					
					
						
						| 
							 | 
						            logging.warning(f"Failed to load BiRefNet model: {e}") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    models_status["background_removal"] = bg_removal_loaded | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        load_yolos_fashionpedia_model() | 
					
					
						
						| 
							 | 
						        models_status["yolos"] = True | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        critical_errors.append(f"YOLOS: {str(e)}") | 
					
					
						
						| 
							 | 
						        logging.error(f"Failed to load YOLOS model: {e}") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        load_timm_color_model() | 
					
					
						
						| 
							 | 
						        models_status["timm_color"] = TIMM_COLOR_MODEL is not None | 
					
					
						
						| 
							 | 
						        logging.info(f"TIMM color model loaded: {models_status['timm_color']}") | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        logging.warning(f"Failed to load TIMM color model: {e}") | 
					
					
						
						| 
							 | 
						        models_status["timm_color"] = False | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        logging.info("Loading SAM for color extraction...") | 
					
					
						
						| 
							 | 
						        global SAM_PROCESSOR, SAM_MODEL | 
					
					
						
						| 
							 | 
						        SAM_PROCESSOR = SamProcessor.from_pretrained(SAM_REPO) | 
					
					
						
						| 
							 | 
						        SAM_MODEL = load_model_with_precision( | 
					
					
						
						| 
							 | 
						            SamModel, | 
					
					
						
						| 
							 | 
						            SAM_REPO, | 
					
					
						
						| 
							 | 
						            full_precision=True, | 
					
					
						
						| 
							 | 
						            device_map=False | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        from src.processing.return_images.segment_for_color import simple_segmentation | 
					
					
						
						| 
							 | 
						        simple_segmentation.initialize_model(SAM_MODEL, SAM_PROCESSOR) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        models_status["sam_color"] = True | 
					
					
						
						| 
							 | 
						        logging.info("SAM loaded for color extraction") | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        logging.info(f"SAM not loaded (optional): {e}") | 
					
					
						
						| 
							 | 
						        models_status["sam_color"] = False | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if models_status["rtdetr"] or models_status["yolos"]: | 
					
					
						
						| 
							 | 
						        MODELS_LOADED = True | 
					
					
						
						| 
							 | 
						        LOAD_ERROR = "" | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        loaded = [k for k, v in models_status.items() if v] | 
					
					
						
						| 
							 | 
						        failed = [k for k, v in models_status.items() if not v] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logging.info(f"Models loaded: {', '.join(loaded)}") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if failed: | 
					
					
						
						| 
							 | 
						            logging.warning(f"Models failed: {', '.join(failed)}") | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        error_msg = "Failed to load critical models. " + "; ".join(critical_errors) | 
					
					
						
						| 
							 | 
						        logging.error(error_msg) | 
					
					
						
						| 
							 | 
						        LOAD_ERROR = error_msg | 
					
					
						
						| 
							 | 
						        raise RuntimeError(error_msg) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def move_models_to_gpu(): | 
					
					
						
						| 
							 | 
						    global RMBG_MODEL, BIREFNET_MODEL, RTDETR_PROCESSOR, RTDETR_MODEL, HEAD_MODEL, YOLOS_PROCESSOR, YOLOS_MODEL, TIMM_COLOR_MODEL, SAM_MODEL, DEVICE | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if not torch.cuda.is_available(): | 
					
					
						
						| 
							 | 
						        logging.warning("CUDA not available, cannot move models to GPU") | 
					
					
						
						| 
							 | 
						        return | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    original_device = DEVICE | 
					
					
						
						| 
							 | 
						    DEVICE = "cuda" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        if RMBG_MODEL is not None: | 
					
					
						
						| 
							 | 
						            logging.info("Moving RMBG model to GPU...") | 
					
					
						
						| 
							 | 
						            RMBG_MODEL = RMBG_MODEL.to("cuda") | 
					
					
						
						| 
							 | 
						            if not RMBG_FULL_PRECISION: | 
					
					
						
						| 
							 | 
						                RMBG_MODEL = RMBG_MODEL.half() | 
					
					
						
						| 
							 | 
						            logging.info("RMBG model moved to GPU") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if BIREFNET_MODEL is not None: | 
					
					
						
						| 
							 | 
						            logging.info("Moving BiRefNet model to GPU...") | 
					
					
						
						| 
							 | 
						            BIREFNET_MODEL = BIREFNET_MODEL.to("cuda") | 
					
					
						
						| 
							 | 
						            if not BIREFNET_FULL_PRECISION: | 
					
					
						
						| 
							 | 
						                BIREFNET_MODEL = BIREFNET_MODEL.half() | 
					
					
						
						| 
							 | 
						            logging.info("BiRefNet model moved to GPU") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if RTDETR_MODEL is not None: | 
					
					
						
						| 
							 | 
						            logging.info("Moving RT-DETR model to GPU...") | 
					
					
						
						| 
							 | 
						            RTDETR_MODEL = RTDETR_MODEL.to("cuda") | 
					
					
						
						| 
							 | 
						            if not RTDETR_FULL_PRECISION: | 
					
					
						
						| 
							 | 
						                RTDETR_MODEL = RTDETR_MODEL.half() | 
					
					
						
						| 
							 | 
						            logging.info("RT-DETR model moved to GPU") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if HEAD_MODEL is not None: | 
					
					
						
						| 
							 | 
						            logging.info("Moving Head Detection model to GPU...") | 
					
					
						
						| 
							 | 
						            HEAD_MODEL = HEAD_MODEL.to("cuda") | 
					
					
						
						| 
							 | 
						            if not HEAD_DETECTION_FULL_PRECISION: | 
					
					
						
						| 
							 | 
						                HEAD_MODEL = HEAD_MODEL.half() | 
					
					
						
						| 
							 | 
						            logging.info("Head Detection model moved to GPU") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if YOLOS_MODEL is not None: | 
					
					
						
						| 
							 | 
						            logging.info("Moving YOLOS model to GPU...") | 
					
					
						
						| 
							 | 
						            YOLOS_MODEL = YOLOS_MODEL.to("cuda") | 
					
					
						
						| 
							 | 
						            if not YOLOS_FASHIONPEDIA_FULL_PRECISION: | 
					
					
						
						| 
							 | 
						                YOLOS_MODEL = YOLOS_MODEL.half() | 
					
					
						
						| 
							 | 
						            logging.info("YOLOS model moved to GPU") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if TIMM_COLOR_MODEL is not None: | 
					
					
						
						| 
							 | 
						            logging.info("Moving TIMM color model to GPU...") | 
					
					
						
						| 
							 | 
						            from src.processing.return_images.timm_resnet50_color import timm_color_extractor | 
					
					
						
						| 
							 | 
						            timm_color_extractor.move_to_gpu() | 
					
					
						
						| 
							 | 
						            logging.info("TIMM color model moved to GPU") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logging.info(f"SAM_MODEL status: {SAM_MODEL is not None}") | 
					
					
						
						| 
							 | 
						        if SAM_MODEL is not None: | 
					
					
						
						| 
							 | 
						            logging.info("Moving SAM model to GPU...") | 
					
					
						
						| 
							 | 
						            SAM_MODEL = SAM_MODEL.to("cuda") | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            from src.processing.return_images.segment_for_color import simple_segmentation | 
					
					
						
						| 
							 | 
						            simple_segmentation._sam_model = SAM_MODEL | 
					
					
						
						| 
							 | 
						            simple_segmentation.device = torch.device('cuda') | 
					
					
						
						| 
							 | 
						            logging.info("SAM model moved to GPU") | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            logging.warning("SAM_MODEL is None, cannot move to GPU") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logging.info("All models moved to GPU successfully") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        logging.error(f"Failed to move models to GPU: {e}") | 
					
					
						
						| 
							 | 
						        DEVICE = original_device | 
					
					
						
						| 
							 | 
						        raise | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class ModelLoader: | 
					
					
						
						| 
							 | 
						    def __init__(self): | 
					
					
						
						| 
							 | 
						        self.device = DEVICE | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    def get_model(self, model_name: str): | 
					
					
						
						| 
							 | 
						        global RMBG_MODEL, BIREFNET_MODEL, RTDETR_MODEL, HEAD_MODEL, YOLOS_MODEL, TIMM_COLOR_MODEL | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        ensure_models_loaded() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if model_name == "rmbg": | 
					
					
						
						| 
							 | 
						            if RMBG_MODEL is None: | 
					
					
						
						| 
							 | 
						                raise RuntimeError("RMBG model not loaded") | 
					
					
						
						| 
							 | 
						            return RMBG_MODEL | 
					
					
						
						| 
							 | 
						        elif model_name == "birefnet": | 
					
					
						
						| 
							 | 
						            if BIREFNET_MODEL is None: | 
					
					
						
						| 
							 | 
						                raise RuntimeError("BiRefNet model not loaded") | 
					
					
						
						| 
							 | 
						            return BIREFNET_MODEL | 
					
					
						
						| 
							 | 
						        elif model_name == "rtdetr": | 
					
					
						
						| 
							 | 
						            if RTDETR_MODEL is None: | 
					
					
						
						| 
							 | 
						                raise RuntimeError("RT-DETR model not loaded") | 
					
					
						
						| 
							 | 
						            return RTDETR_MODEL | 
					
					
						
						| 
							 | 
						        elif model_name == "head": | 
					
					
						
						| 
							 | 
						            if HEAD_MODEL is None: | 
					
					
						
						| 
							 | 
						                raise RuntimeError("Head detection model not loaded") | 
					
					
						
						| 
							 | 
						            return HEAD_MODEL | 
					
					
						
						| 
							 | 
						        elif model_name == "yolos": | 
					
					
						
						| 
							 | 
						            if YOLOS_MODEL is None: | 
					
					
						
						| 
							 | 
						                raise RuntimeError("YOLOS model not loaded") | 
					
					
						
						| 
							 | 
						            return YOLOS_MODEL | 
					
					
						
						| 
							 | 
						        elif model_name == "timm_color": | 
					
					
						
						| 
							 | 
						            if TIMM_COLOR_MODEL is None: | 
					
					
						
						| 
							 | 
						                raise RuntimeError("TIMM color model not loaded") | 
					
					
						
						| 
							 | 
						            return TIMM_COLOR_MODEL | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            raise ValueError(f"Unknown model: {model_name}") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def get_processor(self, model_name: str): | 
					
					
						
						| 
							 | 
						        global RTDETR_PROCESSOR, HEAD_PROCESSOR, YOLOS_PROCESSOR | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        ensure_models_loaded() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if model_name == "rtdetr": | 
					
					
						
						| 
							 | 
						            if RTDETR_PROCESSOR is None: | 
					
					
						
						| 
							 | 
						                raise RuntimeError("RT-DETR processor not loaded") | 
					
					
						
						| 
							 | 
						            return RTDETR_PROCESSOR | 
					
					
						
						| 
							 | 
						        elif model_name == "head": | 
					
					
						
						| 
							 | 
						            if HEAD_PROCESSOR is None: | 
					
					
						
						| 
							 | 
						                raise RuntimeError("Head detection processor not loaded") | 
					
					
						
						| 
							 | 
						            return HEAD_PROCESSOR | 
					
					
						
						| 
							 | 
						        elif model_name == "yolos": | 
					
					
						
						| 
							 | 
						            if YOLOS_PROCESSOR is None: | 
					
					
						
						| 
							 | 
						                raise RuntimeError("YOLOS processor not loaded") | 
					
					
						
						| 
							 | 
						            return YOLOS_PROCESSOR | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            raise ValueError(f"No processor for model: {model_name}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						model_loader = ModelLoader() | 
					
					
						
						| 
							 | 
						
 |