TangYiJay commited on
Commit
4ea19be
·
verified ·
1 Parent(s): 46f7c46
Files changed (1) hide show
  1. app.py +41 -18
app.py CHANGED
@@ -1,41 +1,64 @@
 
1
  import gradio as gr
2
  from transformers import AutoModelForImageClassification, AutoImageProcessor
3
  from PIL import Image
4
  import torch
5
 
6
- # 三个模型
7
- MODELS = [
 
8
  "yangy50/garbage-classification",
 
9
  "ahmzakif/TrashNet-Classification",
 
10
  "harriskr14/trashnet-vit"
11
  ]
12
 
13
- processors = []
14
  models = []
15
- for name in MODELS:
16
- p = AutoImageProcessor.from_pretrained(name)
17
- m = AutoModelForImageClassification.from_pretrained(name)
18
- m.eval()
19
- processors.append(p)
20
- models.append(m)
 
 
 
 
 
 
21
 
 
22
  def classify_image(image: Image.Image):
23
  results = {}
24
- for name, p, m in zip(MODELS, processors, models):
25
- inputs = p(images=image, return_tensors="pt")
 
26
  with torch.no_grad():
27
- outputs = m(**inputs)
28
  pred = outputs.logits.argmax(-1).item()
29
- label = m.config.id2label.get(pred, f"id_{pred}")
30
- results[name] = label
31
- return "\n".join([f"{k}: {v}" for k, v in results.items()])
 
 
 
 
32
 
 
33
  iface = gr.Interface(
34
  fn=classify_image,
35
  inputs=gr.Image(type="pil", label="上传图片"),
36
- outputs="text",
37
- title="三模型垃圾分类",
38
- description="使用三个模型独立预测垃圾种类"
 
 
 
 
 
 
39
  )
40
 
41
  if __name__ == "__main__":
 
1
+ # app.py
2
  import gradio as gr
3
  from transformers import AutoModelForImageClassification, AutoImageProcessor
4
  from PIL import Image
5
  import torch
6
 
7
+ # ---------------- 配置模型列表,包含你提到的全部模型 ----------------
8
+ MODEL_LIST = [
9
+ "prithivMLmods/Trash-Net",
10
  "yangy50/garbage-classification",
11
+ "eunoiawiira-vgg-realwaste-classification",
12
  "ahmzakif/TrashNet-Classification",
13
+ "ee8225-group4-vit-trashnet-enhanced",
14
  "harriskr14/trashnet-vit"
15
  ]
16
 
17
+ # ---------------- 加载模型 ----------------
18
  models = []
19
+ processors = []
20
+ print("🔹 正在加载模型,请稍等...")
21
+ for model_name in MODEL_LIST:
22
+ try:
23
+ processor = AutoImageProcessor.from_pretrained(model_name)
24
+ model = AutoModelForImageClassification.from_pretrained(model_name)
25
+ model.eval()
26
+ processors.append(processor)
27
+ models.append(model)
28
+ print(f"✅ 加载完成: {model_name}")
29
+ except Exception as e:
30
+ print(f"❌ 加载失败: {model_name}, 错误: {e}")
31
 
32
+ # ---------------- 推理函数 ----------------
33
  def classify_image(image: Image.Image):
34
  results = {}
35
+ # 遍历所有模型进行预测
36
+ for model_name, processor, model in zip(MODEL_LIST, processors, models):
37
+ inputs = processor(images=image, return_tensors="pt")
38
  with torch.no_grad():
39
+ outputs = model(**inputs)
40
  pred = outputs.logits.argmax(-1).item()
41
+ label = model.config.id2label[pred]
42
+ results[model_name] = label
43
+
44
+ # 格式化输出,每个模型的结果单独显示
45
+ results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
46
+
47
+ return results_text
48
 
49
+ # ---------------- Gradio 界面 ----------------
50
  iface = gr.Interface(
51
  fn=classify_image,
52
  inputs=gr.Image(type="pil", label="上传图片"),
53
+ outputs=[gr.Textbox(label="所有模型预测结果")],
54
+ title="垃圾分类多模型检测",
55
+ description="上传图片后,使用以下模型进行垃圾分类,每个模型结果单独输出:\n"
56
+ "1. prithivMLmods/Trash-Net\n"
57
+ "2. yangy50/garbage-classification\n"
58
+ "3. eunoiawiira-vgg-realwaste-classification\n"
59
+ "4. ahmzakif/TrashNet-Classification\n"
60
+ "5. ee8225-group4-vit-trashnet-enhanced\n"
61
+ "6. harriskr14/trashnet-vit"
62
  )
63
 
64
  if __name__ == "__main__":