TangYiJay commited on
Commit
8ee0f62
·
verified ·
1 Parent(s): e4edfab
Files changed (1) hide show
  1. app.py +66 -23
app.py CHANGED
@@ -1,41 +1,84 @@
 
1
  import gradio as gr
2
- from transformers import AutoModelForImageClassification, AutoImageProcessor
3
  from PIL import Image
4
  import torch
 
5
 
6
- # 三个模型
7
- MODELS = [
8
  "yangy50/garbage-classification",
9
- "ahmzakif/TrashNet-Classification",
10
- "harriskr14/trashnet-vit"
11
  ]
12
 
13
- processors = []
14
  models = []
15
- for name in MODELS:
16
- p = AutoImageProcessor.from_pretrained(name)
17
- m = AutoModelForImageClassification.from_pretrained(name)
18
- m.eval()
19
- processors.append(p)
20
- models.append(m)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def classify_image(image: Image.Image):
23
  results = {}
24
- for name, p, m in zip(MODELS, processors, models):
25
- inputs = p(images=image, return_tensors="pt")
26
- with torch.no_grad():
27
- outputs = m(**inputs)
28
- pred = outputs.logits.argmax(-1).item()
29
- label = m.config.id2label.get(pred, f"id_{pred}")
30
- results[name] = label
31
- return "\n".join([f"{k}: {v}" for k, v in results.items()])
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
33
  iface = gr.Interface(
34
  fn=classify_image,
35
  inputs=gr.Image(type="pil", label="上传图片"),
36
- outputs="text",
37
- title="三模型垃圾分类",
38
- description="使用三个模型独立预测垃圾种类"
39
  )
40
 
41
  if __name__ == "__main__":
 
1
+ # app.py
2
  import gradio as gr
3
+ from transformers import AutoModelForImageClassification, AutoImageProcessor, AutoFeatureExtractor
4
  from PIL import Image
5
  import torch
6
+ from torchvision import transforms
7
 
8
+ # ---------------- 配置模型 ----------------
9
+ MODEL_LIST = [
10
  "yangy50/garbage-classification",
11
+ "harriskr14/trashnet-vit",
12
+ "ahmzakif/TrashNet-Classification"
13
  ]
14
 
15
+ # ---------------- 加载模型 ----------------
16
  models = []
17
+ processors = []
18
+ loaded_model_names = []
19
+
20
+ print("🔹 正在加载模型,请稍等...")
 
 
21
 
22
+ for name in MODEL_LIST:
23
+ try:
24
+ model = AutoModelForImageClassification.from_pretrained(name)
25
+ # 尝试加载 AutoImageProcessor
26
+ try:
27
+ processor = AutoImageProcessor.from_pretrained(name)
28
+ except OSError:
29
+ # 尝试 AutoFeatureExtractor
30
+ try:
31
+ processor = AutoFeatureExtractor.from_pretrained(name)
32
+ except OSError:
33
+ processor = None
34
+ print(f"⚠️ 模型 {name} 没有自带处理器,将使用默认 transforms")
35
+ model.eval()
36
+ models.append(model)
37
+ processors.append(processor)
38
+ loaded_model_names.append(name)
39
+ print(f"✅ 加载成功: {name}")
40
+ except Exception as e:
41
+ print(f"❌ 加载失败: {name}, 错误: {e}")
42
+
43
+ # ---------------- 默认图片处理 ----------------
44
+ default_preprocess = transforms.Compose([
45
+ transforms.Resize((224, 224)),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
48
+ std=[0.229, 0.224, 0.225])
49
+ ])
50
+
51
+ # ---------------- 推理函数 ----------------
52
  def classify_image(image: Image.Image):
53
  results = {}
54
+ for name, model, processor in zip(loaded_model_names, models, processors):
55
+ try:
56
+ if processor:
57
+ inputs = processor(images=image, return_tensors="pt")
58
+ else:
59
+ inputs = default_preprocess(image).unsqueeze(0)
60
+ with torch.no_grad():
61
+ outputs = model(**inputs)
62
+ pred_idx = outputs.logits.argmax(-1).item()
63
+ if hasattr(model.config, "id2label"):
64
+ label = model.config.id2label[pred_idx]
65
+ else:
66
+ label = f"索引预测: {pred_idx}"
67
+ results[name] = label
68
+ except Exception as e:
69
+ results[name] = f"❌ 预测失败: {e}"
70
+
71
+ # 格式化输出
72
+ results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
73
+ return results_text
74
 
75
+ # ---------------- Gradio 界面 ----------------
76
  iface = gr.Interface(
77
  fn=classify_image,
78
  inputs=gr.Image(type="pil", label="上传图片"),
79
+ outputs=[gr.Textbox(label="模型预测结果")],
80
+ title="垃圾分类模型检测",
81
+ description="上传图片,每个模型独立预测,输出所有结果。"
82
  )
83
 
84
  if __name__ == "__main__":