|
|
import json |
|
|
import os |
|
|
import time |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
from hf_repo_utils import auto_register_kaloscope_variants, validate_models |
|
|
|
|
|
|
|
|
from inference_onnx import ONNXInference, softmax |
|
|
from inference_pytorch import PyTorchInference |
|
|
|
|
|
|
|
|
try: |
|
|
import lsnet.lsnet_artist |
|
|
except ImportError as e: |
|
|
print(f"Error: {e}") |
|
|
raise gr.Error("Could not import lsnet.lsnet_artist. Please ensure the lsnet folder is in your workspace.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"Kaloscope v2.0 ONNX": { |
|
|
"type": "onnx", |
|
|
"path": "", |
|
|
"repo_id": "DraconicDragon/Kaloscope-onnx", |
|
|
"subfolder": "v2.0", |
|
|
"filename": "kaloscope_2-0.onnx", |
|
|
"arch": "lsnet_xl_artist_448", |
|
|
}, |
|
|
"Kaloscope v2.0": { |
|
|
"type": "pytorch", |
|
|
"path": "", |
|
|
"repo_id": "heathcliff01/Kaloscope2.0", |
|
|
"subfolder": "448-90.13", |
|
|
"filename": "best_checkpoint.pth", |
|
|
"arch": "lsnet_xl_artist_448", |
|
|
}, |
|
|
"Kaloscope v1.1 ONNX": { |
|
|
"type": "onnx", |
|
|
"path": "", |
|
|
"repo_id": "DraconicDragon/Kaloscope-onnx", |
|
|
"filename": "kaloscope_1-1.onnx", |
|
|
"arch": "lsnet_xl_artist", |
|
|
}, |
|
|
"Kaloscope v1.1": { |
|
|
"type": "pytorch", |
|
|
"path": "", |
|
|
"repo_id": "heathcliff01/Kaloscope", |
|
|
"subfolder": "224-85.65", |
|
|
"filename": "best_checkpoint.pth", |
|
|
"arch": "lsnet_xl_artist", |
|
|
}, |
|
|
"Kaloscope v1.0 ONNX": { |
|
|
"type": "onnx", |
|
|
"path": "", |
|
|
"repo_id": "DraconicDragon/Kaloscope-onnx", |
|
|
"filename": "kaloscope_1-0.onnx", |
|
|
"arch": "lsnet_xl_artist", |
|
|
}, |
|
|
"Kaloscope v1.0": { |
|
|
"type": "pytorch", |
|
|
"path": "", |
|
|
"repo_id": "heathcliff01/Kaloscope", |
|
|
"filename": "best_checkpoint.pth", |
|
|
"arch": "lsnet_xl_artist", |
|
|
}, |
|
|
"Kaloscope v1.0 ema": { |
|
|
"type": "pytorch", |
|
|
"path": "", |
|
|
"repo_id": "DraconicDragon/Kaloscope-onnx", |
|
|
"filename": "best_checkpoint_ema.pth", |
|
|
"arch": "lsnet_xl_artist", |
|
|
}, |
|
|
} |
|
|
MODELS = validate_models(MODELS) |
|
|
auto_register_kaloscope_variants(MODELS) |
|
|
|
|
|
|
|
|
CSV_FILENAME = "class_mapping.csv" |
|
|
|
|
|
|
|
|
try: |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
except Exception: |
|
|
DEVICE = "cpu" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_labels(csv_path): |
|
|
""" |
|
|
Loads the class labels from the provided CSV file into a dictionary. |
|
|
""" |
|
|
try: |
|
|
df = pd.read_csv(csv_path) |
|
|
if "class_id" not in df.columns or "class_name" not in df.columns: |
|
|
raise gr.Error("CSV file must have 'class_id' and 'class_name' columns.") |
|
|
df["class_name"] = df["class_name"].str.strip("'") |
|
|
return dict(zip(df["class_id"], df["class_name"])) |
|
|
except FileNotFoundError: |
|
|
raise gr.Error(f"CSV file not found at '{csv_path}'") |
|
|
except Exception as e: |
|
|
raise gr.Error(f"Error reading CSV file: {e}") |
|
|
|
|
|
|
|
|
|
|
|
labels_cache = {} |
|
|
|
|
|
|
|
|
model_cache = {} |
|
|
|
|
|
|
|
|
def get_model_inference(model_name): |
|
|
""" |
|
|
Get or create inference object for the specified model. |
|
|
Uses caching to avoid reloading models. Also downloads the model's class mapping CSV |
|
|
from the same repo as the model unless explicitly overridden in the model config. |
|
|
""" |
|
|
|
|
|
if model_name not in model_cache: |
|
|
if model_name not in MODELS: |
|
|
raise gr.Error(f"Unknown model: {model_name}") |
|
|
|
|
|
config = MODELS[model_name] |
|
|
model_path = config.get("path") or "" |
|
|
|
|
|
print(f"Loading model: {model_name} ({config.get('filename', '<no filename>')})") |
|
|
|
|
|
|
|
|
if not model_path or not os.path.exists(model_path): |
|
|
if "repo_id" in config and "filename" in config: |
|
|
target_display = model_path or f"repo {config['repo_id']}" |
|
|
print(f"Model not found locally at {target_display}, attempting to download from Hugging Face...") |
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
download_kwargs = { |
|
|
"repo_id": config["repo_id"], |
|
|
"filename": config["filename"], |
|
|
} |
|
|
if config.get("subfolder"): |
|
|
download_kwargs["subfolder"] = config["subfolder"] |
|
|
if config.get("revision"): |
|
|
download_kwargs["revision"] = config["revision"] |
|
|
|
|
|
model_path = hf_hub_download(**download_kwargs) |
|
|
config["path"] = model_path |
|
|
print(f"Downloaded model to: {model_path}") |
|
|
except Exception as e: |
|
|
raise gr.Error(f"Could not load model from local path or Hugging Face: {e}") |
|
|
else: |
|
|
raise gr.Error(f"Model file not found at: {model_path}") |
|
|
|
|
|
|
|
|
if config["type"] == "onnx": |
|
|
model_cache[model_name] = ONNXInference(model_path=model_path, model_arch=config["arch"], device=DEVICE) |
|
|
elif config["type"] == "pytorch": |
|
|
model_cache[model_name] = PyTorchInference( |
|
|
checkpoint_path=model_path, model_arch=config["arch"], device=DEVICE |
|
|
) |
|
|
else: |
|
|
raise gr.Error(f"Unknown model type: {config['type']}") |
|
|
|
|
|
print(f"Model {model_name} loaded successfully ({config.get('filename', '<no filename>')})") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
csv_repo = config.get("csv_repo_id") or config.get("repo_id") |
|
|
csv_filename = config.get("csv_filename") or CSV_FILENAME |
|
|
csv_subfolder = config.get("csv_subfolder") or config.get("subfolder") |
|
|
|
|
|
if not csv_repo: |
|
|
raise gr.Error(f"No repo available to find class mapping CSV for model '{model_name}'") |
|
|
|
|
|
print(f"Attempting to download class mapping CSV '{csv_filename}' from repo '{csv_repo}'") |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
csv_path = None |
|
|
try: |
|
|
csv_download_kwargs = {"repo_id": csv_repo, "filename": csv_filename} |
|
|
if csv_subfolder: |
|
|
csv_download_kwargs["subfolder"] = csv_subfolder |
|
|
|
|
|
csv_path = hf_hub_download(**csv_download_kwargs) |
|
|
print(f"Downloaded CSV to: {csv_path}") |
|
|
except Exception as subfolder_error: |
|
|
|
|
|
print(f"CSV not found in subfolder '{csv_subfolder}', trying root folder...") |
|
|
try: |
|
|
csv_download_kwargs = {"repo_id": csv_repo, "filename": csv_filename} |
|
|
csv_path = hf_hub_download(**csv_download_kwargs) |
|
|
print(f"Downloaded CSV from root folder to: {csv_path}") |
|
|
except Exception as root_error: |
|
|
|
|
|
raise subfolder_error |
|
|
|
|
|
labels_cache[model_name] = load_labels(csv_path) |
|
|
except Exception as e: |
|
|
|
|
|
raise gr.Error(f"Could not load class mapping CSV for model '{model_name}': {e}") |
|
|
|
|
|
return model_cache[model_name] |
|
|
|
|
|
|
|
|
def predict(image, model_selection, top_k, threshold): |
|
|
""" |
|
|
Main prediction function that takes UI inputs. |
|
|
""" |
|
|
|
|
|
if image is None: |
|
|
raise gr.Error("No image provided for prediction.") |
|
|
|
|
|
|
|
|
top_k = int(top_k) |
|
|
|
|
|
|
|
|
inference = get_model_inference(model_selection) |
|
|
|
|
|
|
|
|
labels = labels_cache.get(model_selection) |
|
|
if labels is None: |
|
|
|
|
|
raise gr.Error(f"Label mapping not found for model '{model_selection}'") |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
logits = inference.predict(image, top_k=top_k, threshold=threshold) |
|
|
|
|
|
|
|
|
inference_time = time.time() - start_time |
|
|
|
|
|
|
|
|
probabilities = softmax(logits) |
|
|
|
|
|
|
|
|
all_indices = np.argsort(probabilities)[::-1] |
|
|
|
|
|
tags = [] |
|
|
json_output = {} |
|
|
table_data = [] |
|
|
predictions_found = 0 |
|
|
|
|
|
for index in all_indices: |
|
|
score = probabilities[index] |
|
|
if score >= threshold and predictions_found < top_k: |
|
|
class_name = labels.get(index, f"Unknown Class #{index}") |
|
|
tags.append(class_name) |
|
|
json_output[class_name] = float(score) |
|
|
|
|
|
|
|
|
danbooru_url = f"https://danbooru.donmai.us/posts?tags={class_name.replace(' ', '_')}" |
|
|
artist_link = f"[{class_name}]({danbooru_url})" |
|
|
|
|
|
|
|
|
copy_button = f"<span class='copy-btn' data-copy='{class_name}'>📋</span>" |
|
|
|
|
|
|
|
|
table_data.append([predictions_found + 1, artist_link, copy_button, f"{score:.2%}"]) |
|
|
predictions_found += 1 |
|
|
|
|
|
if predictions_found >= top_k: |
|
|
break |
|
|
|
|
|
tags_output = ", ".join(tags) |
|
|
|
|
|
|
|
|
if table_data: |
|
|
df = pd.DataFrame(table_data, columns=["Rank", "Artist", "", "Score"]) |
|
|
else: |
|
|
df = pd.DataFrame(columns=["Rank", "Artist", "", "Score"]) |
|
|
|
|
|
|
|
|
if hasattr(inference, "execution_provider"): |
|
|
device_info = inference.execution_provider |
|
|
else: |
|
|
device_info = inference.device |
|
|
|
|
|
|
|
|
time_taken_str = f"- **{MODELS[model_selection]['type']}:** {device_info} | **Time taken:** {inference_time:.4f}s" |
|
|
|
|
|
return tags_output, df, json.dumps(json_output, indent=4), time_taken_str |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
css=""" |
|
|
* { box-sizing: border-box; } |
|
|
@media (max-width: 1022px) { |
|
|
#slider-row-container { |
|
|
flex-direction: column !important; |
|
|
} |
|
|
#slider-row-container > * { |
|
|
width: 100% !important; |
|
|
} |
|
|
#slider-row-container .block { |
|
|
width: 100% !important; |
|
|
} |
|
|
} |
|
|
#image-upload { |
|
|
max-height: 80vh !important; |
|
|
overflow: hidden !important; |
|
|
display: flex !important; |
|
|
flex-direction: column !important; |
|
|
} |
|
|
#image-upload .image-container { |
|
|
flex: 1 1 auto !important; |
|
|
min-height: 0 !important; |
|
|
overflow: hidden !important; |
|
|
display: flex !important; |
|
|
align-items: center !important; |
|
|
justify-content: center !important; |
|
|
} |
|
|
#image-upload img { |
|
|
max-height: 75vh !important; |
|
|
max-width: 100% !important; |
|
|
width: auto !important; |
|
|
height: auto !important; |
|
|
object-fit: contain !important; |
|
|
} |
|
|
#results-table-wrapper { |
|
|
overflow: hidden !important; |
|
|
width: 100% !important; |
|
|
} |
|
|
#results-table-wrapper .table-wrap, |
|
|
#results-table-wrapper .dataframe-wrap { |
|
|
overflow: hidden !important; |
|
|
width: 100% !important; |
|
|
} |
|
|
#results-table-wrapper table { |
|
|
width: 100% !important; |
|
|
table-layout: fixed !important; |
|
|
border-collapse: collapse !important; |
|
|
} |
|
|
#results-table-wrapper td, |
|
|
#results-table-wrapper th { |
|
|
overflow: hidden !important; |
|
|
text-overflow: ellipsis !important; |
|
|
white-space: nowrap !important; |
|
|
} |
|
|
#results-table-wrapper td:nth-child(1), |
|
|
#results-table-wrapper th:nth-child(1) { |
|
|
width: 55px !important; |
|
|
} |
|
|
#results-table-wrapper td:nth-child(2), |
|
|
#results-table-wrapper th:nth-child(2) { |
|
|
width: auto !important; |
|
|
} |
|
|
#results-table-wrapper td:nth-child(3), |
|
|
#results-table-wrapper th:nth-child(3) { |
|
|
width: 50px !important; |
|
|
border-left: none !important; |
|
|
text-align: center !important; |
|
|
padding: 0 !important; |
|
|
} |
|
|
#results-table-wrapper td:nth-child(4), |
|
|
#results-table-wrapper th:nth-child(4) { |
|
|
width: 69px !important; |
|
|
} |
|
|
#results-table-wrapper th:nth-child(3) { |
|
|
background: transparent !important; |
|
|
} |
|
|
#results-table-wrapper .copy-btn { |
|
|
cursor: pointer; |
|
|
width: 100%; |
|
|
height: 100%; |
|
|
display: flex; |
|
|
align-items: center; |
|
|
justify-content: center; |
|
|
user-select: none; |
|
|
} |
|
|
#results-table-wrapper .copy-btn:hover { |
|
|
background: rgba(128, 128, 128, 0.1); |
|
|
} |
|
|
""", |
|
|
js=""" |
|
|
function() { |
|
|
document.addEventListener('click', function(e) { |
|
|
if (e.target.classList.contains('copy-btn')) { |
|
|
const text = e.target.getAttribute('data-copy'); |
|
|
navigator.clipboard.writeText(text).then(() => { |
|
|
const original = e.target.textContent; |
|
|
e.target.textContent = '✓'; |
|
|
setTimeout(() => { e.target.textContent = original; }, 1000); |
|
|
}); |
|
|
} |
|
|
}); |
|
|
|
|
|
// Fix right-click on links in DataFrame |
|
|
document.addEventListener('contextmenu', function(e) { |
|
|
// Check if the clicked element or its parent is a link inside the results table |
|
|
const link = e.target.closest('#results-table-wrapper a'); |
|
|
if (link) { |
|
|
e.stopPropagation(); |
|
|
// Let the browser's native context menu show for the link |
|
|
return true; |
|
|
} |
|
|
}, true); |
|
|
|
|
|
// Prevent Gradio's default click handling on links |
|
|
document.addEventListener('click', function(e) { |
|
|
const link = e.target.closest('#results-table-wrapper a'); |
|
|
if (link && e.button === 0) { |
|
|
e.stopPropagation(); |
|
|
// Allow normal link behavior (open in new tab due to markdown) |
|
|
return true; |
|
|
} |
|
|
}, true); |
|
|
} |
|
|
""", |
|
|
) as demo: |
|
|
|
|
|
gr.Markdown("# Kaloscope Artist Style Classification") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_input = gr.Image(type="pil", label="Upload Image", elem_id="image-upload") |
|
|
|
|
|
with gr.Column(): |
|
|
submit_btn = gr.Button("Predict") |
|
|
|
|
|
tags_output = gr.Textbox(label="Predicted Tags", show_copy_button=True) |
|
|
prettier_output = gr.DataFrame( |
|
|
elem_id="results-table-wrapper", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
interactive=False, |
|
|
datatype=["number", "markdown", "html", "str"], |
|
|
headers=["Rank", "Artist", "", "Score"], |
|
|
) |
|
|
json_accordion = gr.Accordion("JSON Output", open=False) |
|
|
with json_accordion: |
|
|
json_output = gr.Code(language="json", show_label=False, lines=7) |
|
|
|
|
|
with gr.Group(): |
|
|
model_selection = gr.Dropdown( |
|
|
choices=[ |
|
|
( |
|
|
f"{name}", |
|
|
|
|
|
name, |
|
|
) |
|
|
for name in MODELS |
|
|
], |
|
|
value=list(MODELS.keys())[0], |
|
|
label="Select Model", |
|
|
) |
|
|
with gr.Row(elem_id="slider-row-container"): |
|
|
top_k_slider = gr.Slider(minimum=1, maximum=25, value=5, step=1, label="Top K") |
|
|
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Threshold") |
|
|
time_display = gr.Markdown() |
|
|
|
|
|
gr.Markdown( |
|
|
"Models sourced from [heathcliff01/Kaloscope](https://huggingface.co/heathcliff01/Kaloscope) & [heathcliff01/Kaloscope2.0](https://huggingface.co/heathcliff01/Kaloscope2.0) (Original PyTorch releases) " |
|
|
+ "and [DraconicDragon/Kaloscope-onnx](https://huggingface.co/DraconicDragon/Kaloscope-onnx) (ONNX converted and EMA weights). \n" |
|
|
+ "OpenVINO™ will be used to accelerate ONNX CPU inference with ONNX CPUExecutionProvider as fallback." |
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=predict, |
|
|
inputs=[image_input, model_selection, top_k_slider, threshold_slider], |
|
|
outputs=[tags_output, prettier_output, json_output, time_display], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|