Spaces:
Sleeping
Sleeping
| # 3. model_utils.py | |
| # Model management (loading, prediction, and species information) | |
| from transformers import ViTForImageClassification | |
| from PIL import Image | |
| import torch | |
| from dataset_utils import DatasetHandler | |
| import threading | |
| class BugClassifier: | |
| def __init__(self, model_path="google/vit-base-patch16-224"): | |
| self.model = ViTForImageClassification.from_pretrained(model_path) | |
| self.model.eval() | |
| self.labels = [ | |
| "Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant", | |
| "Japanese Beetle", "Garden Spider", "Green Grasshopper", | |
| "Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp" | |
| ] | |
| self.species_descriptions = {} | |
| self.load_species_descriptions() | |
| def load_species_descriptions(self): | |
| def load(): | |
| handler = DatasetHandler() | |
| self.species_descriptions = handler.load_descriptions(max_records=500) | |
| thread = threading.Thread(target=load) | |
| thread.start() | |
| def predict(self, image): | |
| try: | |
| processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| confidence, predicted_idx = probabilities.max(dim=1) | |
| confidence = confidence.item() * 100 | |
| predicted_label = self.labels[predicted_idx.item()] | |
| if confidence < 30: | |
| return "Unknown Insect", confidence | |
| return predicted_label, confidence | |
| except Exception as e: | |
| return "Error Processing Image", 0.0 | |
| def get_species_info(self, species): | |
| return self.species_descriptions.get( | |
| species, "Information not available. Consider updating your dataset for this species." | |
| ) |