Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image | |
| import torchvision | |
| from torchvision import transforms | |
| import xml.etree.ElementTree as ET | |
| import torch.optim as optim | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| # Ensure device is set to GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class FaceMaskDataset(Dataset): | |
| def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)): | |
| self.images_dir = images_dir | |
| self.annotations_dir = annotations_dir | |
| self.transform = transform | |
| self.resize = resize | |
| self.image_files = os.listdir(images_dir) | |
| def __len__(self): | |
| return len(self.image_files) | |
| def __getitem__(self, idx): | |
| image_path = os.path.join(self.images_dir, self.image_files[idx]) | |
| image = Image.open(image_path).convert("RGB") | |
| # Resize the image to a fixed size, while maintaining aspect ratio | |
| image = image.resize(self.resize) | |
| # Handle both .jpg and .png files | |
| annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml")) | |
| if not os.path.exists(annotation_path): | |
| print(f"Warning: Annotation file {annotation_path} does not exist. Skipping image {self.image_files[idx]}.") | |
| return None, None # Return a tuple with None to skip the image/annotation pair | |
| boxes, labels = self.load_annotations(annotation_path) | |
| if boxes is None or labels is None: | |
| return None, None # Skip this item if annotations are invalid | |
| target = {'boxes': boxes, 'labels': labels} | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, target | |
| def load_annotations(self, annotation_path): | |
| tree = ET.parse(annotation_path) | |
| root = tree.getroot() | |
| boxes = [] | |
| labels = [] | |
| for obj in root.iter('object'): | |
| label = obj.find('name').text | |
| bndbox = obj.find('bndbox') | |
| xmin = float(bndbox.find('xmin').text) | |
| ymin = float(bndbox.find('ymin').text) | |
| xmax = float(bndbox.find('xmax').text) | |
| ymax = float(bndbox.find('ymax').text) | |
| boxes.append([xmin, ymin, xmax, ymax]) | |
| labels.append(1 if label == "mask" else 0) # Assuming "mask" = 1, "no_mask" = 0 | |
| if len(boxes) == 0 or len(labels) == 0: | |
| return None, None # If no boxes or labels are found, return None | |
| boxes = torch.as_tensor(boxes, dtype=torch.float32) | |
| labels = torch.tensor(labels, dtype=torch.int64) | |
| return boxes, labels | |
| # Define the collate function for DataLoader | |
| def collate_fn(batch): | |
| # Filter out None values and pack the rest into a batch | |
| batch = [item for item in batch if item[0] is not None and item[1] is not None] | |
| return tuple(zip(*batch)) | |
| # Load your pre-trained model (or initialize if required) | |
| def load_model(): | |
| model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) | |
| # Assuming 2 classes: mask and no_mask | |
| num_classes = 2 | |
| in_features = model.roi_heads.box_predictor.cls_score.in_features | |
| model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) | |
| model.to(device) | |
| return model | |
| # Inference function | |
| def infer(image): | |
| model = load_model() # Load the model | |
| model.eval() | |
| # Apply transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), # Resize all images to 224x224 | |
| transforms.ToTensor(), | |
| ]) | |
| image = Image.fromarray(image) | |
| image = transform(image).unsqueeze(0).to(device) # Add batch dimension | |
| with torch.no_grad(): | |
| prediction = model(image) | |
| # Get boxes and labels from the predictions | |
| boxes = prediction[0]['boxes'].cpu().numpy() | |
| labels = prediction[0]['labels'].cpu().numpy() | |
| return boxes, labels | |
| # Gradio interface | |
| def gradio_interface(image): | |
| boxes, labels = infer(image) | |
| # Assuming labels: 0 = no mask, 1 = mask | |
| result = {"boxes": boxes, "labels": labels} | |
| return result | |
| # Create Gradio interface | |
| iface = gr.Interface(fn=gradio_interface, inputs=gr.Image(type="numpy"), outputs="json") | |
| # Launch Gradio interface | |
| iface.launch() | |