File size: 2,344 Bytes
4ea19be
c2ca069
04b114a
0d60bc4
 
dd30521
4ea19be
 
5c030c4
91c176f
 
46f7c46
4ea19be
f768864
606734b
91c176f
4ea19be
 
 
f768864
 
 
 
 
4ea19be
 
 
f768864
91c176f
4ea19be
637a6ad
8ee0f62
0d60bc4
e4edfab
f768864
 
 
 
 
 
46f7c46
4ea19be
 
f768864
637a6ad
4ea19be
5c030c4
4ea19be
 
5c030c4
 
 
4ea19be
0d60bc4
 
 
91c176f
 
5c030c4
f768864
637a6ad
 
 
5c030c4
f768864
0d60bc4
 
 
46f7c46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# app.py
import gradio as gr
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import torch

MODEL_LIST = [
    "prithivMLmods/Trash-Net",
    "yangy50/garbage-classification"
]

models = []
processors = []
devices = []

print("Loading models...")
for model_name in MODEL_LIST:
    try:
        processor = AutoImageProcessor.from_pretrained(model_name)
        model = AutoModelForImageClassification.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None
        )
        model.eval()
        processors.append(processor)
        models.append(model)
        devices.append(next(model.parameters()).device)
        print(f"Loaded: {model_name}")
    except Exception as e:
        print(f"Failed to load {model_name}: {e}")

def classify_image(image: Image.Image):
    results = {}
    for model_name, processor, model, device in zip(MODEL_LIST, processors, models, devices):
        try:
            inputs = processor(images=image, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = model(**inputs)
            pred = outputs.logits.argmax(-1).item()
            label = model.config.id2label[pred]
            results[model_name] = label
        except Exception as e:
            results[model_name] = f"error:{e}"

    # 输出每个模型的结果
    results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])

    # 以 yangy50/garbage-classification 为最终结果
    final_label = results.get("yangy50/garbage-classification", "Unknown")
    results_text += f"\n\nFinal Label (base yangy50): {final_label}"
    return results_text

iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=[gr.Textbox(label="Model Predictions")],
    title="Trash Classification",
    description=(
        "Upload an image, and the following models will classify it:\n"
        "1. prithivMLmods/Trash-Net\n"
        "2. yangy50/garbage-classification\n"
        "The final label is based on yangy50/garbage-classification."
    )
)

if __name__ == "__main__":
    iface.launch()