import json import os import time import gradio as gr import numpy as np import pandas as pd import torch from hf_repo_utils import auto_register_kaloscope_variants, validate_models # Import inference implementations from inference_onnx import ONNXInference, softmax from inference_pytorch import PyTorchInference # This import is crucial to register the lsnet models with timm try: import lsnet.lsnet_artist # noqa: F401 except ImportError as e: print(f"Error: {e}") raise gr.Error("Could not import lsnet.lsnet_artist. Please ensure the lsnet folder is in your workspace.") # ------------------------------------------------------------ # CONFIG SECTION - Model Configuration # ------------------------------------------------------------ # Define available models with their configuration # Format: "display_name": { # "type": "onnx" or "pytorch", # "path": "local path or repo:filename", # for local use - if left empty, will use repo id # "repo_id": "huggingface repo id" (optional, but path cant be empty if repo_id empty), # "subfolder": "subfolder in repo" (optional, if applicable), # "arch": "model architecture name" - lsnet_xl_artist is expected for usual Kaloscope releases b, l, s exist but unused # # Optional CSV mapping settings - if omitted the model's repo_id will be used to find class_mapping.csv # # "csv_filename": "class_mapping.csv", # # "csv_subfolder": "optional/subfolder/for/csv", # } MODELS = { "Kaloscope v2.0 ONNX": { "type": "onnx", "path": "", "repo_id": "DraconicDragon/Kaloscope-onnx", "subfolder": "v2.0", "filename": "kaloscope_2-0.onnx", "arch": "lsnet_xl_artist_448", }, "Kaloscope v2.0": { "type": "pytorch", "path": "", "repo_id": "heathcliff01/Kaloscope2.0", "subfolder": "448-90.13", "filename": "best_checkpoint.pth", "arch": "lsnet_xl_artist_448", }, "Kaloscope v1.1 ONNX": { "type": "onnx", "path": "", "repo_id": "DraconicDragon/Kaloscope-onnx", "filename": "kaloscope_1-1.onnx", "arch": "lsnet_xl_artist", }, "Kaloscope v1.1": { "type": "pytorch", "path": "", "repo_id": "heathcliff01/Kaloscope", "subfolder": "224-85.65", "filename": "best_checkpoint.pth", "arch": "lsnet_xl_artist", }, "Kaloscope v1.0 ONNX": { "type": "onnx", "path": "", "repo_id": "DraconicDragon/Kaloscope-onnx", "filename": "kaloscope_1-0.onnx", "arch": "lsnet_xl_artist", }, "Kaloscope v1.0": { "type": "pytorch", "path": "", "repo_id": "heathcliff01/Kaloscope", "filename": "best_checkpoint.pth", "arch": "lsnet_xl_artist", }, "Kaloscope v1.0 ema": { "type": "pytorch", "path": "", "repo_id": "DraconicDragon/Kaloscope-onnx", "filename": "best_checkpoint_ema.pth", "arch": "lsnet_xl_artist", }, } MODELS = validate_models(MODELS) auto_register_kaloscope_variants(MODELS) # Default CSV filename used if a model doesn't specify one CSV_FILENAME = "class_mapping.csv" # Device configuration try: DEVICE = "cuda" if torch.cuda.is_available() else "cpu" except Exception: DEVICE = "cpu" # ------------------------------------------------------------ def load_labels(csv_path): """ Loads the class labels from the provided CSV file into a dictionary. """ try: df = pd.read_csv(csv_path) if "class_id" not in df.columns or "class_name" not in df.columns: raise gr.Error("CSV file must have 'class_id' and 'class_name' columns.") df["class_name"] = df["class_name"].str.strip("'") return dict(zip(df["class_id"], df["class_name"])) except FileNotFoundError: raise gr.Error(f"CSV file not found at '{csv_path}'") except Exception as e: raise gr.Error(f"Error reading CSV file: {e}") # Labels cache - maps model_name -> labels dict labels_cache = {} # Initialize model cache model_cache = {} def get_model_inference(model_name): """ Get or create inference object for the specified model. Uses caching to avoid reloading models. Also downloads the model's class mapping CSV from the same repo as the model unless explicitly overridden in the model config. """ if model_name not in model_cache: if model_name not in MODELS: raise gr.Error(f"Unknown model: {model_name}") config = MODELS[model_name] model_path = config.get("path") or "" print(f"Loading model: {model_name} ({config.get('filename', '')})") # Check if local file exists, otherwise try to download if not model_path or not os.path.exists(model_path): if "repo_id" in config and "filename" in config: target_display = model_path or f"repo {config['repo_id']}" print(f"Model not found locally at {target_display}, attempting to download from Hugging Face...") try: from huggingface_hub import hf_hub_download download_kwargs = { "repo_id": config["repo_id"], "filename": config["filename"], } if config.get("subfolder"): download_kwargs["subfolder"] = config["subfolder"] if config.get("revision"): download_kwargs["revision"] = config["revision"] model_path = hf_hub_download(**download_kwargs) config["path"] = model_path print(f"Downloaded model to: {model_path}") except Exception as e: raise gr.Error(f"Could not load model from local path or Hugging Face: {e}") else: raise gr.Error(f"Model file not found at: {model_path}") # Create inference object based on type if config["type"] == "onnx": model_cache[model_name] = ONNXInference(model_path=model_path, model_arch=config["arch"], device=DEVICE) elif config["type"] == "pytorch": model_cache[model_name] = PyTorchInference( checkpoint_path=model_path, model_arch=config["arch"], device=DEVICE ) else: raise gr.Error(f"Unknown model type: {config['type']}") print(f"Model {model_name} loaded successfully ({config.get('filename', '')})") # Now download and load the class mapping CSV for this model try: # Prefer explicit csv repo if provided, otherwise fall back to the model's repo_id csv_repo = config.get("csv_repo_id") or config.get("repo_id") csv_filename = config.get("csv_filename") or CSV_FILENAME csv_subfolder = config.get("csv_subfolder") or config.get("subfolder") if not csv_repo: raise gr.Error(f"No repo available to find class mapping CSV for model '{model_name}'") print(f"Attempting to download class mapping CSV '{csv_filename}' from repo '{csv_repo}'") from huggingface_hub import hf_hub_download # First try with subfolder csv_path = None try: csv_download_kwargs = {"repo_id": csv_repo, "filename": csv_filename} if csv_subfolder: csv_download_kwargs["subfolder"] = csv_subfolder csv_path = hf_hub_download(**csv_download_kwargs) print(f"Downloaded CSV to: {csv_path}") except Exception as subfolder_error: # If subfolder download fails, try without subfolder (root) print(f"CSV not found in subfolder '{csv_subfolder}', trying root folder...") try: csv_download_kwargs = {"repo_id": csv_repo, "filename": csv_filename} csv_path = hf_hub_download(**csv_download_kwargs) print(f"Downloaded CSV from root folder to: {csv_path}") except Exception as root_error: # If both fail, raise the original subfolder error for better context raise subfolder_error labels_cache[model_name] = load_labels(csv_path) except Exception as e: # If CSV loading fails, surface a helpful error - labels are required for predictions raise gr.Error(f"Could not load class mapping CSV for model '{model_name}': {e}") return model_cache[model_name] def predict(image, model_selection, top_k, threshold): """ Main prediction function that takes UI inputs. """ # check if there even is image and throw error if image is none and dont continue if image is None: raise gr.Error("No image provided for prediction.") # Ensure top_k is an integer for slicing top_k = int(top_k) # Get inference object for selected model - this will also ensure labels are loaded for that model inference = get_model_inference(model_selection) # Retrieve labels for this model labels = labels_cache.get(model_selection) if labels is None: # This should not happen because get_model_inference loads labels, but guard anyway raise gr.Error(f"Label mapping not found for model '{model_selection}'") # Start timing start_time = time.time() # Run inference logits = inference.predict(image, top_k=top_k, threshold=threshold) # End timing inference_time = time.time() - start_time # Compute probabilities probabilities = softmax(logits) # Get all indices and their scores all_indices = np.argsort(probabilities)[::-1] tags = [] json_output = {} table_data = [] predictions_found = 0 for index in all_indices: score = probabilities[index] if score >= threshold and predictions_found < top_k: class_name = labels.get(index, f"Unknown Class #{index}") tags.append(class_name) json_output[class_name] = float(score) # Create Danbooru search URL with markdown link danbooru_url = f"https://danbooru.donmai.us/posts?tags={class_name.replace(' ', '_')}" artist_link = f"[{class_name}]({danbooru_url})" # Add copy button HTML with span instead of button copy_button = f"πŸ“‹" # Add row to table: [Rank, Artist (markdown link), Copy Button, Score] table_data.append([predictions_found + 1, artist_link, copy_button, f"{score:.2%}"]) predictions_found += 1 # Stop early if we have enough predictions if predictions_found >= top_k: break tags_output = ", ".join(tags) # Create DataFrame for display if table_data: df = pd.DataFrame(table_data, columns=["Rank", "Artist", "", "Score"]) else: df = pd.DataFrame(columns=["Rank", "Artist", "", "Score"]) # Get actual device/provider info from inference object if hasattr(inference, "execution_provider"): device_info = inference.execution_provider else: device_info = inference.device # Format time taken with 3-4 decimal places time_taken_str = f"- **{MODELS[model_selection]['type']}:** {device_info} | **Time taken:** {inference_time:.4f}s" return tags_output, df, json.dumps(json_output, indent=4), time_taken_str # --- Gradio Interface --- with gr.Blocks( css=""" * { box-sizing: border-box; } @media (max-width: 1022px) { #slider-row-container { flex-direction: column !important; } #slider-row-container > * { width: 100% !important; } #slider-row-container .block { width: 100% !important; } } #image-upload { max-height: 80vh !important; overflow: hidden !important; display: flex !important; flex-direction: column !important; } #image-upload .image-container { flex: 1 1 auto !important; min-height: 0 !important; overflow: hidden !important; display: flex !important; align-items: center !important; justify-content: center !important; } #image-upload img { max-height: 75vh !important; max-width: 100% !important; width: auto !important; height: auto !important; object-fit: contain !important; } #results-table-wrapper { overflow: hidden !important; width: 100% !important; } #results-table-wrapper .table-wrap, #results-table-wrapper .dataframe-wrap { overflow: hidden !important; width: 100% !important; } #results-table-wrapper table { width: 100% !important; table-layout: fixed !important; border-collapse: collapse !important; } #results-table-wrapper td, #results-table-wrapper th { overflow: hidden !important; text-overflow: ellipsis !important; white-space: nowrap !important; } #results-table-wrapper td:nth-child(1), #results-table-wrapper th:nth-child(1) { width: 55px !important; } #results-table-wrapper td:nth-child(2), #results-table-wrapper th:nth-child(2) { width: auto !important; } #results-table-wrapper td:nth-child(3), #results-table-wrapper th:nth-child(3) { width: 50px !important; border-left: none !important; text-align: center !important; padding: 0 !important; } #results-table-wrapper td:nth-child(4), #results-table-wrapper th:nth-child(4) { width: 69px !important; } #results-table-wrapper th:nth-child(3) { background: transparent !important; } #results-table-wrapper .copy-btn { cursor: pointer; width: 100%; height: 100%; display: flex; align-items: center; justify-content: center; user-select: none; } #results-table-wrapper .copy-btn:hover { background: rgba(128, 128, 128, 0.1); } """, js=""" function() { document.addEventListener('click', function(e) { if (e.target.classList.contains('copy-btn')) { const text = e.target.getAttribute('data-copy'); navigator.clipboard.writeText(text).then(() => { const original = e.target.textContent; e.target.textContent = 'βœ“'; setTimeout(() => { e.target.textContent = original; }, 1000); }); } }); // Fix right-click on links in DataFrame document.addEventListener('contextmenu', function(e) { // Check if the clicked element or its parent is a link inside the results table const link = e.target.closest('#results-table-wrapper a'); if (link) { e.stopPropagation(); // Let the browser's native context menu show for the link return true; } }, true); // Prevent Gradio's default click handling on links document.addEventListener('click', function(e) { const link = e.target.closest('#results-table-wrapper a'); if (link && e.button === 0) { e.stopPropagation(); // Allow normal link behavior (open in new tab due to markdown) return true; } }, true); } """, ) as demo: gr.Markdown("# Kaloscope Artist Style Classification") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image", elem_id="image-upload") with gr.Column(): submit_btn = gr.Button("Predict") tags_output = gr.Textbox(label="Predicted Tags", show_copy_button=True) prettier_output = gr.DataFrame( elem_id="results-table-wrapper", # value=[ # [ # 1, # "[Samplaaaae Artist](https://example.com)", # "πŸ“‹", # "95.00%", # ], # [ # 2, # "[Another Artist](https://example.com)", # "πŸ“‹", # "90.00%", # ], # [ # 3, # "[Third Artist](https://example.com)", # "πŸ“‹", # "85.00%", # ], # ], interactive=False, datatype=["number", "markdown", "html", "str"], headers=["Rank", "Artist", "", "Score"], ) json_accordion = gr.Accordion("JSON Output", open=False) with json_accordion: json_output = gr.Code(language="json", show_label=False, lines=7) with gr.Group(): model_selection = gr.Dropdown( choices=[ ( f"{name}", # f"{name} | Repo: {MODELS[name].get('repo_id') or 'local'}", name, ) for name in MODELS ], value=list(MODELS.keys())[0], label="Select Model", ) with gr.Row(elem_id="slider-row-container"): top_k_slider = gr.Slider(minimum=1, maximum=25, value=5, step=1, label="Top K") threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Threshold") time_display = gr.Markdown() # populated after prediction gr.Markdown( "Models sourced from [heathcliff01/Kaloscope](https://huggingface.co/heathcliff01/Kaloscope) & [heathcliff01/Kaloscope2.0](https://huggingface.co/heathcliff01/Kaloscope2.0) (Original PyTorch releases) " + "and [DraconicDragon/Kaloscope-onnx](https://huggingface.co/DraconicDragon/Kaloscope-onnx) (ONNX converted and EMA weights). \n" + "OpenVINOβ„’ will be used to accelerate ONNX CPU inference with ONNX CPUExecutionProvider as fallback." ) submit_btn.click( fn=predict, inputs=[image_input, model_selection, top_k_slider, threshold_slider], outputs=[tags_output, prettier_output, json_output, time_display], ) if __name__ == "__main__": demo.launch()