mochuan zhan commited on
Commit
360f3af
·
1 Parent(s): a2ce9c9
Files changed (2) hide show
  1. app.py +15 -15
  2. vit_model.pth +1 -1
app.py CHANGED
@@ -69,13 +69,14 @@ class ViT(nn.Module):
69
  model = ViT(num_classes=10) # 确保num_classes与你的MNIST任务一致
70
  model_path = "vit_model.pth" # 模型权重文件名
71
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
72
-
73
 
74
  # 定义图像预处理
75
  transform = transforms.Compose([
76
- transforms.Resize((224, 224)), # 适应ViT的输入大小
 
77
  transforms.ToTensor(),
78
- transforms.Normalize((0.5,), (0.5,)) # 根据训练时的归一化参数调整
79
  ])
80
 
81
  # 定义预测函数
@@ -87,18 +88,15 @@ def classify_image(image):
87
  # 确保 image 是一个 PIL 图像
88
  if not isinstance(image, Image.Image):
89
  raise TypeError(f"Expected image to be PIL Image, but got {type(image)}")
 
90
 
91
- # 将图像转换为灰度模式
92
- image = image.convert("L")
93
-
94
- # 反转颜色
95
- image = ImageOps.invert(image)
96
-
97
- # 调整图像大小
98
- image = image.resize((224, 224))
99
-
100
  # 图像预处理
101
  img = transform(image).unsqueeze(0) # 添加批次维度
 
 
 
102
 
103
  # 模型预测
104
  with torch.no_grad():
@@ -106,10 +104,11 @@ def classify_image(image):
106
  probabilities = F.softmax(outputs, dim=1)
107
 
108
  # 获取预测结果
109
- _, predicted = torch.max(outputs, 1)
110
  confidence = probabilities[0][predicted].item()
111
 
112
  # 返回结果字典,包含预测类别和置信度
 
113
  return {str(predicted.item()): confidence}
114
 
115
  # # 创建Gradio界面
@@ -123,11 +122,12 @@ def classify_image(image):
123
 
124
  iface = gr.Interface(
125
  fn=classify_image,
126
- inputs=gr.Sketchpad(crop_size=(256,256), type='pil', image_mode='L', brush=gr.Brush()),
127
  outputs=gr.Label(num_top_classes=1),
128
  title="MNIST Digit Classification with ViT",
129
  description="Use the mouse to hand draw a number and the model will predict the category it belongs to."
130
  )
131
 
132
 
133
- iface.launch()
 
 
69
  model = ViT(num_classes=10) # 确保num_classes与你的MNIST任务一致
70
  model_path = "vit_model.pth" # 模型权重文件名
71
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
72
+ model.eval()
73
 
74
  # 定义图像预处理
75
  transform = transforms.Compose([
76
+ transforms.Grayscale(num_output_channels=1), # 转换为单通道
77
+ transforms.Resize((28, 28)),
78
  transforms.ToTensor(),
79
+ transforms.Normalize((0.1307,), (0.3081,))
80
  ])
81
 
82
  # 定义预测函数
 
88
  # 确保 image 是一个 PIL 图像
89
  if not isinstance(image, Image.Image):
90
  raise TypeError(f"Expected image to be PIL Image, but got {type(image)}")
91
+ # 打印image的数组
92
 
93
+ print(image)
94
+
 
 
 
 
 
 
 
95
  # 图像预处理
96
  img = transform(image).unsqueeze(0) # 添加批次维度
97
+
98
+ image_pil = Image.fromarray(img.numpy().squeeze() * 255).convert('L')
99
+ image_pil.show()
100
 
101
  # 模型预测
102
  with torch.no_grad():
 
104
  probabilities = F.softmax(outputs, dim=1)
105
 
106
  # 获取预测结果
107
+ _, predicted = torch.max(probabilities, 1)
108
  confidence = probabilities[0][predicted].item()
109
 
110
  # 返回结果字典,包含预测类别和置信度
111
+ print(predicted, confidence)
112
  return {str(predicted.item()): confidence}
113
 
114
  # # 创建Gradio界面
 
122
 
123
  iface = gr.Interface(
124
  fn=classify_image,
125
+ inputs=gr.Sketchpad(type='pil', image_mode='L', brush=gr.Brush(default_size=18), crop_size=(600, 600)),
126
  outputs=gr.Label(num_top_classes=1),
127
  title="MNIST Digit Classification with ViT",
128
  description="Use the mouse to hand draw a number and the model will predict the category it belongs to."
129
  )
130
 
131
 
132
+ if __name__ == "__main__":
133
+ iface.launch()
vit_model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:223c6c32c2a9d4c274b09c35ef089b358ee7cf1729b9d939fca898db5765dcdb
3
  size 3248655
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4c96b402b7457e05bba3fbf9589f8ee20aaf1bb86a482d4b605ac289cafde68
3
  size 3248655