Spaces:
Runtime error
Runtime error
Commit
·
1db5c15
1
Parent(s):
3e73041
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,18 +3,18 @@ import gradio as gr
|
|
| 3 |
import datasets
|
| 4 |
import torch
|
| 5 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
| 6 |
-
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
| 7 |
|
| 8 |
dataset = datasets.load_dataset('beans')
|
| 9 |
|
| 10 |
extractor = AutoFeatureExtractor.from_pretrained("suresh-subramanian/beans-classification")
|
| 11 |
model = AutoModelForImageClassification.from_pretrained("suresh-subramanian/beans-classification")
|
| 12 |
-
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
|
| 13 |
|
| 14 |
labels = dataset['train'].features['labels'].names
|
| 15 |
|
| 16 |
def classify(im):
|
| 17 |
-
features =
|
| 18 |
with torch.no_grad():
|
| 19 |
logits = model(features["pixel_values"])[-1]
|
| 20 |
probability = torch.nn.functional.softmax(logits, dim=-1)
|
|
|
|
| 3 |
import datasets
|
| 4 |
import torch
|
| 5 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
| 6 |
+
# from transformers import ViTFeatureExtractor, ViTForImageClassification
|
| 7 |
|
| 8 |
dataset = datasets.load_dataset('beans')
|
| 9 |
|
| 10 |
extractor = AutoFeatureExtractor.from_pretrained("suresh-subramanian/beans-classification")
|
| 11 |
model = AutoModelForImageClassification.from_pretrained("suresh-subramanian/beans-classification")
|
| 12 |
+
# feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
|
| 13 |
|
| 14 |
labels = dataset['train'].features['labels'].names
|
| 15 |
|
| 16 |
def classify(im):
|
| 17 |
+
features = extractor(im, return_tensors='pt')
|
| 18 |
with torch.no_grad():
|
| 19 |
logits = model(features["pixel_values"])[-1]
|
| 20 |
probability = torch.nn.functional.softmax(logits, dim=-1)
|