TangYiJay commited on
Commit
e4edfab
·
verified ·
1 Parent(s): 94a8d5e
Files changed (1) hide show
  1. app.py +23 -39
app.py CHANGED
@@ -1,57 +1,41 @@
1
- # app.py
2
  import gradio as gr
3
  from transformers import AutoModelForImageClassification, AutoImageProcessor
4
  from PIL import Image
5
  import torch
6
 
7
- # ---------------- 模型列表 ----------------
8
- MODEL_LIST = [
9
  "yangy50/garbage-classification",
10
- "harriskr14/trashnet-vit",
11
- "prithivMLmods/Trash-Net"
12
  ]
13
 
14
- # ---------------- 加载模型 ----------------
15
- models = []
16
  processors = []
17
- loaded_names = []
18
-
19
- print("🔹 正在加载模型,请稍等...")
20
- for model_name in MODEL_LIST:
21
- try:
22
- processor = AutoImageProcessor.from_pretrained(model_name)
23
- model = AutoModelForImageClassification.from_pretrained(model_name)
24
- model.eval()
25
- processors.append(processor)
26
- models.append(model)
27
- loaded_names.append(model_name)
28
- print(f"✅ 加载成功: {model_name}")
29
- except Exception as e:
30
- print(f"❌ 加载失败: {model_name}, 错误: {e}")
31
 
32
- # ---------------- 推理函数 ----------------
33
  def classify_image(image: Image.Image):
34
- results = []
35
- for name, processor, model in zip(loaded_names, processors, models):
36
- try:
37
- inputs = processor(images=image, return_tensors="pt")
38
- with torch.no_grad():
39
- outputs = model(**inputs)
40
- pred = outputs.logits.argmax(-1).item()
41
- label = model.config.id2label.get(pred, str(pred))
42
- results.append(f"{name}: {label}")
43
- except Exception as e:
44
- results.append(f"{name}: ❌ 预测失败 ({e})")
45
-
46
- return "\n".join(results)
47
 
48
- # ---------------- Gradio 界面 ----------------
49
  iface = gr.Interface(
50
  fn=classify_image,
51
  inputs=gr.Image(type="pil", label="上传图片"),
52
- outputs=gr.Textbox(label="模型预测结果", lines=6),
53
- title="多模型垃圾分类",
54
- description="使用以下模型进行独立预测:yangy50、harriskr14、prithivMLmods。"
55
  )
56
 
57
  if __name__ == "__main__":
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForImageClassification, AutoImageProcessor
3
  from PIL import Image
4
  import torch
5
 
6
+ # 三个模型
7
+ MODELS = [
8
  "yangy50/garbage-classification",
9
+ "ahmzakif/TrashNet-Classification",
10
+ "harriskr14/trashnet-vit"
11
  ]
12
 
 
 
13
  processors = []
14
+ models = []
15
+ for name in MODELS:
16
+ p = AutoImageProcessor.from_pretrained(name)
17
+ m = AutoModelForImageClassification.from_pretrained(name)
18
+ m.eval()
19
+ processors.append(p)
20
+ models.append(m)
 
 
 
 
 
 
 
21
 
 
22
  def classify_image(image: Image.Image):
23
+ results = {}
24
+ for name, p, m in zip(MODELS, processors, models):
25
+ inputs = p(images=image, return_tensors="pt")
26
+ with torch.no_grad():
27
+ outputs = m(**inputs)
28
+ pred = outputs.logits.argmax(-1).item()
29
+ label = m.config.id2label.get(pred, f"id_{pred}")
30
+ results[name] = label
31
+ return "\n".join([f"{k}: {v}" for k, v in results.items()])
 
 
 
 
32
 
 
33
  iface = gr.Interface(
34
  fn=classify_image,
35
  inputs=gr.Image(type="pil", label="上传图片"),
36
+ outputs="text",
37
+ title="三模型垃圾分类",
38
+ description="使用三个模型独立预测垃圾种类"
39
  )
40
 
41
  if __name__ == "__main__":