|
|
|
|
|
import gradio as gr |
|
|
from transformers import AutoModelForImageClassification, AutoImageProcessor |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
MODEL_LIST = [ |
|
|
"prithivMLmods/Trash-Net", |
|
|
"yangy50/garbage-classification" |
|
|
] |
|
|
|
|
|
models = [] |
|
|
processors = [] |
|
|
devices = [] |
|
|
|
|
|
print("Loading models...") |
|
|
for model_name in MODEL_LIST: |
|
|
try: |
|
|
processor = AutoImageProcessor.from_pretrained(model_name) |
|
|
model = AutoModelForImageClassification.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None |
|
|
) |
|
|
model.eval() |
|
|
processors.append(processor) |
|
|
models.append(model) |
|
|
devices.append(next(model.parameters()).device) |
|
|
print(f"Loaded: {model_name}") |
|
|
except Exception as e: |
|
|
print(f"Failed to load {model_name}: {e}") |
|
|
|
|
|
def classify_image(image: Image.Image): |
|
|
results = {} |
|
|
for model_name, processor, model, device in zip(MODEL_LIST, processors, models, devices): |
|
|
try: |
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
pred = outputs.logits.argmax(-1).item() |
|
|
label = model.config.id2label[pred] |
|
|
results[model_name] = label |
|
|
except Exception as e: |
|
|
results[model_name] = f"error:{e}" |
|
|
|
|
|
|
|
|
results_text = "\n".join([f"{name}: {label}" for name, label in results.items()]) |
|
|
|
|
|
|
|
|
final_label = results.get("yangy50/garbage-classification", "Unknown") |
|
|
results_text += f"\n\nFinal Label (base yangy50): {final_label}" |
|
|
return results_text |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=classify_image, |
|
|
inputs=gr.Image(type="pil", label="Upload Image"), |
|
|
outputs=[gr.Textbox(label="Model Predictions")], |
|
|
title="Trash Classification", |
|
|
description=( |
|
|
"Upload an image, and the following models will classify it:\n" |
|
|
"1. prithivMLmods/Trash-Net\n" |
|
|
"2. yangy50/garbage-classification\n" |
|
|
"The final label is based on yangy50/garbage-classification." |
|
|
) |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |
|
|
|