trashnet / app.py
TangYiJay's picture
app.py
5c030c4 verified
# app.py
import gradio as gr
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import torch
MODEL_LIST = [
"prithivMLmods/Trash-Net",
"yangy50/garbage-classification"
]
models = []
processors = []
devices = []
print("Loading models...")
for model_name in MODEL_LIST:
try:
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
model.eval()
processors.append(processor)
models.append(model)
devices.append(next(model.parameters()).device)
print(f"Loaded: {model_name}")
except Exception as e:
print(f"Failed to load {model_name}: {e}")
def classify_image(image: Image.Image):
results = {}
for model_name, processor, model, device in zip(MODEL_LIST, processors, models, devices):
try:
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
pred = outputs.logits.argmax(-1).item()
label = model.config.id2label[pred]
results[model_name] = label
except Exception as e:
results[model_name] = f"error:{e}"
# 输出每个模型的结果
results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
# 以 yangy50/garbage-classification 为最终结果
final_label = results.get("yangy50/garbage-classification", "Unknown")
results_text += f"\n\nFinal Label (base yangy50): {final_label}"
return results_text
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=[gr.Textbox(label="Model Predictions")],
title="Trash Classification",
description=(
"Upload an image, and the following models will classify it:\n"
"1. prithivMLmods/Trash-Net\n"
"2. yangy50/garbage-classification\n"
"The final label is based on yangy50/garbage-classification."
)
)
if __name__ == "__main__":
iface.launch()