TangYiJay commited on
Commit
f768864
·
verified ·
1 Parent(s): 6f789e1
Files changed (1) hide show
  1. app.py +32 -13
app.py CHANGED
@@ -3,8 +3,9 @@ import gradio as gr
3
  from transformers import AutoModelForImageClassification, AutoImageProcessor
4
  from PIL import Image
5
  import torch
 
6
 
7
- # ---------------- 配置模型列表,包含你提到的全部模型 ----------------
8
  MODEL_LIST = [
9
  "prithivMLmods/Trash-Net",
10
  "yangy50/garbage-classification",
@@ -14,14 +15,21 @@ MODEL_LIST = [
14
  # ---------------- 加载模型 ----------------
15
  models = []
16
  processors = []
 
17
  print("🔹 正在加载模型,请稍等...")
 
18
  for model_name in MODEL_LIST:
19
  try:
20
  processor = AutoImageProcessor.from_pretrained(model_name)
21
- model = AutoModelForImageClassification.from_pretrained(model_name)
 
 
 
 
22
  model.eval()
23
  processors.append(processor)
24
  models.append(model)
 
25
  print(f"✅ 加载完成: {model_name}")
26
  except Exception as e:
27
  print(f"❌ 加载失败: {model_name}, 错误: {e}")
@@ -29,30 +37,41 @@ for model_name in MODEL_LIST:
29
  # ---------------- 推理函数 ----------------
30
  def classify_image(image: Image.Image):
31
  results = {}
32
- # 遍历所有模型进行预测
33
- for model_name, processor, model in zip(MODEL_LIST, processors, models):
34
- inputs = processor(images=image, return_tensors="pt")
35
- with torch.no_grad():
36
- outputs = model(**inputs)
 
37
  pred = outputs.logits.argmax(-1).item()
38
  label = model.config.id2label[pred]
39
  results[model_name] = label
 
 
40
 
41
- # 格式化输出,每个模型的结果单独显示
42
  results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
43
 
 
 
 
 
 
44
  return results_text
45
 
46
  # ---------------- Gradio 界面 ----------------
47
  iface = gr.Interface(
48
  fn=classify_image,
49
  inputs=gr.Image(type="pil", label="上传图片"),
50
- outputs=[gr.Textbox(label="所有模型预测结果")],
51
  title="垃圾分类多模型检测",
52
- description="上传图片后,使用以下模型进行垃圾分类,每个模型结果单独输出:\n"
53
- "1. prithivMLmods/Trash-Net\n"
54
- "2. yangy50/garbage-classification\n"
55
- "3. eunoiawiira-vgg-realwaste-classification\n"
 
 
 
56
  )
57
 
58
  if __name__ == "__main__":
 
3
  from transformers import AutoModelForImageClassification, AutoImageProcessor
4
  from PIL import Image
5
  import torch
6
+ from collections import Counter
7
 
8
+ # ---------------- 模型列表 ----------------
9
  MODEL_LIST = [
10
  "prithivMLmods/Trash-Net",
11
  "yangy50/garbage-classification",
 
15
  # ---------------- 加载模型 ----------------
16
  models = []
17
  processors = []
18
+ devices = []
19
  print("🔹 正在加载模型,请稍等...")
20
+
21
  for model_name in MODEL_LIST:
22
  try:
23
  processor = AutoImageProcessor.from_pretrained(model_name)
24
+ model = AutoModelForImageClassification.from_pretrained(
25
+ model_name,
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
+ device_map="auto" if torch.cuda.is_available() else None
28
+ )
29
  model.eval()
30
  processors.append(processor)
31
  models.append(model)
32
+ devices.append(next(model.parameters()).device)
33
  print(f"✅ 加载完成: {model_name}")
34
  except Exception as e:
35
  print(f"❌ 加载失败: {model_name}, 错误: {e}")
 
37
  # ---------------- 推理函数 ----------------
38
  def classify_image(image: Image.Image):
39
  results = {}
40
+ for model_name, processor, model, device in zip(MODEL_LIST, processors, models, devices):
41
+ try:
42
+ inputs = processor(images=image, return_tensors="pt")
43
+ inputs = {k: v.to(device) for k, v in inputs.items()}
44
+ with torch.no_grad():
45
+ outputs = model(**inputs)
46
  pred = outputs.logits.argmax(-1).item()
47
  label = model.config.id2label[pred]
48
  results[model_name] = label
49
+ except Exception as e:
50
+ results[model_name] = f"error: {e}"
51
 
52
+ # 格式化输出
53
  results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
54
 
55
+ # 计算最终标签(投票法)
56
+ valid_labels = [lbl for lbl in results.values() if not lbl.startswith("error")]
57
+ final_label = Counter(valid_labels).most_common(1)[0][0] if valid_labels else "Unknown"
58
+
59
+ results_text += f"\n\n最终标签: {final_label}"
60
  return results_text
61
 
62
  # ---------------- Gradio 界面 ----------------
63
  iface = gr.Interface(
64
  fn=classify_image,
65
  inputs=gr.Image(type="pil", label="上传图片"),
66
+ outputs=[gr.Textbox(label="模型预测结果")],
67
  title="垃圾分类多模型检测",
68
+ description=(
69
+ "上传图片后,使用以下模型进行垃圾分类,每个模型结果单独输出:\n"
70
+ "1. prithivMLmods/Trash-Net\n"
71
+ "2. yangy50/garbage-classification\n"
72
+ "3. eunoiawiira-vgg-realwaste-classification\n"
73
+ "最终标签通过投票法生成"
74
+ )
75
  )
76
 
77
  if __name__ == "__main__":