DraconicDragon's picture
Update app.py
fecccd0 verified
import json
import os
import time
import gradio as gr
import numpy as np
import pandas as pd
import torch
from hf_repo_utils import auto_register_kaloscope_variants, validate_models
# Import inference implementations
from inference_onnx import ONNXInference, softmax
from inference_pytorch import PyTorchInference
# This import is crucial to register the lsnet models with timm
try:
import lsnet.lsnet_artist # noqa: F401
except ImportError as e:
print(f"Error: {e}")
raise gr.Error("Could not import lsnet.lsnet_artist. Please ensure the lsnet folder is in your workspace.")
# ------------------------------------------------------------
# CONFIG SECTION - Model Configuration
# ------------------------------------------------------------
# Define available models with their configuration
# Format: "display_name": {
# "type": "onnx" or "pytorch",
# "path": "local path or repo:filename", # for local use - if left empty, will use repo id
# "repo_id": "huggingface repo id" (optional, but path cant be empty if repo_id empty),
# "subfolder": "subfolder in repo" (optional, if applicable),
# "arch": "model architecture name" - lsnet_xl_artist is expected for usual Kaloscope releases b, l, s exist but unused
# # Optional CSV mapping settings - if omitted the model's repo_id will be used to find class_mapping.csv
# # "csv_filename": "class_mapping.csv",
# # "csv_subfolder": "optional/subfolder/for/csv",
# }
MODELS = {
"Kaloscope v2.0 ONNX": {
"type": "onnx",
"path": "",
"repo_id": "DraconicDragon/Kaloscope-onnx",
"subfolder": "v2.0",
"filename": "kaloscope_2-0.onnx",
"arch": "lsnet_xl_artist_448",
},
"Kaloscope v2.0": {
"type": "pytorch",
"path": "",
"repo_id": "heathcliff01/Kaloscope2.0",
"subfolder": "448-90.13",
"filename": "best_checkpoint.pth",
"arch": "lsnet_xl_artist_448",
},
"Kaloscope v1.1 ONNX": {
"type": "onnx",
"path": "",
"repo_id": "DraconicDragon/Kaloscope-onnx",
"filename": "kaloscope_1-1.onnx",
"arch": "lsnet_xl_artist",
},
"Kaloscope v1.1": {
"type": "pytorch",
"path": "",
"repo_id": "heathcliff01/Kaloscope",
"subfolder": "224-85.65",
"filename": "best_checkpoint.pth",
"arch": "lsnet_xl_artist",
},
"Kaloscope v1.0 ONNX": {
"type": "onnx",
"path": "",
"repo_id": "DraconicDragon/Kaloscope-onnx",
"filename": "kaloscope_1-0.onnx",
"arch": "lsnet_xl_artist",
},
"Kaloscope v1.0": {
"type": "pytorch",
"path": "",
"repo_id": "heathcliff01/Kaloscope",
"filename": "best_checkpoint.pth",
"arch": "lsnet_xl_artist",
},
"Kaloscope v1.0 ema": {
"type": "pytorch",
"path": "",
"repo_id": "DraconicDragon/Kaloscope-onnx",
"filename": "best_checkpoint_ema.pth",
"arch": "lsnet_xl_artist",
},
}
MODELS = validate_models(MODELS)
auto_register_kaloscope_variants(MODELS)
# Default CSV filename used if a model doesn't specify one
CSV_FILENAME = "class_mapping.csv"
# Device configuration
try:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
except Exception:
DEVICE = "cpu"
# ------------------------------------------------------------
def load_labels(csv_path):
"""
Loads the class labels from the provided CSV file into a dictionary.
"""
try:
df = pd.read_csv(csv_path)
if "class_id" not in df.columns or "class_name" not in df.columns:
raise gr.Error("CSV file must have 'class_id' and 'class_name' columns.")
df["class_name"] = df["class_name"].str.strip("'")
return dict(zip(df["class_id"], df["class_name"]))
except FileNotFoundError:
raise gr.Error(f"CSV file not found at '{csv_path}'")
except Exception as e:
raise gr.Error(f"Error reading CSV file: {e}")
# Labels cache - maps model_name -> labels dict
labels_cache = {}
# Initialize model cache
model_cache = {}
def get_model_inference(model_name):
"""
Get or create inference object for the specified model.
Uses caching to avoid reloading models. Also downloads the model's class mapping CSV
from the same repo as the model unless explicitly overridden in the model config.
"""
if model_name not in model_cache:
if model_name not in MODELS:
raise gr.Error(f"Unknown model: {model_name}")
config = MODELS[model_name]
model_path = config.get("path") or ""
print(f"Loading model: {model_name} ({config.get('filename', '<no filename>')})")
# Check if local file exists, otherwise try to download
if not model_path or not os.path.exists(model_path):
if "repo_id" in config and "filename" in config:
target_display = model_path or f"repo {config['repo_id']}"
print(f"Model not found locally at {target_display}, attempting to download from Hugging Face...")
try:
from huggingface_hub import hf_hub_download
download_kwargs = {
"repo_id": config["repo_id"],
"filename": config["filename"],
}
if config.get("subfolder"):
download_kwargs["subfolder"] = config["subfolder"]
if config.get("revision"):
download_kwargs["revision"] = config["revision"]
model_path = hf_hub_download(**download_kwargs)
config["path"] = model_path
print(f"Downloaded model to: {model_path}")
except Exception as e:
raise gr.Error(f"Could not load model from local path or Hugging Face: {e}")
else:
raise gr.Error(f"Model file not found at: {model_path}")
# Create inference object based on type
if config["type"] == "onnx":
model_cache[model_name] = ONNXInference(model_path=model_path, model_arch=config["arch"], device=DEVICE)
elif config["type"] == "pytorch":
model_cache[model_name] = PyTorchInference(
checkpoint_path=model_path, model_arch=config["arch"], device=DEVICE
)
else:
raise gr.Error(f"Unknown model type: {config['type']}")
print(f"Model {model_name} loaded successfully ({config.get('filename', '<no filename>')})")
# Now download and load the class mapping CSV for this model
try:
# Prefer explicit csv repo if provided, otherwise fall back to the model's repo_id
csv_repo = config.get("csv_repo_id") or config.get("repo_id")
csv_filename = config.get("csv_filename") or CSV_FILENAME
csv_subfolder = config.get("csv_subfolder") or config.get("subfolder")
if not csv_repo:
raise gr.Error(f"No repo available to find class mapping CSV for model '{model_name}'")
print(f"Attempting to download class mapping CSV '{csv_filename}' from repo '{csv_repo}'")
from huggingface_hub import hf_hub_download
# First try with subfolder
csv_path = None
try:
csv_download_kwargs = {"repo_id": csv_repo, "filename": csv_filename}
if csv_subfolder:
csv_download_kwargs["subfolder"] = csv_subfolder
csv_path = hf_hub_download(**csv_download_kwargs)
print(f"Downloaded CSV to: {csv_path}")
except Exception as subfolder_error:
# If subfolder download fails, try without subfolder (root)
print(f"CSV not found in subfolder '{csv_subfolder}', trying root folder...")
try:
csv_download_kwargs = {"repo_id": csv_repo, "filename": csv_filename}
csv_path = hf_hub_download(**csv_download_kwargs)
print(f"Downloaded CSV from root folder to: {csv_path}")
except Exception as root_error:
# If both fail, raise the original subfolder error for better context
raise subfolder_error
labels_cache[model_name] = load_labels(csv_path)
except Exception as e:
# If CSV loading fails, surface a helpful error - labels are required for predictions
raise gr.Error(f"Could not load class mapping CSV for model '{model_name}': {e}")
return model_cache[model_name]
def predict(image, model_selection, top_k, threshold):
"""
Main prediction function that takes UI inputs.
"""
# check if there even is image and throw error if image is none and dont continue
if image is None:
raise gr.Error("No image provided for prediction.")
# Ensure top_k is an integer for slicing
top_k = int(top_k)
# Get inference object for selected model - this will also ensure labels are loaded for that model
inference = get_model_inference(model_selection)
# Retrieve labels for this model
labels = labels_cache.get(model_selection)
if labels is None:
# This should not happen because get_model_inference loads labels, but guard anyway
raise gr.Error(f"Label mapping not found for model '{model_selection}'")
# Start timing
start_time = time.time()
# Run inference
logits = inference.predict(image, top_k=top_k, threshold=threshold)
# End timing
inference_time = time.time() - start_time
# Compute probabilities
probabilities = softmax(logits)
# Get all indices and their scores
all_indices = np.argsort(probabilities)[::-1]
tags = []
json_output = {}
table_data = []
predictions_found = 0
for index in all_indices:
score = probabilities[index]
if score >= threshold and predictions_found < top_k:
class_name = labels.get(index, f"Unknown Class #{index}")
tags.append(class_name)
json_output[class_name] = float(score)
# Create Danbooru search URL with markdown link
danbooru_url = f"https://danbooru.donmai.us/posts?tags={class_name.replace(' ', '_')}"
artist_link = f"[{class_name}]({danbooru_url})"
# Add copy button HTML with span instead of button
copy_button = f"<span class='copy-btn' data-copy='{class_name}'>📋</span>"
# Add row to table: [Rank, Artist (markdown link), Copy Button, Score]
table_data.append([predictions_found + 1, artist_link, copy_button, f"{score:.2%}"])
predictions_found += 1
# Stop early if we have enough predictions
if predictions_found >= top_k:
break
tags_output = ", ".join(tags)
# Create DataFrame for display
if table_data:
df = pd.DataFrame(table_data, columns=["Rank", "Artist", "", "Score"])
else:
df = pd.DataFrame(columns=["Rank", "Artist", "", "Score"])
# Get actual device/provider info from inference object
if hasattr(inference, "execution_provider"):
device_info = inference.execution_provider
else:
device_info = inference.device
# Format time taken with 3-4 decimal places
time_taken_str = f"- **{MODELS[model_selection]['type']}:** {device_info} | **Time taken:** {inference_time:.4f}s"
return tags_output, df, json.dumps(json_output, indent=4), time_taken_str
# --- Gradio Interface ---
with gr.Blocks(
css="""
* { box-sizing: border-box; }
@media (max-width: 1022px) {
#slider-row-container {
flex-direction: column !important;
}
#slider-row-container > * {
width: 100% !important;
}
#slider-row-container .block {
width: 100% !important;
}
}
#image-upload {
max-height: 80vh !important;
overflow: hidden !important;
display: flex !important;
flex-direction: column !important;
}
#image-upload .image-container {
flex: 1 1 auto !important;
min-height: 0 !important;
overflow: hidden !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
}
#image-upload img {
max-height: 75vh !important;
max-width: 100% !important;
width: auto !important;
height: auto !important;
object-fit: contain !important;
}
#results-table-wrapper {
overflow: hidden !important;
width: 100% !important;
}
#results-table-wrapper .table-wrap,
#results-table-wrapper .dataframe-wrap {
overflow: hidden !important;
width: 100% !important;
}
#results-table-wrapper table {
width: 100% !important;
table-layout: fixed !important;
border-collapse: collapse !important;
}
#results-table-wrapper td,
#results-table-wrapper th {
overflow: hidden !important;
text-overflow: ellipsis !important;
white-space: nowrap !important;
}
#results-table-wrapper td:nth-child(1),
#results-table-wrapper th:nth-child(1) {
width: 55px !important;
}
#results-table-wrapper td:nth-child(2),
#results-table-wrapper th:nth-child(2) {
width: auto !important;
}
#results-table-wrapper td:nth-child(3),
#results-table-wrapper th:nth-child(3) {
width: 50px !important;
border-left: none !important;
text-align: center !important;
padding: 0 !important;
}
#results-table-wrapper td:nth-child(4),
#results-table-wrapper th:nth-child(4) {
width: 69px !important;
}
#results-table-wrapper th:nth-child(3) {
background: transparent !important;
}
#results-table-wrapper .copy-btn {
cursor: pointer;
width: 100%;
height: 100%;
display: flex;
align-items: center;
justify-content: center;
user-select: none;
}
#results-table-wrapper .copy-btn:hover {
background: rgba(128, 128, 128, 0.1);
}
""",
js="""
function() {
document.addEventListener('click', function(e) {
if (e.target.classList.contains('copy-btn')) {
const text = e.target.getAttribute('data-copy');
navigator.clipboard.writeText(text).then(() => {
const original = e.target.textContent;
e.target.textContent = '✓';
setTimeout(() => { e.target.textContent = original; }, 1000);
});
}
});
// Fix right-click on links in DataFrame
document.addEventListener('contextmenu', function(e) {
// Check if the clicked element or its parent is a link inside the results table
const link = e.target.closest('#results-table-wrapper a');
if (link) {
e.stopPropagation();
// Let the browser's native context menu show for the link
return true;
}
}, true);
// Prevent Gradio's default click handling on links
document.addEventListener('click', function(e) {
const link = e.target.closest('#results-table-wrapper a');
if (link && e.button === 0) {
e.stopPropagation();
// Allow normal link behavior (open in new tab due to markdown)
return true;
}
}, true);
}
""",
) as demo:
gr.Markdown("# Kaloscope Artist Style Classification")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image", elem_id="image-upload")
with gr.Column():
submit_btn = gr.Button("Predict")
tags_output = gr.Textbox(label="Predicted Tags", show_copy_button=True)
prettier_output = gr.DataFrame(
elem_id="results-table-wrapper",
# value=[
# [
# 1,
# "[Samplaaaae Artist](https://example.com)",
# "<span class='copy-btn' data-copy='Samplaaaae Artist'>📋</span>",
# "95.00%",
# ],
# [
# 2,
# "[Another Artist](https://example.com)",
# "<span class='copy-btn' data-copy='Another Artist'>📋</span>",
# "90.00%",
# ],
# [
# 3,
# "[Third Artist](https://example.com)",
# "<span class='copy-btn' data-copy='Third Artist'>📋</span>",
# "85.00%",
# ],
# ],
interactive=False,
datatype=["number", "markdown", "html", "str"],
headers=["Rank", "Artist", "", "Score"],
)
json_accordion = gr.Accordion("JSON Output", open=False)
with json_accordion:
json_output = gr.Code(language="json", show_label=False, lines=7)
with gr.Group():
model_selection = gr.Dropdown(
choices=[
(
f"{name}",
# f"{name} | Repo: {MODELS[name].get('repo_id') or 'local'}",
name,
)
for name in MODELS
],
value=list(MODELS.keys())[0],
label="Select Model",
)
with gr.Row(elem_id="slider-row-container"):
top_k_slider = gr.Slider(minimum=1, maximum=25, value=5, step=1, label="Top K")
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Threshold")
time_display = gr.Markdown() # populated after prediction
gr.Markdown(
"Models sourced from [heathcliff01/Kaloscope](https://huggingface.co/heathcliff01/Kaloscope) & [heathcliff01/Kaloscope2.0](https://huggingface.co/heathcliff01/Kaloscope2.0) (Original PyTorch releases) "
+ "and [DraconicDragon/Kaloscope-onnx](https://huggingface.co/DraconicDragon/Kaloscope-onnx) (ONNX converted and EMA weights). \n"
+ "OpenVINO™ will be used to accelerate ONNX CPU inference with ONNX CPUExecutionProvider as fallback."
)
submit_btn.click(
fn=predict,
inputs=[image_input, model_selection, top_k_slider, threshold_slider],
outputs=[tags_output, prettier_output, json_output, time_display],
)
if __name__ == "__main__":
demo.launch()