TangYiJay commited on
Commit
94a8d5e
·
verified ·
1 Parent(s): 7180ec9
Files changed (1) hide show
  1. app.py +15 -26
app.py CHANGED
@@ -4,65 +4,54 @@ from transformers import AutoModelForImageClassification, AutoImageProcessor
4
  from PIL import Image
5
  import torch
6
 
7
- # ---------------- 配置模型列表 ----------------
8
  MODEL_LIST = [
9
- "prithivMLmods/Trash-Net",
10
  "yangy50/garbage-classification",
11
- "eunoiawiira-vgg-realwaste-classification",
12
- "ahmzakif/TrashNet-Classification",
13
- "ee8225-group4-vit-trashnet-enhanced",
14
- "harriskr14/trashnet-vit"
15
  ]
16
 
17
  # ---------------- 加载模型 ----------------
18
  models = []
19
  processors = []
20
- loaded_model_names = []
21
- print("🔹 正在加载模型,请稍等...")
22
 
 
23
  for model_name in MODEL_LIST:
24
  try:
25
  processor = AutoImageProcessor.from_pretrained(model_name)
26
  model = AutoModelForImageClassification.from_pretrained(model_name)
27
  model.eval()
28
- # 只加载模型自带 id2label,不手动干预
29
- if not hasattr(model.config, "id2label"):
30
- print(f"⚠️ 模型 {model_name} 没有内置 id2label,预测可能失败")
31
  processors.append(processor)
32
  models.append(model)
33
- loaded_model_names.append(model_name)
34
  print(f"✅ 加载成功: {model_name}")
35
  except Exception as e:
36
  print(f"❌ 加载失败: {model_name}, 错误: {e}")
37
 
38
  # ---------------- 推理函数 ----------------
39
  def classify_image(image: Image.Image):
40
- results = {}
41
- for model_name, processor, model in zip(loaded_model_names, processors, models):
42
  try:
43
  inputs = processor(images=image, return_tensors="pt")
44
  with torch.no_grad():
45
  outputs = model(**inputs)
46
  pred = outputs.logits.argmax(-1).item()
47
- if hasattr(model.config, "id2label"):
48
- label = model.config.id2label[pred]
49
- else:
50
- label = f"⚠️ 无内置 id2label,索引预测: {pred}"
51
- results[model_name] = label
52
  except Exception as e:
53
- results[model_name] = f"❌ 预测失败: {e}"
54
 
55
- # 格式化输出
56
- results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
57
- return results_text
58
 
59
  # ---------------- Gradio 界面 ----------------
60
  iface = gr.Interface(
61
  fn=classify_image,
62
  inputs=gr.Image(type="pil", label="上传图片"),
63
- outputs=[gr.Textbox(label="所有模型预测结果")],
64
- title="垃圾分类全模型检测",
65
- description="上传图片后,每个模型独立输出预测结果,不做任何人工干预。"
66
  )
67
 
68
  if __name__ == "__main__":
 
4
  from PIL import Image
5
  import torch
6
 
7
+ # ---------------- 模型列表 ----------------
8
  MODEL_LIST = [
 
9
  "yangy50/garbage-classification",
10
+ "harriskr14/trashnet-vit",
11
+ "prithivMLmods/Trash-Net"
 
 
12
  ]
13
 
14
  # ---------------- 加载模型 ----------------
15
  models = []
16
  processors = []
17
+ loaded_names = []
 
18
 
19
+ print("🔹 正在加载模型,请稍等...")
20
  for model_name in MODEL_LIST:
21
  try:
22
  processor = AutoImageProcessor.from_pretrained(model_name)
23
  model = AutoModelForImageClassification.from_pretrained(model_name)
24
  model.eval()
 
 
 
25
  processors.append(processor)
26
  models.append(model)
27
+ loaded_names.append(model_name)
28
  print(f"✅ 加载成功: {model_name}")
29
  except Exception as e:
30
  print(f"❌ 加载失败: {model_name}, 错误: {e}")
31
 
32
  # ---------------- 推理函数 ----------------
33
  def classify_image(image: Image.Image):
34
+ results = []
35
+ for name, processor, model in zip(loaded_names, processors, models):
36
  try:
37
  inputs = processor(images=image, return_tensors="pt")
38
  with torch.no_grad():
39
  outputs = model(**inputs)
40
  pred = outputs.logits.argmax(-1).item()
41
+ label = model.config.id2label.get(pred, str(pred))
42
+ results.append(f"{name}: {label}")
 
 
 
43
  except Exception as e:
44
+ results.append(f"{name}: 预测失败 ({e})")
45
 
46
+ return "\n".join(results)
 
 
47
 
48
  # ---------------- Gradio 界面 ----------------
49
  iface = gr.Interface(
50
  fn=classify_image,
51
  inputs=gr.Image(type="pil", label="上传图片"),
52
+ outputs=gr.Textbox(label="模型预测结果", lines=6),
53
+ title="多模型垃圾分类",
54
+ description="使用以下模型进行独立预测:yangy50、harriskr14、prithivMLmods。"
55
  )
56
 
57
  if __name__ == "__main__":