File size: 2,344 Bytes
4ea19be c2ca069 04b114a 0d60bc4 dd30521 4ea19be 5c030c4 91c176f 46f7c46 4ea19be f768864 606734b 91c176f 4ea19be f768864 4ea19be f768864 91c176f 4ea19be 637a6ad 8ee0f62 0d60bc4 e4edfab f768864 46f7c46 4ea19be f768864 637a6ad 4ea19be 5c030c4 4ea19be 5c030c4 4ea19be 0d60bc4 91c176f 5c030c4 f768864 637a6ad 5c030c4 f768864 0d60bc4 46f7c46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# app.py
import gradio as gr
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import torch
MODEL_LIST = [
"prithivMLmods/Trash-Net",
"yangy50/garbage-classification"
]
models = []
processors = []
devices = []
print("Loading models...")
for model_name in MODEL_LIST:
try:
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
model.eval()
processors.append(processor)
models.append(model)
devices.append(next(model.parameters()).device)
print(f"Loaded: {model_name}")
except Exception as e:
print(f"Failed to load {model_name}: {e}")
def classify_image(image: Image.Image):
results = {}
for model_name, processor, model, device in zip(MODEL_LIST, processors, models, devices):
try:
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
pred = outputs.logits.argmax(-1).item()
label = model.config.id2label[pred]
results[model_name] = label
except Exception as e:
results[model_name] = f"error:{e}"
# 输出每个模型的结果
results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
# 以 yangy50/garbage-classification 为最终结果
final_label = results.get("yangy50/garbage-classification", "Unknown")
results_text += f"\n\nFinal Label (base yangy50): {final_label}"
return results_text
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=[gr.Textbox(label="Model Predictions")],
title="Trash Classification",
description=(
"Upload an image, and the following models will classify it:\n"
"1. prithivMLmods/Trash-Net\n"
"2. yangy50/garbage-classification\n"
"The final label is based on yangy50/garbage-classification."
)
)
if __name__ == "__main__":
iface.launch()
|