TangYiJay commited on
Commit
04b114a
·
verified ·
1 Parent(s): 8ee0f62
Files changed (1) hide show
  1. app.py +37 -43
app.py CHANGED
@@ -1,15 +1,30 @@
1
  # app.py
2
- import gradio as gr
3
- from transformers import AutoModelForImageClassification, AutoImageProcessor, AutoFeatureExtractor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from PIL import Image
5
  import torch
6
- from torchvision import transforms
7
 
8
- # ---------------- 配置模型 ----------------
9
  MODEL_LIST = [
10
  "yangy50/garbage-classification",
11
- "harriskr14/trashnet-vit",
12
- "ahmzakif/TrashNet-Classification"
13
  ]
14
 
15
  # ---------------- 加载模型 ----------------
@@ -18,55 +33,34 @@ processors = []
18
  loaded_model_names = []
19
 
20
  print("🔹 正在加载模型,请稍等...")
21
-
22
- for name in MODEL_LIST:
23
  try:
24
- model = AutoModelForImageClassification.from_pretrained(name)
25
- # 尝试加载 AutoImageProcessor
26
- try:
27
- processor = AutoImageProcessor.from_pretrained(name)
28
- except OSError:
29
- # 尝试 AutoFeatureExtractor
30
- try:
31
- processor = AutoFeatureExtractor.from_pretrained(name)
32
- except OSError:
33
- processor = None
34
- print(f"⚠️ 模型 {name} 没有自带处理器,将使用默认 transforms")
35
  model.eval()
36
- models.append(model)
37
  processors.append(processor)
38
- loaded_model_names.append(name)
39
- print(f"✅ 加载成功: {name}")
 
40
  except Exception as e:
41
- print(f"❌ 加载失败: {name}, 错误: {e}")
42
-
43
- # ---------------- 默认图片处理 ----------------
44
- default_preprocess = transforms.Compose([
45
- transforms.Resize((224, 224)),
46
- transforms.ToTensor(),
47
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
48
- std=[0.229, 0.224, 0.225])
49
- ])
50
 
51
  # ---------------- 推理函数 ----------------
52
  def classify_image(image: Image.Image):
53
  results = {}
54
- for name, model, processor in zip(loaded_model_names, models, processors):
55
  try:
56
- if processor:
57
- inputs = processor(images=image, return_tensors="pt")
58
- else:
59
- inputs = default_preprocess(image).unsqueeze(0)
60
  with torch.no_grad():
61
  outputs = model(**inputs)
62
- pred_idx = outputs.logits.argmax(-1).item()
63
  if hasattr(model.config, "id2label"):
64
- label = model.config.id2label[pred_idx]
65
  else:
66
- label = f"索引预测: {pred_idx}"
67
- results[name] = label
68
  except Exception as e:
69
- results[name] = f"❌ 预测失败: {e}"
70
 
71
  # 格式化输出
72
  results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
@@ -76,9 +70,9 @@ def classify_image(image: Image.Image):
76
  iface = gr.Interface(
77
  fn=classify_image,
78
  inputs=gr.Image(type="pil", label="上传图片"),
79
- outputs=[gr.Textbox(label="模型预测结果")],
80
  title="垃圾分类模型检测",
81
- description="上传图片,每个模型独立预测,输出所有结果。"
82
  )
83
 
84
  if __name__ == "__main__":
 
1
  # app.py
2
+ import subprocess
3
+ import sys
4
+
5
+ # ---------------- 自动安装缺失依赖 ----------------
6
+ def install(package):
7
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
8
+
9
+ for pkg in ["torch", "torchvision", "transformers", "gradio", "Pillow"]:
10
+ try:
11
+ __import__(pkg if pkg != "Pillow" else "PIL")
12
+ except ModuleNotFoundError:
13
+ print(f"⚠️ 未找到 {pkg},正在自动安装...")
14
+ install(pkg)
15
+ print(f"✅ {pkg} 安装完成")
16
+
17
+ # ---------------- 导入库 ----------------
18
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
19
  from PIL import Image
20
  import torch
21
+ import gradio as gr
22
 
23
+ # ---------------- 模型列表 ----------------
24
  MODEL_LIST = [
25
  "yangy50/garbage-classification",
26
+ "ahmzakif/TrashNet-Classification",
27
+ "harriskr14/trashnet-vit"
28
  ]
29
 
30
  # ---------------- 加载模型 ----------------
 
33
  loaded_model_names = []
34
 
35
  print("🔹 正在加载模型,请稍等...")
36
+ for model_name in MODEL_LIST:
 
37
  try:
38
+ processor = AutoImageProcessor.from_pretrained(model_name)
39
+ model = AutoModelForImageClassification.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
40
  model.eval()
 
41
  processors.append(processor)
42
+ models.append(model)
43
+ loaded_model_names.append(model_name)
44
+ print(f"✅ 加载成功: {model_name}")
45
  except Exception as e:
46
+ print(f"❌ 加载失败: {model_name}, 错误: {e}")
 
 
 
 
 
 
 
 
47
 
48
  # ---------------- 推理函数 ----------------
49
  def classify_image(image: Image.Image):
50
  results = {}
51
+ for model_name, processor, model in zip(loaded_model_names, processors, models):
52
  try:
53
+ inputs = processor(images=image, return_tensors="pt")
 
 
 
54
  with torch.no_grad():
55
  outputs = model(**inputs)
56
+ pred = outputs.logits.argmax(-1).item()
57
  if hasattr(model.config, "id2label"):
58
+ label = model.config.id2label[pred]
59
  else:
60
+ label = f"⚠️ 无内置 id2label,索引预测: {pred}"
61
+ results[model_name] = label
62
  except Exception as e:
63
+ results[model_name] = f"❌ 预测失败: {e}"
64
 
65
  # 格式化输出
66
  results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
 
70
  iface = gr.Interface(
71
  fn=classify_image,
72
  inputs=gr.Image(type="pil", label="上传图片"),
73
+ outputs=[gr.Textbox(label="所有模型预测结果")],
74
  title="垃圾分类模型检测",
75
+ description="上传图片后,每个模型独立输出预测结果,不做任何人工干预。"
76
  )
77
 
78
  if __name__ == "__main__":