import gradio as gr from transformers import AutoImageProcessor, SiglipForImageClassification from PIL import Image import torch # 加载 Trash-Net 模型 model_name = "prithivMLmods/Trash-Net" model = SiglipForImageClassification.from_pretrained(model_name) processor = AutoImageProcessor.from_pretrained(model_name) # 定义垃圾分类函数 def trash_classification(image): """输入图片,返回垃圾分类结果""" if image is None: return {} # 转换图片为 RGB image = Image.fromarray(image).convert("RGB") # 转换成模型需要的 tensor inputs = processor(images=image, return_tensors="pt") # 模型预测 with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist() # 分类标签 labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"] # 返回每个类别的概率 predictions = {labels[i]: round(probs[i], 3) for i in range(len(probs))} return predictions # 创建 Gradio 接口 iface = gr.Interface( fn=trash_classification, inputs=gr.Image(type="numpy"), outputs=gr.Label(label="Prediction Scores"), title="Trash Classification", description="Upload an image to classify the type of waste material." ) # 启动 if __name__ == "__main__": iface.launch()