Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import zipfile | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image | |
| import xml.etree.ElementTree as ET | |
| import torchvision.models.detection | |
| from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Dataset class | |
| class FaceMaskDataset(Dataset): | |
| def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)): | |
| self.images_dir = images_dir | |
| self.annotations_dir = annotations_dir | |
| self.transform = transform | |
| self.resize = resize | |
| self.image_files = [f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png'))] | |
| def __len__(self): | |
| return len(self.image_files) | |
| def __getitem__(self, idx): | |
| image_path = os.path.join(self.images_dir, self.image_files[idx]) | |
| image = Image.open(image_path).convert("RGB") | |
| image = image.resize(self.resize) | |
| annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml")) | |
| if not os.path.exists(annotation_path): | |
| return None, None | |
| boxes, labels = self.load_annotations(annotation_path) | |
| if boxes is None or labels is None: | |
| return None, None | |
| target = {'boxes': boxes, 'labels': labels} | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, target | |
| def load_annotations(self, annotation_path): | |
| tree = ET.parse(annotation_path) | |
| root = tree.getroot() | |
| boxes = [] | |
| labels = [] | |
| for obj in root.iter('object'): | |
| label = obj.find('name').text | |
| bndbox = obj.find('bndbox') | |
| xmin = float(bndbox.find('xmin').text) | |
| ymin = float(bndbox.find('ymin').text) | |
| xmax = float(bndbox.find('xmax').text) | |
| ymax = float(bndbox.find('ymax').text) | |
| boxes.append([xmin, ymin, xmax, ymax]) | |
| labels.append(1 if label == "mask" else 0) | |
| if not boxes or not labels: | |
| return None, None | |
| boxes = torch.tensor(boxes, dtype=torch.float32) | |
| labels = torch.tensor(labels, dtype=torch.int64) | |
| return boxes, labels | |
| def collate_fn(batch): | |
| batch = [b for b in batch if b[0] is not None and b[1] is not None] | |
| images, targets = zip(*batch) | |
| return list(images), list(targets) | |
| def get_model(num_classes): | |
| model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) | |
| in_features = model.roi_heads.box_predictor.cls_score.in_features | |
| model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) | |
| return model | |
| def extract_zip(zip_file, extract_to): | |
| with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |
| zip_ref.extractall(extract_to) | |
| def train_model(train_zip, val_zip): | |
| extract_zip(train_zip, './data/train') | |
| extract_zip(val_zip, './data/val') | |
| transform = transforms.Compose([transforms.ToTensor()]) | |
| train_dataset = FaceMaskDataset( | |
| images_dir='./data/train/train/images', | |
| annotations_dir='./data/train/train/annotations', | |
| transform=transform | |
| ) | |
| val_dataset = FaceMaskDataset( | |
| images_dir='./data/val/val/images', | |
| annotations_dir='./data/val/val/annotations', | |
| transform=transform | |
| ) | |
| train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) | |
| val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn) | |
| model = get_model(num_classes=2).to(device) | |
| optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005) | |
| for epoch in range(3): # Reduce for demo | |
| model.train() | |
| total_loss = 0 | |
| for images, targets in train_loader: | |
| images = [img.to(device) for img in images] | |
| targets = [{k: v.to(device) for k, v in t.items()} for t in targets] | |
| optimizer.zero_grad() | |
| loss_dict = model(images, targets) | |
| loss = sum(loss for loss in loss_dict.values()) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}") | |
| torch.save(model.state_dict(), "model.pth") | |
| return "Training completed. Model saved as model.pth" | |
| # Gradio upload interface | |
| iface = gr.Interface( | |
| fn=train_model, | |
| inputs=[ | |
| gr.File(label="Upload Train ZIP"), | |
| gr.File(label="Upload Val ZIP") | |
| ], | |
| outputs="text" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |